mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Enhance CustomLLM and JWTAuthLLM initialization with model parameter
- Update CustomLLM to accept a model parameter during initialization - Modify test cases to include the new model argument - Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors - Update type hints in create_llm function to support both LLM and BaseLLM types
This commit is contained in:
@@ -7,7 +7,7 @@ from crewai.llm import LLM, BaseLLM
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
) -> Optional[LLM]:
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
|
||||
|
||||
@@ -15,13 +15,13 @@ class CustomLLM(BaseLLM):
|
||||
that returns a predefined response for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self, response="Default response"):
|
||||
def __init__(self, response="Default response", model="test-model"):
|
||||
"""Initialize the CustomLLM with a predefined response.
|
||||
|
||||
Args:
|
||||
response: The predefined response to return from call().
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(model=model)
|
||||
self.response = response
|
||||
self.call_count = 0
|
||||
|
||||
@@ -99,7 +99,7 @@ def test_custom_llm_implementation():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_custom_llm_within_crew():
|
||||
"""Test that a custom LLM implementation works with create_llm."""
|
||||
custom_llm = CustomLLM(response="Hello! Nice to meet you!")
|
||||
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
|
||||
|
||||
agent = Agent(
|
||||
role="Say Hi",
|
||||
@@ -130,7 +130,7 @@ def test_custom_llm_within_crew():
|
||||
|
||||
def test_custom_llm_message_formatting():
|
||||
"""Test that the custom LLM properly formats messages"""
|
||||
custom_llm = CustomLLM(response="Test response")
|
||||
custom_llm = CustomLLM(response="Test response", model="test-model")
|
||||
|
||||
# Test with string input
|
||||
result = custom_llm.call("Test message")
|
||||
@@ -149,7 +149,7 @@ class JWTAuthLLM(BaseLLM):
|
||||
"""Custom LLM implementation with JWT authentication."""
|
||||
|
||||
def __init__(self, jwt_token: str):
|
||||
super().__init__()
|
||||
super().__init__(model="test-model")
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
raise ValueError("Invalid JWT token")
|
||||
self.jwt_token = jwt_token
|
||||
@@ -228,7 +228,7 @@ class TimeoutHandlingLLM(BaseLLM):
|
||||
max_retries: Maximum number of retry attempts.
|
||||
timeout: Timeout in seconds for each API call.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(model="test-model")
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.calls = []
|
||||
|
||||
Reference in New Issue
Block a user