Enhance CustomLLM and JWTAuthLLM initialization with model parameter

- Update CustomLLM to accept a model parameter during initialization
- Modify test cases to include the new model argument
- Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors
- Update type hints in create_llm function to support both LLM and BaseLLM types
This commit is contained in:
Lorenze Jay
2025-03-12 08:16:59 -07:00
parent b305ef8f48
commit 902c330113
2 changed files with 7 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ from crewai.llm import LLM, BaseLLM
def create_llm( def create_llm(
llm_value: Union[str, LLM, Any, None] = None, llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]: ) -> Optional[LLM | BaseLLM]:
""" """
Creates or returns an LLM instance based on the given llm_value. Creates or returns an LLM instance based on the given llm_value.

View File

@@ -15,13 +15,13 @@ class CustomLLM(BaseLLM):
that returns a predefined response for testing purposes. that returns a predefined response for testing purposes.
""" """
def __init__(self, response="Default response"): def __init__(self, response="Default response", model="test-model"):
"""Initialize the CustomLLM with a predefined response. """Initialize the CustomLLM with a predefined response.
Args: Args:
response: The predefined response to return from call(). response: The predefined response to return from call().
""" """
super().__init__() super().__init__(model=model)
self.response = response self.response = response
self.call_count = 0 self.call_count = 0
@@ -99,7 +99,7 @@ def test_custom_llm_implementation():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_within_crew(): def test_custom_llm_within_crew():
"""Test that a custom LLM implementation works with create_llm.""" """Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!") custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
agent = Agent( agent = Agent(
role="Say Hi", role="Say Hi",
@@ -130,7 +130,7 @@ def test_custom_llm_within_crew():
def test_custom_llm_message_formatting(): def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages""" """Test that the custom LLM properly formats messages"""
custom_llm = CustomLLM(response="Test response") custom_llm = CustomLLM(response="Test response", model="test-model")
# Test with string input # Test with string input
result = custom_llm.call("Test message") result = custom_llm.call("Test message")
@@ -149,7 +149,7 @@ class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication.""" """Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str): def __init__(self, jwt_token: str):
super().__init__() super().__init__(model="test-model")
if not jwt_token or not isinstance(jwt_token, str): if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token") raise ValueError("Invalid JWT token")
self.jwt_token = jwt_token self.jwt_token = jwt_token
@@ -228,7 +228,7 @@ class TimeoutHandlingLLM(BaseLLM):
max_retries: Maximum number of retry attempts. max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call. timeout: Timeout in seconds for each API call.
""" """
super().__init__() super().__init__(model="test-model")
self.max_retries = max_retries self.max_retries = max_retries
self.timeout = timeout self.timeout = timeout
self.calls = [] self.calls = []