Files
crewAI/lib/crewai/tests/test_custom_llm.py
Greyson LaLonde 20704742e2
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
feat: async llm support
feat: introduce async contract to BaseLLM

feat: add async call support for:

Azure provider

Anthropic provider

OpenAI provider

Gemini provider

Bedrock provider

LiteLLM provider

chore: expand scrubbed header fields (conftest, anthropic, bedrock)

chore: update docs to cover async functionality

chore: update and harden tests to support acall; re-add uri for cassette compatibility

chore: generate missing cassette

fix: ensure acall is non-abstract and set supports_tools = true for supported Anthropic models

chore: improve Bedrock async docstring and general test robustness
2025-12-01 18:56:56 -05:00

373 lines
12 KiB
Python

from typing import Any, Dict, List, Optional, Union
import pytest
from crewai import Agent, Crew, Process, Task
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
that returns a predefined response for testing purposes.
"""
def __init__(self, response="Default response", model="test-model"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__(model=model)
self.response = response
self.call_count = 0
def call(
self,
messages,
tools=None,
callbacks=None,
available_functions=None,
from_task=None,
from_agent=None,
response_model=None,
):
"""
Mock LLM call that returns a predefined response.
Properly formats messages to match OpenAI's expected structure.
"""
self.call_count += 1
# If input is a string, convert to proper message format
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Ensure each message has properly formatted content
for message in messages:
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
# Return predefined response in expected format
if "Thought:" in str(messages):
return f"Thought: I will say hi\nFinal Answer: {self.response}"
return self.response
def supports_function_calling(self) -> bool:
"""Return False to indicate that function calling is not supported.
Returns:
False, indicating that this LLM does not support function calling.
"""
return False
def supports_stop_words(self) -> bool:
"""Return False to indicate that stop words are not supported.
Returns:
False, indicating that this LLM does not support stop words.
"""
return False
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
4096, a typical context window size for modern LLMs.
"""
return 4096
async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
raise NotImplementedError
@pytest.mark.vcr()
def test_custom_llm_implementation():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
# Test that create_llm returns the custom LLM instance directly
result_llm = create_llm(custom_llm)
assert result_llm is custom_llm
# Test calling the custom LLM
response = result_llm.call(
"What is the answer to life, the universe, and everything?"
)
# Verify that the response from the custom LLM was used
assert "42" in response
@pytest.mark.vcr()
def test_custom_llm_within_crew():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
agent = Agent(
role="Say Hi",
goal="Say hi to the user",
backstory="""You just say hi to the user""",
llm=custom_llm,
)
task = Task(
description="Say hi to the user",
expected_output="A greeting to the user",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
)
result = crew.kickoff()
# Assert the LLM was called
assert custom_llm.call_count > 0
# Assert we got a response
assert "Hello!" in result.raw
def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages"""
custom_llm = CustomLLM(response="Test response", model="test-model")
# Test with string input
result = custom_llm.call("Test message")
assert result == "Test response"
# Test with message list
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
]
result = custom_llm.call(messages)
assert result == "Test response"
class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
super().__init__(model="test-model")
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token")
self.jwt_token = jwt_token
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task=None,
from_agent=None,
response_model=None,
) -> Union[str, Any]:
"""Record the call and return a predefined response."""
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# In a real implementation, this would use the JWT token to authenticate
# with an external service
return "Response from JWT-authenticated LLM"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
return 8192
async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
raise NotImplementedError
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
assert result_llm is jwt_llm
# Test calling the JWT-authenticated LLM
response = result_llm.call("Test message")
# Verify that the JWT-authenticated LLM was called
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"
def test_jwt_auth_llm_validation():
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")
# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)
class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__(model="test-model")
self.max_retries = max_retries
self.timeout = timeout
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task=None,
from_agent=None,
response_model=None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0,
}
)
# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
if attempt == 0:
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on first attempt
return "First attempt response"
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on retry
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
raise NotImplementedError
def test_timeout_handling_llm():
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1
# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call
# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt