Enhance CustomLLM and JWTAuthLLM initialization with model parameter

- Update CustomLLM to accept a model parameter during initialization
- Modify test cases to include the new model argument
- Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors
- Update type hints in create_llm function to support both LLM and BaseLLM types
This commit is contained in:
Lorenze Jay
2025-03-12 08:16:59 -07:00
parent b305ef8f48
commit 902c330113
2 changed files with 7 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]:
) -> Optional[LLM | BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.

View File

@@ -15,13 +15,13 @@ class CustomLLM(BaseLLM):
that returns a predefined response for testing purposes.
"""
def __init__(self, response="Default response"):
def __init__(self, response="Default response", model="test-model"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__()
super().__init__(model=model)
self.response = response
self.call_count = 0
@@ -99,7 +99,7 @@ def test_custom_llm_implementation():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_within_crew():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!")
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
agent = Agent(
role="Say Hi",
@@ -130,7 +130,7 @@ def test_custom_llm_within_crew():
def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages"""
custom_llm = CustomLLM(response="Test response")
custom_llm = CustomLLM(response="Test response", model="test-model")
# Test with string input
result = custom_llm.call("Test message")
@@ -149,7 +149,7 @@ class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
super().__init__()
super().__init__(model="test-model")
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token")
self.jwt_token = jwt_token
@@ -228,7 +228,7 @@ class TimeoutHandlingLLM(BaseLLM):
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__()
super().__init__(model="test-model")
self.max_retries = max_retries
self.timeout = timeout
self.calls = []