Compare commits

...

8 Commits

Author SHA1 Message Date
Devin AI
73880d407b Implement improvements based on PR feedback: enhanced error handling in agent.py, JWT token validation, and rate limiting in custom_llm_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-14 06:49:41 +00:00
Devin AI
2a573d8df9 Add test cassette for LLM authentication error handling
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-14 06:36:15 +00:00
Devin AI
1b8c07760e Simplify LLM implementation by consolidating LLM and BaseLLM classes
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-14 06:35:42 +00:00
Devin AI
963ed23b63 Enhance custom LLM implementation with better error handling, documentation, and test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:50:52 +00:00
Devin AI
22aeeaadbe Fix type errors in crew.py by updating tool-related methods to return List[BaseTool]
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:42:58 +00:00
Devin AI
7201161207 Fix linting issues with import sorting
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:19:36 +00:00
Devin AI
687303ad63 Fix import sorting and type annotations
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:13:12 +00:00
Devin AI
ec8e705bbc Add support for custom LLM implementations
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 17:09:17 +00:00
9 changed files with 1676 additions and 58 deletions

681
docs/custom_llm.md Normal file
View File

@@ -0,0 +1,681 @@
# Custom LLM Implementations
CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
## Using Custom LLM Implementations
To create a custom LLM implementation, you need to:
1. Inherit from the `LLM` base class
2. Implement the required methods:
- `call()`: The main method to call the LLM with messages
- `supports_function_calling()`: Whether the LLM supports function calling
- `supports_stop_words()`: Whether the LLM supports stop words
- `get_context_window_size()`: The context window size of the LLM
## Using the Default LLM Implementation
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
```python
from crewai import LLM
# Create a default LLM instance
llm = LLM.create(model="gpt-4")
# Or with more parameters
llm = LLM.create(
model="gpt-4",
temperature=0.7,
max_tokens=1000,
api_key="your-api-key"
)
```
## Example: Basic Custom LLM
```python
from crewai import LLM
from typing import Any, Dict, List, Optional, Union
class CustomLLM(LLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM
# For example, using requests:
import requests
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
# Return True if your LLM supports function calling
return True
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
# Return True if your LLM supports stop words
return True
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
# Return the context window size of your LLM
return 8192
```
## Error Handling Best Practices
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
### 1. Implement Try-Except Blocks for API Calls
Always wrap API calls in try-except blocks to handle different types of errors:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
# API call implementation
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```
### 2. Implement Retry Logic for Transient Failures
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import time
max_retries = 3
retry_delay = 1 # seconds
for attempt in range(max_retries):
try:
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except (requests.Timeout, requests.ConnectionError) as e:
if attempt < max_retries - 1:
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
continue
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
```
### 3. Validate Input Parameters
Always validate input parameters to prevent runtime errors:
```python
def __init__(self, api_key: str, endpoint: str):
super().__init__()
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
```
### 4. Handle Authentication Errors Gracefully
Provide clear error messages for authentication failures:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
response = requests.post(self.endpoint, headers=self.headers, json=data)
if response.status_code == 401:
raise ValueError("Authentication failed: Invalid API key or token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")
response.raise_for_status()
# Process response
except Exception as e:
# Handle error
raise
```
## Example: JWT-based Authentication
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
```python
from crewai import LLM, Agent, Task
from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(LLM):
def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.jwt_token = jwt_token
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with JWT authentication.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM with JWT authentication
import requests
try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)
if response.status_code == 401:
raise ValueError("Authentication failed: Invalid JWT token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return True
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
return True
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
return 8192
```
## Troubleshooting
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
### 1. Authentication Failures
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
**Solutions**:
- Verify that your API key or JWT token is valid and not expired
- Check that you're using the correct authentication header format
- Ensure that your token has the necessary permissions
### 2. Timeout Issues
**Symptoms**: Requests taking too long or timing out
**Solutions**:
- Implement timeout handling as shown in the examples
- Use retry logic with exponential backoff
- Consider using a more reliable network connection
### 3. Response Parsing Errors
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
**Solutions**:
- Validate the response format before accessing nested fields
- Implement proper error handling for malformed responses
- Check the API documentation for the expected response format
### 4. Rate Limiting
**Symptoms**: 429 Too Many Requests errors
**Solutions**:
- Implement rate limiting in your custom LLM
- Add exponential backoff for retries
- Consider using a token bucket algorithm for more precise rate control
## Advanced Features
### Logging
Adding logging to your custom LLM can help with debugging and monitoring:
```python
import logging
from typing import Any, Dict, List, Optional, Union
class LoggingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.logger = logging.getLogger("crewai.llm.custom")
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
try:
# API call implementation
response = self._make_api_call(messages, tools)
self.logger.debug(f"LLM response received: {response[:100]}...")
return response
except Exception as e:
self.logger.error(f"LLM call failed: {str(e)}")
raise
```
### Rate Limiting
Implementing rate limiting can help avoid overwhelming the LLM API:
```python
import time
from typing import Any, Dict, List, Optional, Union
class RateLimitedLLM(BaseLLM):
def __init__(
self,
api_key: str,
endpoint: str,
requests_per_minute: int = 60
):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.requests_per_minute = requests_per_minute
self.request_times: List[float] = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self._enforce_rate_limit()
# Record this request time
self.request_times.append(time.time())
# Make the actual API call
return self._make_api_call(messages, tools)
def _enforce_rate_limit(self) -> None:
"""Enforce the rate limit by waiting if necessary."""
now = time.time()
# Remove request times older than 1 minute
self.request_times = [t for t in self.request_times if now - t < 60]
if len(self.request_times) >= self.requests_per_minute:
# Calculate how long to wait
oldest_request = min(self.request_times)
wait_time = 60 - (now - oldest_request)
if wait_time > 0:
time.sleep(wait_time)
```
### Metrics Collection
Collecting metrics can help you monitor your LLM usage:
```python
import time
from typing import Any, Dict, List, Optional, Union
class MetricsCollectingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.metrics: Dict[str, Any] = {
"total_calls": 0,
"total_tokens": 0,
"errors": 0,
"latency": []
}
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
start_time = time.time()
self.metrics["total_calls"] += 1
try:
response = self._make_api_call(messages, tools)
# Estimate tokens (simplified)
if isinstance(messages, str):
token_estimate = len(messages) // 4
else:
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
self.metrics["total_tokens"] += token_estimate
return response
except Exception as e:
self.metrics["errors"] += 1
raise
finally:
latency = time.time() - start_time
self.metrics["latency"].append(latency)
def get_metrics(self) -> Dict[str, Any]:
"""Return the collected metrics."""
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
return {
**self.metrics,
"avg_latency": avg_latency
}
```
## Advanced Usage: Function Calling
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
```python
import json
from typing import Any, Dict, List, Optional, Union
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import requests
try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
response_data = response.json()
# Check if the LLM wants to call a function
if response_data["choices"][0]["message"].get("tool_calls"):
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
# Process each tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
if available_functions and function_name in available_functions:
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
# Add the function response to the messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_response)
})
# Call the LLM again with the updated messages
return self.call(messages, tools, callbacks, available_functions)
# Return the text response if no function call
return response_data["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```
## Using Your Custom LLM with CrewAI
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
```python
from crewai import Agent, Task, Crew
from typing import Dict, Any
# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
jwt_token="your.jwt.token",
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)
# Use it with an agent
agent = Agent(
role="Research Assistant",
goal="Find information on a topic",
backstory="You are a research assistant tasked with finding information.",
llm=jwt_llm,
)
# Create a task for the agent
task = Task(
description="Research the benefits of exercise",
agent=agent,
expected_output="A summary of the benefits of exercise",
)
# Execute the task
result = agent.execute_task(task)
print(result)
# Or use it with a crew
crew = Crew(
agents=[agent],
tasks=[task],
manager_llm=jwt_llm, # Use your custom LLM for the manager
)
# Run the crew
result = crew.kickoff()
print(result)
```
## Implementing Your Own Authentication Mechanism
The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
- OAuth tokens
- Client certificates
- Custom headers
- Session-based authentication
- Any other authentication method required by your LLM provider
Simply implement the appropriate authentication logic in your custom LLM class.
## Migrating from BaseLLM to LLM
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
```python
# Old code
from crewai import BaseLLM
class CustomLLM(BaseLLM):
# ...
# New code
from crewai import LLM
class CustomLLM(LLM):
# ...
```
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.

View File

@@ -4,7 +4,7 @@ from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM, DefaultLLM
from crewai.process import Process
from crewai.task import Task
@@ -21,6 +21,8 @@ __all__ = [
"Process",
"Task",
"LLM",
"BaseLLM",
"DefaultLLM",
"Flow",
"Knowledge",
]

View File

@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
@@ -70,10 +70,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[LLM], Any] = Field(
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -116,9 +116,16 @@ class Agent(BaseAgent):
def post_init_setup(self):
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
try:
self.llm = create_llm(self.llm)
except Exception as e:
raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}")
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
try:
self.function_calling_llm = create_llm(self.function_calling_llm)
except Exception as e:
raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}")
if not self.agent_executor:
self._setup_agent_executor()

View File

@@ -14,7 +14,7 @@ from packaging import version
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
def initialize_chat_llm(crew: Crew) -> Optional[BaseLLM]:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
@@ -220,7 +220,7 @@ def get_user_input() -> str:
def handle_user_input(
user_input: str,
chat_llm: LLM,
chat_llm: BaseLLM,
messages: List[Dict[str, str]],
crew_tool_schema: Dict[str, Any],
available_functions: Dict[str, Any],

View File

@@ -6,8 +6,9 @@ import warnings
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast
from langchain_core.tools import BaseTool as LangchainBaseTool
from pydantic import (
UUID4,
BaseModel,
@@ -26,7 +27,7 @@ from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -36,7 +37,7 @@ from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -150,14 +151,14 @@ class Crew(BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Any] = Field(
manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
description="Language model that will run the agent.", default=None
description="Language model that will be used for function calling.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
@@ -184,7 +185,7 @@ class Crew(BaseModel):
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
)
prompt_file: str = Field(
prompt_file: Optional[str] = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
@@ -196,7 +197,7 @@ class Crew(BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Any] = Field(
planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
@@ -212,7 +213,7 @@ class Crew(BaseModel):
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Any] = Field(
chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -798,7 +799,8 @@ class Crew(BaseModel):
# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))
self._log_task_start(task, agent_to_use.role)
@@ -817,7 +819,7 @@ class Crew(BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
@@ -829,7 +831,7 @@ class Crew(BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -867,10 +869,10 @@ class Crew(BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: List[Tool]
) -> List[Tool]:
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
# Add delegation tools if agent allows delegation
if agent.allow_delegation:
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
if self.process == Process.hierarchical:
if self.manager_agent:
tools = self._update_manager_tools(task, tools)
@@ -879,17 +881,18 @@ class Crew(BaseModel):
"Manager agent is required for hierarchical process."
)
elif agent and agent.allow_delegation:
elif agent:
tools = self._add_delegation_tools(task, tools)
# Add code execution tools if agent allows code execution
if agent.allow_code_execution:
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
tools = self._add_code_execution_tools(agent, tools)
if agent and agent.multimodal:
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
tools = self._add_multimodal_tools(agent, tools)
return tools
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
if self.process == Process.hierarchical:
@@ -897,11 +900,11 @@ class Crew(BaseModel):
return task.agent
def _merge_tools(
self, existing_tools: List[Tool], new_tools: List[Tool]
) -> List[Tool]:
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return existing_tools
return cast(List[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -912,23 +915,32 @@ class Crew(BaseModel):
# Add all new tools
tools.extend(new_tools)
return tools
return cast(List[BaseTool], tools)
def _inject_delegation_tools(
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
):
delegation_tools = task_agent.get_delegation_tools(agents)
return self._merge_tools(tools, delegation_tools)
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
) -> List[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
multimodal_tools = agent.get_multimodal_tools()
return self._merge_tools(tools, multimodal_tools)
def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
code_tools = agent.get_code_execution_tools()
return self._merge_tools(tools, code_tools)
def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -936,7 +948,7 @@ class Crew(BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return tools
return cast(List[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
@@ -944,7 +956,7 @@ class Crew(BaseModel):
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(self, task: Task, tools: List[Tool]):
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -952,7 +964,7 @@ class Crew(BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return tools
return cast(List[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
context = (
@@ -1198,21 +1210,27 @@ class Crew(BaseModel):
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
try:
eval_llm = create_llm(eval_llm)
if not eval_llm:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
if not llm_instance:
raise ValueError("Failed to create LLM instance.")
# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
from crewai.llm import LLM
if not isinstance(llm_instance, LLM):
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")
crewai_event_bus.emit(
self,
CrewTestStartedEvent(
crew_name=self.name or "crew",
n_iterations=n_iterations,
eval_llm=eval_llm,
eval_llm=llm_instance,
inputs=inputs,
),
)
test_crew = self.copy()
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
evaluator = CrewEvaluator(test_crew, llm_instance)
for i in range(1, n_iterations + 1):
evaluator.set_iteration(i)

View File

@@ -4,6 +4,7 @@ import os
import sys
import threading
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
@@ -34,6 +35,223 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class LLM(ABC):
"""Base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
rely on litellm's authentication mechanism.
Custom LLM implementations should handle error cases gracefully, including
timeouts, authentication failures, and malformed responses. They should also
implement proper validation for input parameters and provide clear error
messages when things go wrong.
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
def __new__(cls, *args, **kwargs):
"""Create a new LLM instance.
This method handles backward compatibility by creating a DefaultLLM instance
when the LLM class is instantiated directly with parameters.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
"""
if cls is LLM and (args or kwargs.get('model') is not None):
# Import locally to avoid circular imports
# This is safe because DefaultLLM is defined later in this file
DefaultLLM = globals().get('DefaultLLM')
if DefaultLLM is None:
# If DefaultLLM is not yet defined, return a placeholder
# that will be replaced with a real DefaultLLM instance later
return object.__new__(cls)
return DefaultLLM(*args, **kwargs)
return super().__new__(cls)
def __init__(self):
"""Initialize the LLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
All custom LLM implementations should call super().__init__() to ensure
that these default attributes are properly initialized.
"""
self.stop = []
@classmethod
def create(
cls,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
) -> 'DefaultLLM':
"""Create a default LLM instance using litellm.
This factory method creates a default LLM instance using litellm as the backend.
It's the recommended way to create LLM instances for most users.
Args:
model: The model name (e.g., "gpt-4").
timeout: Optional timeout for the LLM call.
temperature: Optional temperature for the LLM call.
top_p: Optional top_p for the LLM call.
n: Optional n for the LLM call.
stop: Optional stop sequences for the LLM call.
max_completion_tokens: Optional max_completion_tokens for the LLM call.
max_tokens: Optional max_tokens for the LLM call.
presence_penalty: Optional presence_penalty for the LLM call.
frequency_penalty: Optional frequency_penalty for the LLM call.
logit_bias: Optional logit_bias for the LLM call.
response_format: Optional response_format for the LLM call.
seed: Optional seed for the LLM call.
logprobs: Optional logprobs for the LLM call.
top_logprobs: Optional top_logprobs for the LLM call.
base_url: Optional base_url for the LLM call.
api_base: Optional api_base for the LLM call.
api_version: Optional api_version for the LLM call.
api_key: Optional api_key for the LLM call.
callbacks: Optional callbacks for the LLM call.
reasoning_effort: Optional reasoning_effort for the LLM call.
**kwargs: Additional keyword arguments for the LLM call.
Returns:
A DefaultLLM instance configured with the provided parameters.
"""
from crewai.llm import DefaultLLM
return DefaultLLM(
model=model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
**kwargs,
)
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
NotImplementedError: If this method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement call()")
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
This method should return True if the LLM implementation supports
function calling (tools), and False otherwise. If this method returns
True, the LLM should be able to handle the 'tools' parameter in the
call() method.
Returns:
True if the LLM supports function calling, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement supports_function_calling()")
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
This method should return True if the LLM implementation supports
stop words, and False otherwise. If this method returns True, the
LLM should respect the 'stop' attribute when generating responses.
Returns:
True if the LLM supports stop words, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement supports_stop_words()")
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
This method should return the maximum number of tokens that the LLM
can process in a single request. This is used by CrewAI to ensure
that messages don't exceed the LLM's context window.
Returns:
The context window size as an integer.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement get_context_window_size()")
class FilteredStream:
def __init__(self, original_stream):
self._original_stream = original_stream
@@ -126,7 +344,14 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM:
class DefaultLLM(LLM):
"""Default LLM implementation using litellm.
This class provides a concrete implementation of the LLM interface
using litellm as the backend. It's the default implementation used
by CrewAI when no custom LLM is provided.
"""
def __init__(
self,
model: str,
@@ -152,6 +377,8 @@ class LLM:
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
):
super().__init__() # Initialize the base class
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -180,7 +407,7 @@ class LLM:
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop = [] # Already initialized in base class
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -564,3 +791,27 @@ class LLM:
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
class BaseLLM(LLM):
"""Deprecated: Use LLM instead.
This class is kept for backward compatibility and will be removed in a future release.
It inherits from LLM and provides the same interface, but emits a deprecation warning
when instantiated.
"""
def __init__(self):
"""Initialize the BaseLLM with a deprecation warning.
This constructor emits a deprecation warning and then calls the parent class's
constructor to initialize the LLM.
"""
import warnings
warnings.warn(
"BaseLLM is deprecated and will be removed in a future release. "
"Use LLM instead for custom implementations.",
DeprecationWarning,
stacklevel=2
)
super().__init__()

View File

@@ -2,7 +2,7 @@ import os
from typing import Any, Dict, List, Optional, Union
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
def create_llm(
@@ -19,17 +19,17 @@ def create_llm(
- None: Use environment-based or fallback default model.
Returns:
An LLM instance if successful, or None if something fails.
A LLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already an LLM object, return it directly
# 1) If llm_value is already a LLM object, return it directly
if isinstance(llm_value, LLM):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
created_llm = LLM.create(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -56,7 +56,7 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
created_llm = LLM.create(
model=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Try creating the LLM
try:
new_llm = LLM(**llm_params)
new_llm = LLM.create(**llm_params)
return new_llm
except Exception as e:
print(

View File

@@ -0,0 +1,89 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
personal goal is: test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
is the expected criteria for your final answer: Test output\nyou MUST return
the actual complete content as the final answer, not a summary.\n\nBegin! This
is VERY important to you, use the tools available and give your best Final Answer,
your job depends on it!\n\nThought:"}], "model": "gpt-4", "stop": ["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '805'
content-type:
- application/json
cookie:
- _cfuvid=xecEkmr_qTiKn7EKC7aeGN5bpsbPM9ofyIsipL4VCYM-1734033219265-0.0.1.1-604800000
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- x64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.61.0
x-stainless-raw-response:
- 'true'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
sk-proj-********************************************************************************************************************************************************sLcA.
You can find your API key at https://platform.openai.com/account/api-keys.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
\"invalid_api_key\"\n }\n}\n"
headers:
CF-RAY:
- 9201beec18a0762e-SEA
Connection:
- keep-alive
Content-Length:
- '414'
Content-Type:
- application/json; charset=utf-8
Date:
- Fri, 14 Mar 2025 06:34:31 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=wF6OyTyATDK7A9tGqAdaSB3QZfmd34JWPicYlDC1hug-1741934071-1.0.1.1-nZThPWX_7A9FsU7Z14PyrVhl6mCD99iuk9ujCFkNCCdepMHEwK9EXoDrP4IBBCXxkXmKjrVTSaQ63zpcociXuMHR8JKhth2fRUV2H4hMldY;
path=/; expires=Fri, 14-Mar-25 07:04:31 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=rn5IWZdYMRmbyCa2_84MkWO46MIaP6soWc8npaLc9iQ-1741934071787-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
vary:
- Origin
x-request-id:
- req_f55471c8eb5755daaef3d63eab5a95de
http_version: HTTP/1.1
status_code: 401
version: 1

570
tests/custom_llm_test.py Normal file
View File

@@ -0,0 +1,570 @@
from collections import deque
from typing import Any, Dict, List, Optional, Union
import time
import jwt
import pytest
from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(LLM):
"""Custom LLM implementation for testing.
This is a simple implementation of the LLM abstract base class
that returns a predefined response for testing purposes.
"""
def __init__(self, response: str = "Custom LLM response"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__()
self.response = response
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return the predefined response.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The predefined response string.
"""
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
return self.response
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
def test_custom_llm_implementation():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
# Test that create_llm returns the custom LLM instance directly
result_llm = create_llm(custom_llm)
assert result_llm is custom_llm
# Test calling the custom LLM
response = result_llm.call("What is the answer to life, the universe, and everything?")
# Verify that the custom LLM was called
assert len(custom_llm.calls) > 0
# Verify that the response from the custom LLM was used
assert response == "The answer is 42"
class JWTAuthLLM(LLM):
"""Custom LLM implementation with JWT authentication.
This class demonstrates how to implement a custom LLM that uses JWT
authentication instead of API key-based authentication. It validates
the JWT token before each call and checks for token expiration.
"""
def __init__(self, jwt_token: str, expiration_buffer: int = 60):
"""Initialize the JWTAuthLLM with a JWT token.
Args:
jwt_token: The JWT token to use for authentication.
expiration_buffer: Buffer time in seconds to warn about token expiration.
Default is 60 seconds.
Raises:
ValueError: If the JWT token is invalid or missing.
"""
super().__init__()
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token")
self.jwt_token = jwt_token
self.expiration_buffer = expiration_buffer
self.calls = []
self.stop = []
# Validate the token immediately
self._validate_token()
def _validate_token(self) -> None:
"""Validate the JWT token.
Checks if the token is valid and not expired. Also warns if the token
is about to expire within the expiration_buffer time.
Raises:
ValueError: If the token is invalid, expired, or malformed.
"""
try:
# Decode without verification to check expiration
# In a real implementation, you would verify the signature
decoded = jwt.decode(self.jwt_token, options={"verify_signature": False})
# Check if token is expired or about to expire
if 'exp' in decoded:
expiration_time = decoded['exp']
current_time = time.time()
if expiration_time < current_time:
raise ValueError("JWT token has expired")
if expiration_time < current_time + self.expiration_buffer:
# Token will expire soon, log a warning
import logging
logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds")
except jwt.PyJWTError as e:
raise ValueError(f"Invalid JWT token format: {str(e)}")
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with JWT authentication.
Validates the JWT token before making the call to ensure it's still valid.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The LLM response.
Raises:
ValueError: If the JWT token is invalid or expired.
TimeoutError: If the request times out.
ConnectionError: If there's a connection issue.
"""
# Validate token before making the call
self._validate_token()
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
# In a real implementation, this would use the JWT token to authenticate
# with an external service
return "Response from JWT-authenticated LLM"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
return 8192
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
# Create a valid JWT token that expires 1 hour from now
valid_token = jwt.encode(
{"exp": int(time.time()) + 3600},
"secret",
algorithm="HS256"
)
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
assert result_llm is jwt_llm
# Test calling the JWT-authenticated LLM
response = result_llm.call("Test message")
# Verify that the JWT-authenticated LLM was called
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"
def test_jwt_auth_llm_validation():
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")
# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)
# Test with expired token
# Create a token that expired 1 hour ago
expired_token = jwt.encode(
{"exp": int(time.time()) - 3600},
"secret",
algorithm="HS256"
)
with pytest.raises(ValueError, match="JWT token has expired"):
JWTAuthLLM(jwt_token=expired_token)
# Test with malformed token
with pytest.raises(ValueError, match="Invalid JWT token format"):
JWTAuthLLM(jwt_token="not.a.valid.jwt.token")
# Test with valid token
# Create a token that expires 1 hour from now
valid_token = jwt.encode(
{"exp": int(time.time()) + 3600},
"secret",
algorithm="HS256"
)
# This should not raise an exception
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
assert jwt_llm.jwt_token == valid_token
class TimeoutHandlingLLM(LLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__()
self.max_retries = max_retries
self.timeout = timeout
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0
})
# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
if attempt == 0:
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append({
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192
def test_timeout_handling_llm():
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1
# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call
# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt
def test_rate_limited_llm():
"""Test that rate limiting works correctly."""
# Create a rate limited LLM with a very low limit (2 requests per minute)
llm = RateLimitedLLM(requests_per_minute=2)
# First request should succeed
response1 = llm.call("Test message 1")
assert response1 == "Rate limited response"
assert len(llm.calls) == 1
# Second request should succeed
response2 = llm.call("Test message 2")
assert response2 == "Rate limited response"
assert len(llm.calls) == 2
# Third request should fail due to rate limiting
with pytest.raises(ValueError, match="Rate limit exceeded"):
llm.call("Test message 3")
# Test with invalid requests_per_minute
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
RateLimitedLLM(requests_per_minute=0)
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
RateLimitedLLM(requests_per_minute=-1)
def test_rate_limit_reset():
"""Test that rate limits reset after the time window passes."""
# Create a rate limited LLM with a very low limit (1 request per minute)
# and a short time window for testing (1 second instead of 60 seconds)
time_window = 1 # 1 second instead of 60 seconds
llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window)
# First request should succeed
response1 = llm.call("Test message 1")
assert response1 == "Rate limited response"
# Second request should fail due to rate limiting
with pytest.raises(ValueError, match="Rate limit exceeded"):
llm.call("Test message 2")
# Wait for the rate limit to reset
import time
time.sleep(time_window + 0.1) # Add a small buffer
# After waiting, we should be able to make another request
response3 = llm.call("Test message 3")
assert response3 == "Rate limited response"
assert len(llm.calls) == 2 # First and third requests
class RateLimitedLLM(LLM):
"""Custom LLM implementation with rate limiting.
This class demonstrates how to implement a custom LLM with rate limiting
capabilities. It uses a sliding window algorithm to ensure that no more
than a specified number of requests are made within a given time period.
"""
def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60):
"""Initialize the RateLimitedLLM with rate limiting parameters.
Args:
requests_per_minute: Maximum number of requests allowed per minute.
base_response: Default response to return.
time_window: Time window in seconds for rate limiting (default: 60).
This is configurable for testing purposes.
Raises:
ValueError: If requests_per_minute is not a positive integer.
"""
super().__init__()
if not isinstance(requests_per_minute, int) or requests_per_minute <= 0:
raise ValueError("requests_per_minute must be a positive integer")
self.requests_per_minute = requests_per_minute
self.base_response = base_response
self.time_window = time_window
self.request_times = deque()
self.calls = []
self.stop = []
def _check_rate_limit(self) -> None:
"""Check if the current request exceeds the rate limit.
This method implements a sliding window rate limiting algorithm.
It keeps track of request timestamps and ensures that no more than
`requests_per_minute` requests are made within the configured time window.
Raises:
ValueError: If the rate limit is exceeded.
"""
current_time = time.time()
# Remove requests older than the time window
while self.request_times and current_time - self.request_times[0] > self.time_window:
self.request_times.popleft()
# Check if we've exceeded the rate limit
if len(self.request_times) >= self.requests_per_minute:
wait_time = self.time_window - (current_time - self.request_times[0])
raise ValueError(
f"Rate limit exceeded. Maximum {self.requests_per_minute} "
f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds."
)
# Record this request
self.request_times.append(current_time)
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with rate limiting.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The LLM response.
Raises:
ValueError: If the rate limit is exceeded.
"""
# Check rate limit before making the call
self._check_rate_limit()
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
return self.base_response
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192