diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 3c6f1eb75..2242d305b 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -7,7 +7,7 @@ from crewai.llm import LLM, BaseLLM def create_llm( llm_value: Union[str, LLM, Any, None] = None, -) -> Optional[LLM]: +) -> Optional[LLM | BaseLLM]: """ Creates or returns an LLM instance based on the given llm_value. diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py index 38716cc65..6bee5b31d 100644 --- a/tests/custom_llm_test.py +++ b/tests/custom_llm_test.py @@ -15,13 +15,13 @@ class CustomLLM(BaseLLM): that returns a predefined response for testing purposes. """ - def __init__(self, response="Default response"): + def __init__(self, response="Default response", model="test-model"): """Initialize the CustomLLM with a predefined response. Args: response: The predefined response to return from call(). """ - super().__init__() + super().__init__(model=model) self.response = response self.call_count = 0 @@ -99,7 +99,7 @@ def test_custom_llm_implementation(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_custom_llm_within_crew(): """Test that a custom LLM implementation works with create_llm.""" - custom_llm = CustomLLM(response="Hello! Nice to meet you!") + custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model") agent = Agent( role="Say Hi", @@ -130,7 +130,7 @@ def test_custom_llm_within_crew(): def test_custom_llm_message_formatting(): """Test that the custom LLM properly formats messages""" - custom_llm = CustomLLM(response="Test response") + custom_llm = CustomLLM(response="Test response", model="test-model") # Test with string input result = custom_llm.call("Test message") @@ -149,7 +149,7 @@ class JWTAuthLLM(BaseLLM): """Custom LLM implementation with JWT authentication.""" def __init__(self, jwt_token: str): - super().__init__() + super().__init__(model="test-model") if not jwt_token or not isinstance(jwt_token, str): raise ValueError("Invalid JWT token") self.jwt_token = jwt_token @@ -228,7 +228,7 @@ class TimeoutHandlingLLM(BaseLLM): max_retries: Maximum number of retry attempts. timeout: Timeout in seconds for each API call. """ - super().__init__() + super().__init__(model="test-model") self.max_retries = max_retries self.timeout = timeout self.calls = []