mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
* Add support for custom LLM implementations Co-Authored-By: Joe Moura <joao@crewai.com> * Fix import sorting and type annotations Co-Authored-By: Joe Moura <joao@crewai.com> * Fix linting issues with import sorting Co-Authored-By: Joe Moura <joao@crewai.com> * Fix type errors in crew.py by updating tool-related methods to return List[BaseTool] Co-Authored-By: Joe Moura <joao@crewai.com> * Enhance custom LLM implementation with better error handling, documentation, and test coverage Co-Authored-By: Joe Moura <joao@crewai.com> * Refactor LLM module by extracting BaseLLM to a separate file This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include: - Creating a new file src/crewai/llms/base_llm.py - Moving the BaseLLM class to the new file - Updating imports in __init__.py and llm.py to reflect the new location - Updating test cases to use the new import path The refactoring maintains the existing functionality while improving the project's module structure. * Add AISuite LLM support and update dependencies - Integrate AISuite as a new third-party LLM option - Update pyproject.toml and uv.lock to include aisuite package - Modify BaseLLM to support more flexible initialization - Remove unnecessary LLM imports across multiple files - Implement AISuiteLLM with basic chat completion functionality * Update AISuiteLLM and LLM utility type handling - Modify AISuiteLLM to support more flexible input types for messages - Update type hints in AISuiteLLM to allow string or list of message dictionaries - Enhance LLM utility function to support broader LLM type annotations - Remove default `self.stop` attribute from BaseLLM initialization * Update LLM imports and type hints across multiple files - Modify imports in crew_chat.py to use LLM instead of BaseLLM - Update type hints in llm_utils.py to use LLM type - Add optional `stop` parameter to BaseLLM initialization - Refactor type handling for LLM creation and usage * Improve stop words handling in CrewAgentExecutor - Add support for handling existing stop words in LLM configuration - Ensure stop words are correctly merged and deduplicated - Update type hints to support both LLM and BaseLLM types * Remove abstract method set_callbacks from BaseLLM class * Enhance CustomLLM and JWTAuthLLM initialization with model parameter - Update CustomLLM to accept a model parameter during initialization - Modify test cases to include the new model argument - Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors - Update type hints in create_llm function to support both LLM and BaseLLM types * Enhance create_llm function to support BaseLLM type - Update the create_llm function to accept both LLM and BaseLLM instances - Ensure compatibility with existing LLM handling logic * Update type hint for initialize_chat_llm to support BaseLLM - Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances - Ensure compatibility with recent changes in create_llm function * Refactor AISuiteLLM to include tools parameter in completion methods - Update the _prepare_completion_params method to accept an optional tools parameter - Modify the chat completion method to utilize the new tools parameter for enhanced functionality - Clean up print statements for better code clarity * Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code. * Refactor Crew class and LLM hierarchy for improved type handling and code clarity - Update Crew class methods to enhance readability with consistent formatting and type hints. - Change LLM class to inherit from BaseLLM for better structure. - Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor. - Adjust BaseLLM to provide default implementations for stop words and context window size methods. - Clean up AISuiteLLM by removing unused methods related to stop words and context window size. * Remove unused `stream` method from `BaseLLM` class to enhance code clarity and maintainability. --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com> Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com> Co-authored-by: João Moura <joaomdmoura@gmail.com> Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
360 lines
12 KiB
Python
360 lines
12 KiB
Python
from typing import Any, Dict, List, Optional, Union
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from crewai import Agent, Crew, Process, Task
|
|
from crewai.llms.base_llm import BaseLLM
|
|
from crewai.utilities.llm_utils import create_llm
|
|
|
|
|
|
class CustomLLM(BaseLLM):
|
|
"""Custom LLM implementation for testing.
|
|
|
|
This is a simple implementation of the BaseLLM abstract base class
|
|
that returns a predefined response for testing purposes.
|
|
"""
|
|
|
|
def __init__(self, response="Default response", model="test-model"):
|
|
"""Initialize the CustomLLM with a predefined response.
|
|
|
|
Args:
|
|
response: The predefined response to return from call().
|
|
"""
|
|
super().__init__(model=model)
|
|
self.response = response
|
|
self.call_count = 0
|
|
|
|
def call(
|
|
self,
|
|
messages,
|
|
tools=None,
|
|
callbacks=None,
|
|
available_functions=None,
|
|
):
|
|
"""
|
|
Mock LLM call that returns a predefined response.
|
|
Properly formats messages to match OpenAI's expected structure.
|
|
"""
|
|
self.call_count += 1
|
|
|
|
# If input is a string, convert to proper message format
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
|
|
# Ensure each message has properly formatted content
|
|
for message in messages:
|
|
if isinstance(message["content"], str):
|
|
message["content"] = [{"type": "text", "text": message["content"]}]
|
|
|
|
# Return predefined response in expected format
|
|
if "Thought:" in str(messages):
|
|
return f"Thought: I will say hi\nFinal Answer: {self.response}"
|
|
return self.response
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Return False to indicate that function calling is not supported.
|
|
|
|
Returns:
|
|
False, indicating that this LLM does not support function calling.
|
|
"""
|
|
return False
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Return False to indicate that stop words are not supported.
|
|
|
|
Returns:
|
|
False, indicating that this LLM does not support stop words.
|
|
"""
|
|
return False
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Return a default context window size.
|
|
|
|
Returns:
|
|
4096, a typical context window size for modern LLMs.
|
|
"""
|
|
return 4096
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_custom_llm_implementation():
|
|
"""Test that a custom LLM implementation works with create_llm."""
|
|
custom_llm = CustomLLM(response="The answer is 42")
|
|
|
|
# Test that create_llm returns the custom LLM instance directly
|
|
result_llm = create_llm(custom_llm)
|
|
|
|
assert result_llm is custom_llm
|
|
|
|
# Test calling the custom LLM
|
|
response = result_llm.call(
|
|
"What is the answer to life, the universe, and everything?"
|
|
)
|
|
|
|
# Verify that the response from the custom LLM was used
|
|
assert "42" in response
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_custom_llm_within_crew():
|
|
"""Test that a custom LLM implementation works with create_llm."""
|
|
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
|
|
|
|
agent = Agent(
|
|
role="Say Hi",
|
|
goal="Say hi to the user",
|
|
backstory="""You just say hi to the user""",
|
|
llm=custom_llm,
|
|
)
|
|
|
|
task = Task(
|
|
description="Say hi to the user",
|
|
expected_output="A greeting to the user",
|
|
agent=agent,
|
|
)
|
|
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
process=Process.sequential,
|
|
)
|
|
|
|
result = crew.kickoff()
|
|
|
|
# Assert the LLM was called
|
|
assert custom_llm.call_count > 0
|
|
# Assert we got a response
|
|
assert "Hello!" in result.raw
|
|
|
|
|
|
def test_custom_llm_message_formatting():
|
|
"""Test that the custom LLM properly formats messages"""
|
|
custom_llm = CustomLLM(response="Test response", model="test-model")
|
|
|
|
# Test with string input
|
|
result = custom_llm.call("Test message")
|
|
assert result == "Test response"
|
|
|
|
# Test with message list
|
|
messages = [
|
|
{"role": "system", "content": "System message"},
|
|
{"role": "user", "content": "User message"},
|
|
]
|
|
result = custom_llm.call(messages)
|
|
assert result == "Test response"
|
|
|
|
|
|
class JWTAuthLLM(BaseLLM):
|
|
"""Custom LLM implementation with JWT authentication."""
|
|
|
|
def __init__(self, jwt_token: str):
|
|
super().__init__(model="test-model")
|
|
if not jwt_token or not isinstance(jwt_token, str):
|
|
raise ValueError("Invalid JWT token")
|
|
self.jwt_token = jwt_token
|
|
self.calls = []
|
|
self.stop = []
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
"""Record the call and return a predefined response."""
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"callbacks": callbacks,
|
|
"available_functions": available_functions,
|
|
}
|
|
)
|
|
# In a real implementation, this would use the JWT token to authenticate
|
|
# with an external service
|
|
return "Response from JWT-authenticated LLM"
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Return True to indicate that function calling is supported."""
|
|
return True
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Return True to indicate that stop words are supported."""
|
|
return True
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Return a default context window size."""
|
|
return 8192
|
|
|
|
|
|
def test_custom_llm_with_jwt_auth():
|
|
"""Test a custom LLM implementation with JWT authentication."""
|
|
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
|
|
|
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
|
result_llm = create_llm(jwt_llm)
|
|
|
|
assert result_llm is jwt_llm
|
|
|
|
# Test calling the JWT-authenticated LLM
|
|
response = result_llm.call("Test message")
|
|
|
|
# Verify that the JWT-authenticated LLM was called
|
|
assert len(jwt_llm.calls) > 0
|
|
# Verify that the response from the JWT-authenticated LLM was used
|
|
assert response == "Response from JWT-authenticated LLM"
|
|
|
|
|
|
def test_jwt_auth_llm_validation():
|
|
"""Test that JWT token validation works correctly."""
|
|
# Test with invalid JWT token (empty string)
|
|
with pytest.raises(ValueError, match="Invalid JWT token"):
|
|
JWTAuthLLM(jwt_token="")
|
|
|
|
# Test with invalid JWT token (non-string)
|
|
with pytest.raises(ValueError, match="Invalid JWT token"):
|
|
JWTAuthLLM(jwt_token=None)
|
|
|
|
|
|
class TimeoutHandlingLLM(BaseLLM):
|
|
"""Custom LLM implementation with timeout handling and retry logic."""
|
|
|
|
def __init__(self, max_retries: int = 3, timeout: int = 30):
|
|
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
|
|
|
|
Args:
|
|
max_retries: Maximum number of retry attempts.
|
|
timeout: Timeout in seconds for each API call.
|
|
"""
|
|
super().__init__(model="test-model")
|
|
self.max_retries = max_retries
|
|
self.timeout = timeout
|
|
self.calls = []
|
|
self.stop = []
|
|
self.fail_count = 0 # Number of times to simulate failure
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
"""Simulate API calls with timeout handling and retry logic.
|
|
|
|
Args:
|
|
messages: Input messages for the LLM.
|
|
tools: Optional list of tool schemas for function calling.
|
|
callbacks: Optional list of callback functions.
|
|
available_functions: Optional dict mapping function names to callables.
|
|
|
|
Returns:
|
|
A response string based on whether this is the first attempt or a retry.
|
|
|
|
Raises:
|
|
TimeoutError: If all retry attempts fail.
|
|
"""
|
|
# Record the initial call
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"callbacks": callbacks,
|
|
"available_functions": available_functions,
|
|
"attempt": 0,
|
|
}
|
|
)
|
|
|
|
# Simulate retry logic
|
|
for attempt in range(self.max_retries):
|
|
# Skip the first attempt recording since we already did that above
|
|
if attempt == 0:
|
|
# Simulate a failure if fail_count > 0
|
|
if self.fail_count > 0:
|
|
self.fail_count -= 1
|
|
# If we've used all retries, raise an error
|
|
if attempt == self.max_retries - 1:
|
|
raise TimeoutError(
|
|
f"LLM request failed after {self.max_retries} attempts"
|
|
)
|
|
# Otherwise, continue to the next attempt (simulating backoff)
|
|
continue
|
|
else:
|
|
# Success on first attempt
|
|
return "First attempt response"
|
|
else:
|
|
# This is a retry attempt (attempt > 0)
|
|
# Always record retry attempts
|
|
self.calls.append(
|
|
{
|
|
"retry_attempt": attempt,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"callbacks": callbacks,
|
|
"available_functions": available_functions,
|
|
}
|
|
)
|
|
|
|
# Simulate a failure if fail_count > 0
|
|
if self.fail_count > 0:
|
|
self.fail_count -= 1
|
|
# If we've used all retries, raise an error
|
|
if attempt == self.max_retries - 1:
|
|
raise TimeoutError(
|
|
f"LLM request failed after {self.max_retries} attempts"
|
|
)
|
|
# Otherwise, continue to the next attempt (simulating backoff)
|
|
continue
|
|
else:
|
|
# Success on retry
|
|
return "Response after retry"
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Return True to indicate that function calling is supported.
|
|
|
|
Returns:
|
|
True, indicating that this LLM supports function calling.
|
|
"""
|
|
return True
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Return True to indicate that stop words are supported.
|
|
|
|
Returns:
|
|
True, indicating that this LLM supports stop words.
|
|
"""
|
|
return True
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Return a default context window size.
|
|
|
|
Returns:
|
|
8192, a typical context window size for modern LLMs.
|
|
"""
|
|
return 8192
|
|
|
|
|
|
def test_timeout_handling_llm():
|
|
"""Test a custom LLM implementation with timeout handling and retry logic."""
|
|
# Test successful first attempt
|
|
llm = TimeoutHandlingLLM()
|
|
response = llm.call("Test message")
|
|
assert response == "First attempt response"
|
|
assert len(llm.calls) == 1
|
|
|
|
# Test successful retry
|
|
llm = TimeoutHandlingLLM()
|
|
llm.fail_count = 1 # Fail once, then succeed
|
|
response = llm.call("Test message")
|
|
assert response == "Response after retry"
|
|
assert len(llm.calls) == 2 # Initial call + successful retry call
|
|
|
|
# Test failure after all retries
|
|
llm = TimeoutHandlingLLM(max_retries=2)
|
|
llm.fail_count = 2 # Fail twice, which is all retries
|
|
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
|
llm.call("Test message")
|
|
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|