Compare commits

...

26 Commits

Author SHA1 Message Date
Lorenze Jay
02d262edf3 Merge branch 'devin/1741108142-custom-llm-support' of github.com:crewAIInc/crewAI into devin/1741108142-custom-llm-support 2025-03-25 10:08:13 -07:00
Lorenze Jay
4778117344 Implement dynamic import and installation prompt for 'aisuite' package in AISuiteLLM class to ensure required dependencies are met at runtime. 2025-03-25 10:07:31 -07:00
Brandon Hancock (bhancock_ai)
ac995cab55 Merge branch 'main' into devin/1741108142-custom-llm-support 2025-03-25 12:36:54 -04:00
Lorenze Jay
773da3f994 Remove unused stream method from BaseLLM class to enhance code clarity and maintainability. 2025-03-24 14:07:10 -07:00
Lorenze Jay
ae7f7468d7 Merge branch 'main' of github.com:crewAIInc/crewAI into devin/1741108142-custom-llm-support 2025-03-24 14:02:49 -07:00
Lorenze Jay
782ae12694 Merge branch 'devin/1741108142-custom-llm-support' of github.com:crewAIInc/crewAI into devin/1741108142-custom-llm-support 2025-03-24 13:56:27 -07:00
Lorenze Jay
e659c352df Refactor Crew class and LLM hierarchy for improved type handling and code clarity
- Update Crew class methods to enhance readability with consistent formatting and type hints.
- Change LLM class to inherit from BaseLLM for better structure.
- Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor.
- Adjust BaseLLM to provide default implementations for stop words and context window size methods.
- Clean up AISuiteLLM by removing unused methods related to stop words and context window size.
2025-03-24 13:56:23 -07:00
João Moura
9fc7d2de0a Merge branch 'main' into devin/1741108142-custom-llm-support 2025-03-14 03:11:33 -03:00
Lorenze Jay
7cae76a631 Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code. 2025-03-13 11:59:20 -07:00
Lorenze Jay
21cbbd790c Merge branch 'main' of github.com:crewAIInc/crewAI into devin/1741108142-custom-llm-support 2025-03-13 11:56:08 -07:00
Lorenze Jay
144b96986e Refactor AISuiteLLM to include tools parameter in completion methods
- Update the _prepare_completion_params method to accept an optional tools parameter
- Modify the chat completion method to utilize the new tools parameter for enhanced functionality
- Clean up print statements for better code clarity
2025-03-13 08:10:33 -07:00
Lorenze Jay
9a70528296 Update type hint for initialize_chat_llm to support BaseLLM
- Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances
- Ensure compatibility with recent changes in create_llm function
2025-03-12 14:59:37 -07:00
Lorenze Jay
80c31fb55b Enhance create_llm function to support BaseLLM type
- Update the create_llm function to accept both LLM and BaseLLM instances
- Ensure compatibility with existing LLM handling logic
2025-03-12 14:56:47 -07:00
Lorenze Jay
902c330113 Enhance CustomLLM and JWTAuthLLM initialization with model parameter
- Update CustomLLM to accept a model parameter during initialization
- Modify test cases to include the new model argument
- Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors
- Update type hints in create_llm function to support both LLM and BaseLLM types
2025-03-12 08:16:59 -07:00
Lorenze Jay
b305ef8f48 Remove abstract method set_callbacks from BaseLLM class 2025-03-12 08:16:35 -07:00
Lorenze Jay
afe220d3e8 Improve stop words handling in CrewAgentExecutor
- Add support for handling existing stop words in LLM configuration
- Ensure stop words are correctly merged and deduplicated
- Update type hints to support both LLM and BaseLLM types
2025-03-12 08:15:30 -07:00
Lorenze Jay
7cb3c8bb4b Update LLM imports and type hints across multiple files
- Modify imports in crew_chat.py to use LLM instead of BaseLLM
- Update type hints in llm_utils.py to use LLM type
- Add optional `stop` parameter to BaseLLM initialization
- Refactor type handling for LLM creation and usage
2025-03-11 16:01:10 -07:00
Lorenze Jay
a40abbf490 Update AISuiteLLM and LLM utility type handling
- Modify AISuiteLLM to support more flexible input types for messages
- Update type hints in AISuiteLLM to allow string or list of message dictionaries
- Enhance LLM utility function to support broader LLM type annotations
- Remove default `self.stop` attribute from BaseLLM initialization
2025-03-11 15:57:08 -07:00
Lorenze Jay
25c64ae86d Add AISuite LLM support and update dependencies
- Integrate AISuite as a new third-party LLM option
- Update pyproject.toml and uv.lock to include aisuite package
- Modify BaseLLM to support more flexible initialization
- Remove unnecessary LLM imports across multiple files
- Implement AISuiteLLM with basic chat completion functionality
2025-03-11 15:48:49 -07:00
Lorenze Jay
0cece5fd59 Merge branch 'main' of github.com:crewAIInc/crewAI into devin/1741108142-custom-llm-support 2025-03-10 09:34:33 -07:00
Lorenze Jay
709941c4c7 Refactor LLM module by extracting BaseLLM to a separate file
This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include:

- Creating a new file src/crewai/llms/base_llm.py
- Moving the BaseLLM class to the new file
- Updating imports in __init__.py and llm.py to reflect the new location
- Updating test cases to use the new import path

The refactoring maintains the existing functionality while improving the project's module structure.
2025-03-04 15:54:46 -08:00
Devin AI
963ed23b63 Enhance custom LLM implementation with better error handling, documentation, and test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:50:52 +00:00
Devin AI
22aeeaadbe Fix type errors in crew.py by updating tool-related methods to return List[BaseTool]
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:42:58 +00:00
Devin AI
7201161207 Fix linting issues with import sorting
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:19:36 +00:00
Devin AI
687303ad63 Fix import sorting and type annotations
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:13:12 +00:00
Devin AI
ec8e705bbc Add support for custom LLM implementations
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:09:17 +00:00
16 changed files with 1691 additions and 61 deletions

642
docs/custom_llm.md Normal file
View File

@@ -0,0 +1,642 @@
# Custom LLM Implementations
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
## Using Custom LLM Implementations
To create a custom LLM implementation, you need to:
1. Inherit from the `BaseLLM` abstract base class
2. Implement the required methods:
- `call()`: The main method to call the LLM with messages
- `supports_function_calling()`: Whether the LLM supports function calling
- `supports_stop_words()`: Whether the LLM supports stop words
- `get_context_window_size()`: The context window size of the LLM
## Example: Basic Custom LLM
```python
from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM
# For example, using requests:
import requests
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
# Return True if your LLM supports function calling
return True
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
# Return True if your LLM supports stop words
return True
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
# Return the context window size of your LLM
return 8192
```
## Error Handling Best Practices
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
### 1. Implement Try-Except Blocks for API Calls
Always wrap API calls in try-except blocks to handle different types of errors:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
# API call implementation
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```
### 2. Implement Retry Logic for Transient Failures
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import time
max_retries = 3
retry_delay = 1 # seconds
for attempt in range(max_retries):
try:
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except (requests.Timeout, requests.ConnectionError) as e:
if attempt < max_retries - 1:
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
continue
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
```
### 3. Validate Input Parameters
Always validate input parameters to prevent runtime errors:
```python
def __init__(self, api_key: str, endpoint: str):
super().__init__()
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
```
### 4. Handle Authentication Errors Gracefully
Provide clear error messages for authentication failures:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
response = requests.post(self.endpoint, headers=self.headers, json=data)
if response.status_code == 401:
raise ValueError("Authentication failed: Invalid API key or token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")
response.raise_for_status()
# Process response
except Exception as e:
# Handle error
raise
```
## Example: JWT-based Authentication
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
```python
from crewai import BaseLLM, Agent, Task
from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM):
def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.jwt_token = jwt_token
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with JWT authentication.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM with JWT authentication
import requests
try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)
if response.status_code == 401:
raise ValueError("Authentication failed: Invalid JWT token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return True
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
return True
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
return 8192
```
## Troubleshooting
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
### 1. Authentication Failures
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
**Solutions**:
- Verify that your API key or JWT token is valid and not expired
- Check that you're using the correct authentication header format
- Ensure that your token has the necessary permissions
### 2. Timeout Issues
**Symptoms**: Requests taking too long or timing out
**Solutions**:
- Implement timeout handling as shown in the examples
- Use retry logic with exponential backoff
- Consider using a more reliable network connection
### 3. Response Parsing Errors
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
**Solutions**:
- Validate the response format before accessing nested fields
- Implement proper error handling for malformed responses
- Check the API documentation for the expected response format
### 4. Rate Limiting
**Symptoms**: 429 Too Many Requests errors
**Solutions**:
- Implement rate limiting in your custom LLM
- Add exponential backoff for retries
- Consider using a token bucket algorithm for more precise rate control
## Advanced Features
### Logging
Adding logging to your custom LLM can help with debugging and monitoring:
```python
import logging
from typing import Any, Dict, List, Optional, Union
class LoggingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.logger = logging.getLogger("crewai.llm.custom")
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
try:
# API call implementation
response = self._make_api_call(messages, tools)
self.logger.debug(f"LLM response received: {response[:100]}...")
return response
except Exception as e:
self.logger.error(f"LLM call failed: {str(e)}")
raise
```
### Rate Limiting
Implementing rate limiting can help avoid overwhelming the LLM API:
```python
import time
from typing import Any, Dict, List, Optional, Union
class RateLimitedLLM(BaseLLM):
def __init__(
self,
api_key: str,
endpoint: str,
requests_per_minute: int = 60
):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.requests_per_minute = requests_per_minute
self.request_times: List[float] = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self._enforce_rate_limit()
# Record this request time
self.request_times.append(time.time())
# Make the actual API call
return self._make_api_call(messages, tools)
def _enforce_rate_limit(self) -> None:
"""Enforce the rate limit by waiting if necessary."""
now = time.time()
# Remove request times older than 1 minute
self.request_times = [t for t in self.request_times if now - t < 60]
if len(self.request_times) >= self.requests_per_minute:
# Calculate how long to wait
oldest_request = min(self.request_times)
wait_time = 60 - (now - oldest_request)
if wait_time > 0:
time.sleep(wait_time)
```
### Metrics Collection
Collecting metrics can help you monitor your LLM usage:
```python
import time
from typing import Any, Dict, List, Optional, Union
class MetricsCollectingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.metrics: Dict[str, Any] = {
"total_calls": 0,
"total_tokens": 0,
"errors": 0,
"latency": []
}
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
start_time = time.time()
self.metrics["total_calls"] += 1
try:
response = self._make_api_call(messages, tools)
# Estimate tokens (simplified)
if isinstance(messages, str):
token_estimate = len(messages) // 4
else:
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
self.metrics["total_tokens"] += token_estimate
return response
except Exception as e:
self.metrics["errors"] += 1
raise
finally:
latency = time.time() - start_time
self.metrics["latency"].append(latency)
def get_metrics(self) -> Dict[str, Any]:
"""Return the collected metrics."""
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
return {
**self.metrics,
"avg_latency": avg_latency
}
```
## Advanced Usage: Function Calling
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
```python
import json
from typing import Any, Dict, List, Optional, Union
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import requests
try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
response_data = response.json()
# Check if the LLM wants to call a function
if response_data["choices"][0]["message"].get("tool_calls"):
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
# Process each tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
if available_functions and function_name in available_functions:
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
# Add the function response to the messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_response)
})
# Call the LLM again with the updated messages
return self.call(messages, tools, callbacks, available_functions)
# Return the text response if no function call
return response_data["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```
## Using Your Custom LLM with CrewAI
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
```python
from crewai import Agent, Task, Crew
from typing import Dict, Any
# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
jwt_token="your.jwt.token",
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)
# Use it with an agent
agent = Agent(
role="Research Assistant",
goal="Find information on a topic",
backstory="You are a research assistant tasked with finding information.",
llm=jwt_llm,
)
# Create a task for the agent
task = Task(
description="Research the benefits of exercise",
agent=agent,
expected_output="A summary of the benefits of exercise",
)
# Execute the task
result = agent.execute_task(task)
print(result)
# Or use it with a crew
crew = Crew(
agents=[agent],
tasks=[task],
manager_llm=jwt_llm, # Use your custom LLM for the manager
)
# Run the crew
result = crew.kickoff()
print(result)
```
## Implementing Your Own Authentication Mechanism
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
- OAuth tokens
- Client certificates
- Custom headers
- Session-based authentication
- Any other authentication method required by your LLM provider
Simply implement the appropriate authentication logic in your custom LLM class.

View File

@@ -64,6 +64,9 @@ mem0 = ["mem0ai>=0.1.29"]
docling = [
"docling>=2.12.0",
]
aisuite = [
"aisuite>=0.1.10",
]
[tool.uv]
dev-dependencies = [

View File

@@ -5,6 +5,7 @@ from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.process import Process
from crewai.task import Task
@@ -21,6 +22,7 @@ __all__ = [
"Process",
"Task",
"LLM",
"BaseLLM",
"Flow",
"Knowledge",
]

View File

@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.llm import BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.security import Fingerprint
from crewai.task import Task
@@ -71,10 +71,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[LLM], Any] = Field(
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -118,7 +118,9 @@ class Agent(BaseAgent):
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM
):
self.function_calling_llm = create_llm(self.function_calling_llm)
if not self.agent_executor:

View File

@@ -13,7 +13,7 @@ from crewai.agents.parser import (
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.llm import LLM
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N, Printer
@@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks: List[Any] = [],
):
self._i18n: I18N = I18N()
self.llm: LLM = llm
self.llm: BaseLLM = llm
self.task = task
self.agent = agent
self.crew = crew
@@ -87,8 +87,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
tool.name: tool for tool in self.tools
}
self.stop = stop_words
self.llm.stop = list(set(self.llm.stop + self.stop))
existing_stop = self.llm.stop or []
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
if "system" in self.prompt:

View File

@@ -14,7 +14,7 @@ from packaging import version
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)

View File

@@ -6,8 +6,9 @@ import warnings
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast
from langchain_core.tools import BaseTool as LangchainBaseTool
from pydantic import (
UUID4,
BaseModel,
@@ -26,7 +27,7 @@ from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -37,7 +38,7 @@ from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -153,7 +154,7 @@ class Crew(BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Any] = Field(
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
@@ -187,7 +188,7 @@ class Crew(BaseModel):
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
)
prompt_file: str = Field(
prompt_file: Optional[str] = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
@@ -199,7 +200,7 @@ class Crew(BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Any] = Field(
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
@@ -215,7 +216,7 @@ class Crew(BaseModel):
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Any] = Field(
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -819,7 +820,12 @@ class Crew(BaseModel):
# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -838,7 +844,7 @@ class Crew(BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
@@ -850,7 +856,7 @@ class Crew(BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -888,10 +894,12 @@ class Crew(BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: List[Tool]
) -> List[Tool]:
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
# Add delegation tools if agent allows delegation
if agent.allow_delegation:
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
):
if self.process == Process.hierarchical:
if self.manager_agent:
tools = self._update_manager_tools(task, tools)
@@ -900,17 +908,24 @@ class Crew(BaseModel):
"Manager agent is required for hierarchical process."
)
elif agent and agent.allow_delegation:
elif agent:
tools = self._add_delegation_tools(task, tools)
# Add code execution tools if agent allows code execution
if agent.allow_code_execution:
if hasattr(agent, "allow_code_execution") and getattr(
agent, "allow_code_execution", False
):
tools = self._add_code_execution_tools(agent, tools)
if agent and agent.multimodal:
if (
agent
and hasattr(agent, "multimodal")
and getattr(agent, "multimodal", False)
):
tools = self._add_multimodal_tools(agent, tools)
return tools
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
if self.process == Process.hierarchical:
@@ -918,11 +933,13 @@ class Crew(BaseModel):
return task.agent
def _merge_tools(
self, existing_tools: List[Tool], new_tools: List[Tool]
) -> List[Tool]:
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return existing_tools
return cast(List[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -933,23 +950,41 @@ class Crew(BaseModel):
# Add all new tools
tools.extend(new_tools)
return tools
return cast(List[BaseTool], tools)
def _inject_delegation_tools(
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
):
delegation_tools = task_agent.get_delegation_tools(agents)
return self._merge_tools(tools, delegation_tools)
self,
tools: Union[List[Tool], List[BaseTool]],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
multimodal_tools = agent.get_multimodal_tools()
return self._merge_tools(tools, multimodal_tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
code_tools = agent.get_code_execution_tools()
return self._merge_tools(tools, code_tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -957,7 +992,7 @@ class Crew(BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return tools
return cast(List[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
@@ -965,7 +1000,9 @@ class Crew(BaseModel):
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(self, task: Task, tools: List[Tool]):
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -973,7 +1010,7 @@ class Crew(BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return tools
return cast(List[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
context = (
@@ -1214,13 +1251,14 @@ class Crew(BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[LLM]],
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
try:
eval_llm = create_llm(eval_llm)
if not eval_llm:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
if not llm_instance:
raise ValueError("Failed to create LLM instance.")
crewai_event_bus.emit(
@@ -1228,12 +1266,12 @@ class Crew(BaseModel):
CrewTestStartedEvent(
crew_name=self.name or "crew",
n_iterations=n_iterations,
eval_llm=eval_llm,
eval_llm=llm_instance,
inputs=inputs,
),
)
test_crew = self.copy()
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
evaluator = CrewEvaluator(test_crew, llm_instance)
for i in range(1, n_iterations + 1):
evaluator.set_iteration(i)

View File

@@ -40,6 +40,7 @@ with warnings.catch_warnings():
from litellm.utils import supports_response_schema
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
@@ -218,7 +219,7 @@ class StreamingChoices(TypedDict):
finish_reason: Optional[str]
class LLM:
class LLM(BaseLLM):
def __init__(
self,
model: str,

View File

@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
rely on litellm's authentication mechanism.
Custom LLM implementations should handle error cases gracefully, including
timeouts, authentication failures, and malformed responses. They should also
implement proper validation for input parameters and provide clear error
messages when things go wrong.
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
model: str
temperature: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
model: str,
temperature: Optional[float] = None,
):
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
All custom LLM implementations should call super().__init__() to ensure
that these default attributes are properly initialized.
"""
self.model = model
self.temperature = temperature
self.stop = []
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
bool: True if the LLM supports stop words, False otherwise.
"""
return True # Default implementation assumes support for stop words
def get_context_window_size(self) -> int:
"""Get the context window size for the LLM.
Returns:
int: The number of tokens/characters the model can handle.
"""
# Default implementation - subclasses should override with model-specific values
return 4096

58
src/crewai/llms/third_party/ai_suite.py vendored Normal file
View File

@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional, Union
from crewai.llms.base_llm import BaseLLM
class AISuiteLLM(BaseLLM):
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
super().__init__(model, temperature, **kwargs)
try:
import aisuite as ai
except ImportError:
import click
if click.confirm(
"You are missing the 'aisuite' package. Would you like to install it?"
):
import subprocess
try:
subprocess.run(["uv", "add", "aisuite"], check=True)
import aisuite as ai
except subprocess.CalledProcessError as e:
raise ImportError(f"Failed to install 'aisuite' package: {str(e)}")
else:
raise ImportError(
"The 'aisuite' package is required for this functionality."
)
self.client = ai.Client()
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
completion_params = self._prepare_completion_params(messages, tools)
response = self.client.chat.completions.create(**completion_params)
return response.choices[0].message.content
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
return {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"tools": tools,
}
def supports_function_calling(self) -> bool:
return False

View File

@@ -6,7 +6,7 @@ from rich.console import Console
from rich.table import Table
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.telemetry import Telemetry
@@ -24,7 +24,7 @@ class CrewEvaluator:
Attributes:
crew (Crew): The crew of agents to evaluate.
eval_llm (LLM): Language model instance to use for evaluations
eval_llm (BaseLLM): Language model instance to use for evaluations
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
iteration (int): The current iteration of the evaluation.
"""
@@ -33,7 +33,7 @@ class CrewEvaluator:
run_execution_times: defaultdict = defaultdict(list)
iteration: int = 0
def __init__(self, crew, eval_llm: InstanceOf[LLM]):
def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]):
self.crew = crew
self.llm = eval_llm
self._telemetry = Telemetry()

View File

@@ -2,28 +2,28 @@ import os
from typing import Any, Dict, List, Optional, Union
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]:
) -> Optional[LLM | BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | LLM | Any | None):
llm_value (str | BaseLLM | Any | None):
- str: The model name (e.g., "gpt-4").
- LLM: Already instantiated LLM, returned as-is.
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns:
An LLM instance if successful, or None if something fails.
A BaseLLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already an LLM object, return it directly
if isinstance(llm_value, LLM):
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
return llm_value
# 2) If llm_value is a string (model name)

View File

@@ -0,0 +1,107 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the answer to life, the universe, and everything?"}],
"model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '206'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"id\": \"chatcmpl-B7W6FS0wpfndLdg12G3H6ZAXcYhJi\",\n \"object\":
\"chat.completion\",\n \"created\": 1741131387,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"The answer to life, the universe, and
everything, famously found in Douglas Adams' \\\"The Hitchhiker's Guide to the
Galaxy,\\\" is the number 42. However, the question itself is left ambiguous,
leading to much speculation and humor in the story.\",\n \"refusal\":
null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 30,\n \"completion_tokens\":
54,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_06737a9306\"\n}\n"
headers:
CF-RAY:
- 91b532234c18cf1f-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:36:28 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=DgLb6UAE6W4Oeto1Bi2RiKXQVV5TTzkXdXWFdmAEwQQ-1741131388-1.0.1.1-jWQtsT95wOeQbmIxAK7cv8gJWxYi1tQ.IupuJzBDnZr7iEChwVUQBRfnYUBJPDsNly3bakCDArjD_S.FLKwH6xUfvlxgfd4YSBhBPy7bcgw;
path=/; expires=Wed, 05-Mar-25 00:06:28 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=Oa59XCmqjKLKwU34la1hkTunN57JW20E.ZHojvRBfow-1741131388236-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '776'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999960'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_97824e8fe7c1aca3fbcba7c925388b39
http_version: HTTP/1.1
status_code: 200
version: 1

View File

@@ -0,0 +1,305 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54660799a15b4-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:16 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=OwS.6cyfDpbxxx8vPp4THv5eNoDMQK0qSVN.wSUyOYk-1741132216-1.0.1.1-QBVd08CjfmDBpNnYQM5ILGbTUWKh6SDM9E4ARG4SV2Z9Q4ltFSFLXoo38OGJApUNZmzn4PtRsyAPsHt_dsrHPF6MD17FPcGtrnAHqCjJrfU;
path=/; expires=Wed, 05-Mar-25 00:20:16 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=n_ebDsAOhJm5Mc7OMx8JDiOaZq5qzHCnVxyS3KN0BwA-1741132216951-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '19'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_042a4e8f9432f6fde7a02037bb6caafa
http_version: HTTP/1.1
status_code: 400
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54664bb1acef1-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:17 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=.wGU4pJEajaSzFWjp05TBQwWbCNA2CgpYNu7UYOzbbM-1741132217-1.0.1.1-NoLiAx4qkplllldYYxZCOSQGsX6hsPUJIEyqmt84B3g7hjW1s7.jk9C9PYzXagHWjT0sQ9Ny4LZBA94lDJTfDBZpty8NJQha7ZKW0P_msH8;
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=GAjgJjVLtN49bMeWdWZDYLLkEkK51z5kxK4nKqhAzxY-1741132217161-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '25'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_7a1d027da1ef4468e861e570c72e98fb
http_version: HTTP/1.1
status_code: 400
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54666183beb22-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:17 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=VwjWHHpkZMJlosI9RbMqxYDBS1t0JK4tWpAy4lST2QM-1741132217-1.0.1.1-u7PU.ZvVBTXNB5R8vaYfWdPXAjWZ3ZcTAy656VaGDZmKIckk5od._eQdn0W0EGVtEMm3TuF60z4GZAPDwMYvb3_3cw1RuEMmQbp4IIrl7VY;
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=NglAAsQBoiabMuuHFgilRjflSPFqS38VGKnGyweuCuw-1741132217438-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '56'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_3c335b308b82cc2214783a4bf2fc0fd4
http_version: HTTP/1.1
status_code: 400
version: 1

359
tests/custom_llm_test.py Normal file
View File

@@ -0,0 +1,359 @@
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
import pytest
from crewai import Agent, Crew, Process, Task
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
that returns a predefined response for testing purposes.
"""
def __init__(self, response="Default response", model="test-model"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__(model=model)
self.response = response
self.call_count = 0
def call(
self,
messages,
tools=None,
callbacks=None,
available_functions=None,
):
"""
Mock LLM call that returns a predefined response.
Properly formats messages to match OpenAI's expected structure.
"""
self.call_count += 1
# If input is a string, convert to proper message format
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Ensure each message has properly formatted content
for message in messages:
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
# Return predefined response in expected format
if "Thought:" in str(messages):
return f"Thought: I will say hi\nFinal Answer: {self.response}"
return self.response
def supports_function_calling(self) -> bool:
"""Return False to indicate that function calling is not supported.
Returns:
False, indicating that this LLM does not support function calling.
"""
return False
def supports_stop_words(self) -> bool:
"""Return False to indicate that stop words are not supported.
Returns:
False, indicating that this LLM does not support stop words.
"""
return False
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
4096, a typical context window size for modern LLMs.
"""
return 4096
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_implementation():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
# Test that create_llm returns the custom LLM instance directly
result_llm = create_llm(custom_llm)
assert result_llm is custom_llm
# Test calling the custom LLM
response = result_llm.call(
"What is the answer to life, the universe, and everything?"
)
# Verify that the response from the custom LLM was used
assert "42" in response
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_within_crew():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
agent = Agent(
role="Say Hi",
goal="Say hi to the user",
backstory="""You just say hi to the user""",
llm=custom_llm,
)
task = Task(
description="Say hi to the user",
expected_output="A greeting to the user",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
)
result = crew.kickoff()
# Assert the LLM was called
assert custom_llm.call_count > 0
# Assert we got a response
assert "Hello!" in result.raw
def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages"""
custom_llm = CustomLLM(response="Test response", model="test-model")
# Test with string input
result = custom_llm.call("Test message")
assert result == "Test response"
# Test with message list
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
]
result = custom_llm.call(messages)
assert result == "Test response"
class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
super().__init__(model="test-model")
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token")
self.jwt_token = jwt_token
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return a predefined response."""
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# In a real implementation, this would use the JWT token to authenticate
# with an external service
return "Response from JWT-authenticated LLM"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
return 8192
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
assert result_llm is jwt_llm
# Test calling the JWT-authenticated LLM
response = result_llm.call("Test message")
# Verify that the JWT-authenticated LLM was called
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"
def test_jwt_auth_llm_validation():
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")
# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)
class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__(model="test-model")
self.max_retries = max_retries
self.timeout = timeout
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0,
}
)
# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
if attempt == 0:
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
def test_timeout_handling_llm():
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1
# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call
# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt

16
uv.lock generated
View File

@@ -139,6 +139,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 },
]
[[package]]
name = "aisuite"
version = "0.1.10"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6a/9d/c7a8a76abb9011dd2bc9a5cb8ffa8231640e20bbdae177ce9ab6cb67c66c/aisuite-0.1.10.tar.gz", hash = "sha256:170e62d4c91fecb22e82a04e058154a111cef473681171e5df7346272e77f414", size = 29052 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/58/c2/9a34a01516de107e5f9406dbfd319b6004340708101d67fa107373da4058/aisuite-0.1.10-py3-none-any.whl", hash = "sha256:c8510ebe38d6546b6a06819171e201fcaf0bf9ae020ffcfe19b6bd90430781ad", size = 43984 },
]
[[package]]
name = "alembic"
version = "1.13.3"
@@ -651,6 +663,9 @@ dependencies = [
agentops = [
{ name = "agentops" },
]
aisuite = [
{ name = "aisuite" },
]
docling = [
{ name = "docling" },
]
@@ -698,6 +713,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "agentops", marker = "extra == 'agentops'", specifier = ">=0.3.0" },
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
{ name = "appdirs", specifier = ">=1.4.4" },
{ name = "auth0-python", specifier = ">=4.7.1" },
{ name = "blinker", specifier = ">=1.9.0" },