Enhance custom LLM implementation with better error handling, documentation, and test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-04 17:50:52 +00:00
parent 22aeeaadbe
commit 963ed23b63
4 changed files with 686 additions and 94 deletions

View File

@@ -7,9 +7,19 @@ from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing."""
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
that returns a predefined response for testing purposes.
"""
def __init__(self, response: str = "Custom LLM response"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__()
self.response = response
self.calls = []
self.stop = []
@@ -21,7 +31,17 @@ class CustomLLM(BaseLLM):
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return the predefined response."""
"""Record the call and return the predefined response.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The predefined response string.
"""
self.calls.append({
"messages": messages,
"tools": tools,
@@ -31,15 +51,27 @@ class CustomLLM(BaseLLM):
return self.response
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
@@ -119,3 +151,147 @@ def test_custom_llm_with_jwt_auth():
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"
def test_jwt_auth_llm_validation():
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")
# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)
class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__()
self.max_retries = max_retries
self.timeout = timeout
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0
})
# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
if attempt == 0:
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append({
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
def test_timeout_handling_llm():
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1
# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call
# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt