mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Enhance CustomLLM and JWTAuthLLM initialization with model parameter
- Update CustomLLM to accept a model parameter during initialization - Modify test cases to include the new model argument - Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors - Update type hints in create_llm function to support both LLM and BaseLLM types
This commit is contained in:
@@ -7,7 +7,7 @@ from crewai.llm import LLM, BaseLLM
|
|||||||
|
|
||||||
def create_llm(
|
def create_llm(
|
||||||
llm_value: Union[str, LLM, Any, None] = None,
|
llm_value: Union[str, LLM, Any, None] = None,
|
||||||
) -> Optional[LLM]:
|
) -> Optional[LLM | BaseLLM]:
|
||||||
"""
|
"""
|
||||||
Creates or returns an LLM instance based on the given llm_value.
|
Creates or returns an LLM instance based on the given llm_value.
|
||||||
|
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ class CustomLLM(BaseLLM):
|
|||||||
that returns a predefined response for testing purposes.
|
that returns a predefined response for testing purposes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, response="Default response"):
|
def __init__(self, response="Default response", model="test-model"):
|
||||||
"""Initialize the CustomLLM with a predefined response.
|
"""Initialize the CustomLLM with a predefined response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The predefined response to return from call().
|
response: The predefined response to return from call().
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(model=model)
|
||||||
self.response = response
|
self.response = response
|
||||||
self.call_count = 0
|
self.call_count = 0
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ def test_custom_llm_implementation():
|
|||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_custom_llm_within_crew():
|
def test_custom_llm_within_crew():
|
||||||
"""Test that a custom LLM implementation works with create_llm."""
|
"""Test that a custom LLM implementation works with create_llm."""
|
||||||
custom_llm = CustomLLM(response="Hello! Nice to meet you!")
|
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="Say Hi",
|
role="Say Hi",
|
||||||
@@ -130,7 +130,7 @@ def test_custom_llm_within_crew():
|
|||||||
|
|
||||||
def test_custom_llm_message_formatting():
|
def test_custom_llm_message_formatting():
|
||||||
"""Test that the custom LLM properly formats messages"""
|
"""Test that the custom LLM properly formats messages"""
|
||||||
custom_llm = CustomLLM(response="Test response")
|
custom_llm = CustomLLM(response="Test response", model="test-model")
|
||||||
|
|
||||||
# Test with string input
|
# Test with string input
|
||||||
result = custom_llm.call("Test message")
|
result = custom_llm.call("Test message")
|
||||||
@@ -149,7 +149,7 @@ class JWTAuthLLM(BaseLLM):
|
|||||||
"""Custom LLM implementation with JWT authentication."""
|
"""Custom LLM implementation with JWT authentication."""
|
||||||
|
|
||||||
def __init__(self, jwt_token: str):
|
def __init__(self, jwt_token: str):
|
||||||
super().__init__()
|
super().__init__(model="test-model")
|
||||||
if not jwt_token or not isinstance(jwt_token, str):
|
if not jwt_token or not isinstance(jwt_token, str):
|
||||||
raise ValueError("Invalid JWT token")
|
raise ValueError("Invalid JWT token")
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
@@ -228,7 +228,7 @@ class TimeoutHandlingLLM(BaseLLM):
|
|||||||
max_retries: Maximum number of retry attempts.
|
max_retries: Maximum number of retry attempts.
|
||||||
timeout: Timeout in seconds for each API call.
|
timeout: Timeout in seconds for each API call.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(model="test-model")
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
|||||||
Reference in New Issue
Block a user