mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 05:08:12 +00:00
Compare commits
2 Commits
devin/1741
...
bugfix/con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1eb4717352 | ||
|
|
a1f35e768f |
@@ -224,6 +224,7 @@ CrewAI provides a wide range of events that you can listen for:
|
||||
- **LLMCallStartedEvent**: Emitted when an LLM call starts
|
||||
- **LLMCallCompletedEvent**: Emitted when an LLM call completes
|
||||
- **LLMCallFailedEvent**: Emitted when an LLM call fails
|
||||
- **LLMStreamChunkEvent**: Emitted for each chunk received during streaming LLM responses
|
||||
|
||||
## Event Handler Structure
|
||||
|
||||
|
||||
@@ -540,6 +540,46 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
## Streaming Responses
|
||||
|
||||
CrewAI supports streaming responses from LLMs, allowing your application to receive and process outputs in real-time as they're generated.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Basic Setup">
|
||||
Enable streaming by setting the `stream` parameter to `True` when initializing your LLM:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(
|
||||
model="openai/gpt-4o",
|
||||
stream=True # Enable streaming
|
||||
)
|
||||
```
|
||||
|
||||
When streaming is enabled, responses are delivered in chunks as they're generated, creating a more responsive user experience.
|
||||
</Tab>
|
||||
|
||||
<Tab title="Event Handling">
|
||||
CrewAI emits events for each chunk received during streaming:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
from crewai.utilities.events import EventHandler, LLMStreamChunkEvent
|
||||
|
||||
class MyEventHandler(EventHandler):
|
||||
def on_llm_stream_chunk(self, event: LLMStreamChunkEvent):
|
||||
# Process each chunk as it arrives
|
||||
print(f"Received chunk: {event.chunk}")
|
||||
|
||||
# Register the event handler
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
crewai_event_bus.register_handler(MyEventHandler())
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Structured LLM Calls
|
||||
|
||||
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.
|
||||
@@ -669,46 +709,4 @@ Learn how to get the most out of your LLM configuration:
|
||||
Use larger context models for extensive tasks
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
# Large context model
|
||||
llm = LLM(model="openai/gpt-4o") # 128K tokens
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you need assistance, these resources are available:
|
||||
|
||||
<CardGroup cols={3}>
|
||||
<Card
|
||||
title="LiteLLM Documentation"
|
||||
href="https://docs.litellm.ai/docs/"
|
||||
icon="book"
|
||||
>
|
||||
Comprehensive documentation for LiteLLM integration and troubleshooting common issues.
|
||||
</Card>
|
||||
<Card
|
||||
title="GitHub Issues"
|
||||
href="https://github.com/joaomdmoura/crewAI/issues"
|
||||
icon="bug"
|
||||
>
|
||||
Report bugs, request features, or browse existing issues for solutions.
|
||||
</Card>
|
||||
<Card
|
||||
title="Community Forum"
|
||||
href="https://community.crewai.com"
|
||||
icon="comment-question"
|
||||
>
|
||||
Connect with other CrewAI users, share experiences, and get help from the community.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
<Note>
|
||||
Best Practices for API Key Security:
|
||||
- Use environment variables or secure vaults
|
||||
- Never commit keys to version control
|
||||
- Rotate keys regularly
|
||||
- Use separate keys for development and production
|
||||
- Monitor key usage for unusual patterns
|
||||
</Note>
|
||||
|
||||
@@ -1,681 +0,0 @@
|
||||
# Custom LLM Implementations
|
||||
|
||||
CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||
|
||||
## Using Custom LLM Implementations
|
||||
|
||||
To create a custom LLM implementation, you need to:
|
||||
|
||||
1. Inherit from the `LLM` base class
|
||||
2. Implement the required methods:
|
||||
- `call()`: The main method to call the LLM with messages
|
||||
- `supports_function_calling()`: Whether the LLM supports function calling
|
||||
- `supports_stop_words()`: Whether the LLM supports stop words
|
||||
- `get_context_window_size()`: The context window size of the LLM
|
||||
|
||||
## Using the Default LLM Implementation
|
||||
|
||||
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Create a default LLM instance
|
||||
llm = LLM.create(model="gpt-4")
|
||||
|
||||
# Or with more parameters
|
||||
llm = LLM.create(
|
||||
model="gpt-4",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
api_key="your-api-key"
|
||||
)
|
||||
```
|
||||
|
||||
## Example: Basic Custom LLM
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class CustomLLM(LLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not api_key or not isinstance(api_key, str):
|
||||
raise ValueError("Invalid API key: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.stop = [] # You can customize stop words if needed
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM or the result of a tool function call.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
ValueError: If the response format is invalid.
|
||||
"""
|
||||
# Implement your own logic to call the LLM
|
||||
# For example, using requests:
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
# Return True if your LLM supports function calling
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
# Return True if your LLM supports stop words
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
"""
|
||||
# Return the context window size of your LLM
|
||||
return 8192
|
||||
```
|
||||
|
||||
## Error Handling Best Practices
|
||||
|
||||
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
|
||||
|
||||
### 1. Implement Try-Except Blocks for API Calls
|
||||
|
||||
Always wrap API calls in try-except blocks to handle different types of errors:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
try:
|
||||
# API call implementation
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=self.headers,
|
||||
json=self.prepare_payload(messages),
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
```
|
||||
|
||||
### 2. Implement Retry Logic for Transient Failures
|
||||
|
||||
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
import time
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=self.headers,
|
||||
json=self.prepare_payload(messages),
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except (requests.Timeout, requests.ConnectionError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
|
||||
continue
|
||||
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
```
|
||||
|
||||
### 3. Validate Input Parameters
|
||||
|
||||
Always validate input parameters to prevent runtime errors:
|
||||
|
||||
```python
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
if not api_key or not isinstance(api_key, str):
|
||||
raise ValueError("Invalid API key: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
```
|
||||
|
||||
### 4. Handle Authentication Errors Gracefully
|
||||
|
||||
Provide clear error messages for authentication failures:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
try:
|
||||
response = requests.post(self.endpoint, headers=self.headers, json=data)
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Authentication failed: Invalid API key or token")
|
||||
elif response.status_code == 403:
|
||||
raise ValueError("Authorization failed: Insufficient permissions")
|
||||
response.raise_for_status()
|
||||
# Process response
|
||||
except Exception as e:
|
||||
# Handle error
|
||||
raise
|
||||
```
|
||||
|
||||
## Example: JWT-based Authentication
|
||||
|
||||
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
||||
|
||||
```python
|
||||
from crewai import LLM, Agent, Task
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class JWTAuthLLM(LLM):
|
||||
def __init__(self, jwt_token: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
raise ValueError("Invalid JWT token: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.jwt_token = jwt_token
|
||||
self.endpoint = endpoint
|
||||
self.stop = [] # You can customize stop words if needed
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with JWT authentication.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM or the result of a tool function call.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
ValueError: If the response format is invalid.
|
||||
"""
|
||||
# Implement your own logic to call the LLM with JWT authentication
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Authentication failed: Invalid JWT token")
|
||||
elif response.status_code == 403:
|
||||
raise ValueError("Authorization failed: Insufficient permissions")
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
"""
|
||||
return 8192
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
|
||||
|
||||
### 1. Authentication Failures
|
||||
|
||||
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
|
||||
|
||||
**Solutions**:
|
||||
- Verify that your API key or JWT token is valid and not expired
|
||||
- Check that you're using the correct authentication header format
|
||||
- Ensure that your token has the necessary permissions
|
||||
|
||||
### 2. Timeout Issues
|
||||
|
||||
**Symptoms**: Requests taking too long or timing out
|
||||
|
||||
**Solutions**:
|
||||
- Implement timeout handling as shown in the examples
|
||||
- Use retry logic with exponential backoff
|
||||
- Consider using a more reliable network connection
|
||||
|
||||
### 3. Response Parsing Errors
|
||||
|
||||
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
|
||||
|
||||
**Solutions**:
|
||||
- Validate the response format before accessing nested fields
|
||||
- Implement proper error handling for malformed responses
|
||||
- Check the API documentation for the expected response format
|
||||
|
||||
### 4. Rate Limiting
|
||||
|
||||
**Symptoms**: 429 Too Many Requests errors
|
||||
|
||||
**Solutions**:
|
||||
- Implement rate limiting in your custom LLM
|
||||
- Add exponential backoff for retries
|
||||
- Consider using a token bucket algorithm for more precise rate control
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Logging
|
||||
|
||||
Adding logging to your custom LLM can help with debugging and monitoring:
|
||||
|
||||
```python
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class LoggingLLM(BaseLLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.logger = logging.getLogger("crewai.llm.custom")
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
|
||||
try:
|
||||
# API call implementation
|
||||
response = self._make_api_call(messages, tools)
|
||||
self.logger.debug(f"LLM response received: {response[:100]}...")
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.error(f"LLM call failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Implementing rate limiting can help avoid overwhelming the LLM API:
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class RateLimitedLLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
requests_per_minute: int = 60
|
||||
):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.request_times: List[float] = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
self._enforce_rate_limit()
|
||||
# Record this request time
|
||||
self.request_times.append(time.time())
|
||||
# Make the actual API call
|
||||
return self._make_api_call(messages, tools)
|
||||
|
||||
def _enforce_rate_limit(self) -> None:
|
||||
"""Enforce the rate limit by waiting if necessary."""
|
||||
now = time.time()
|
||||
# Remove request times older than 1 minute
|
||||
self.request_times = [t for t in self.request_times if now - t < 60]
|
||||
|
||||
if len(self.request_times) >= self.requests_per_minute:
|
||||
# Calculate how long to wait
|
||||
oldest_request = min(self.request_times)
|
||||
wait_time = 60 - (now - oldest_request)
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
```
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
Collecting metrics can help you monitor your LLM usage:
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class MetricsCollectingLLM(BaseLLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.metrics: Dict[str, Any] = {
|
||||
"total_calls": 0,
|
||||
"total_tokens": 0,
|
||||
"errors": 0,
|
||||
"latency": []
|
||||
}
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
start_time = time.time()
|
||||
self.metrics["total_calls"] += 1
|
||||
|
||||
try:
|
||||
response = self._make_api_call(messages, tools)
|
||||
# Estimate tokens (simplified)
|
||||
if isinstance(messages, str):
|
||||
token_estimate = len(messages) // 4
|
||||
else:
|
||||
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
|
||||
self.metrics["total_tokens"] += token_estimate
|
||||
return response
|
||||
except Exception as e:
|
||||
self.metrics["errors"] += 1
|
||||
raise
|
||||
finally:
|
||||
latency = time.time() - start_time
|
||||
self.metrics["latency"].append(latency)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Return the collected metrics."""
|
||||
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
|
||||
return {
|
||||
**self.metrics,
|
||||
"avg_latency": avg_latency
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage: Function Calling
|
||||
|
||||
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
|
||||
|
||||
```python
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the LLM wants to call a function
|
||||
if response_data["choices"][0]["message"].get("tool_calls"):
|
||||
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
# Process each tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
if available_functions and function_name in available_functions:
|
||||
function_to_call = available_functions[function_name]
|
||||
function_response = function_to_call(**function_args)
|
||||
|
||||
# Add the function response to the messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"name": function_name,
|
||||
"content": str(function_response)
|
||||
})
|
||||
|
||||
# Call the LLM again with the updated messages
|
||||
return self.call(messages, tools, callbacks, available_functions)
|
||||
|
||||
# Return the text response if no function call
|
||||
return response_data["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
```
|
||||
|
||||
## Using Your Custom LLM with CrewAI
|
||||
|
||||
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from typing import Dict, Any
|
||||
|
||||
# Create your custom LLM instance
|
||||
jwt_llm = JWTAuthLLM(
|
||||
jwt_token="your.jwt.token",
|
||||
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
|
||||
)
|
||||
|
||||
# Use it with an agent
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find information on a topic",
|
||||
backstory="You are a research assistant tasked with finding information.",
|
||||
llm=jwt_llm,
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
task = Task(
|
||||
description="Research the benefits of exercise",
|
||||
agent=agent,
|
||||
expected_output="A summary of the benefits of exercise",
|
||||
)
|
||||
|
||||
# Execute the task
|
||||
result = agent.execute_task(task)
|
||||
print(result)
|
||||
|
||||
# Or use it with a crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
manager_llm=jwt_llm, # Use your custom LLM for the manager
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Implementing Your Own Authentication Mechanism
|
||||
|
||||
The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||
|
||||
- OAuth tokens
|
||||
- Client certificates
|
||||
- Custom headers
|
||||
- Session-based authentication
|
||||
- Any other authentication method required by your LLM provider
|
||||
|
||||
Simply implement the appropriate authentication logic in your custom LLM class.
|
||||
|
||||
## Migrating from BaseLLM to LLM
|
||||
|
||||
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
|
||||
|
||||
```python
|
||||
# Old code
|
||||
from crewai import BaseLLM
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
# ...
|
||||
|
||||
# New code
|
||||
from crewai import LLM
|
||||
|
||||
class CustomLLM(LLM):
|
||||
# ...
|
||||
```
|
||||
|
||||
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.
|
||||
@@ -4,7 +4,7 @@ from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM, BaseLLM, DefaultLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
@@ -21,8 +21,6 @@ __all__ = [
|
||||
"Process",
|
||||
"Task",
|
||||
"LLM",
|
||||
"BaseLLM",
|
||||
"DefaultLLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
@@ -70,10 +70,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -116,16 +116,9 @@ class Agent(BaseAgent):
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
try:
|
||||
self.llm = create_llm(self.llm)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}")
|
||||
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
|
||||
try:
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}")
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
@@ -14,7 +14,7 @@ from packaging import version
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[BaseLLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -220,7 +220,7 @@ def get_user_input() -> str:
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: BaseLLM,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
|
||||
@@ -6,9 +6,8 @@ import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from langchain_core.tools import BaseTool as LangchainBaseTool
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -27,7 +26,7 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
@@ -37,7 +36,7 @@ from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
@@ -151,14 +150,14 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
manager_llm: Optional[Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will be used for function calling.", default=None
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
@@ -197,7 +196,7 @@ class Crew(BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
planning_llm: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
@@ -213,7 +212,7 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
chat_llm: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -799,8 +798,7 @@ class Crew(BaseModel):
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
|
||||
@@ -810,6 +808,7 @@ class Crew(BaseModel):
|
||||
)
|
||||
if skipped_task_output:
|
||||
task_outputs.append(skipped_task_output)
|
||||
last_sync_output = skipped_task_output
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
@@ -819,26 +818,31 @@ class Crew(BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
# Process any pending async tasks before executing a sync task
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
processed_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(processed_outputs)
|
||||
futures.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
last_sync_output = task_output
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
# Process any remaining async tasks at the end
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
processed_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(processed_outputs)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
@@ -850,12 +854,17 @@ class Crew(BaseModel):
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
# Process any pending async tasks to ensure we have the most up-to-date context
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
processed_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(processed_outputs)
|
||||
futures.clear()
|
||||
|
||||
# Get the previous output to evaluate the condition
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
|
||||
# If there's no previous output or the condition evaluates to False, skip the task
|
||||
if previous_output is None or not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
@@ -863,16 +872,21 @@ class Crew(BaseModel):
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
# Store the execution log for the skipped task
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
|
||||
# Set the output on the task itself so it can be referenced later
|
||||
task.output = skipped_task_output
|
||||
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
|
||||
if agent.allow_delegation:
|
||||
if self.process == Process.hierarchical:
|
||||
if self.manager_agent:
|
||||
tools = self._update_manager_tools(task, tools)
|
||||
@@ -881,18 +895,17 @@ class Crew(BaseModel):
|
||||
"Manager agent is required for hierarchical process."
|
||||
)
|
||||
|
||||
elif agent:
|
||||
elif agent and agent.allow_delegation:
|
||||
tools = self._add_delegation_tools(task, tools)
|
||||
|
||||
# Add code execution tools if agent allows code execution
|
||||
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
|
||||
if agent.allow_code_execution:
|
||||
tools = self._add_code_execution_tools(agent, tools)
|
||||
|
||||
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
|
||||
if agent and agent.multimodal:
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -900,11 +913,11 @@ class Crew(BaseModel):
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, existing_tools: List[Tool], new_tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
return existing_tools
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -915,32 +928,23 @@ class Crew(BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
|
||||
):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
return self._merge_tools(tools, multimodal_tools)
|
||||
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
return self._merge_tools(tools, code_tools)
|
||||
|
||||
def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -948,7 +952,7 @@ class Crew(BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
@@ -956,7 +960,7 @@ class Crew(BaseModel):
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _update_manager_tools(self, task: Task, tools: List[Tool]):
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -964,7 +968,7 @@ class Crew(BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
|
||||
context = (
|
||||
@@ -1210,27 +1214,21 @@ class Crew(BaseModel):
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||
llm_instance = create_llm(eval_llm)
|
||||
if not llm_instance:
|
||||
eval_llm = create_llm(eval_llm)
|
||||
if not eval_llm:
|
||||
raise ValueError("Failed to create LLM instance.")
|
||||
|
||||
# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
|
||||
from crewai.llm import LLM
|
||||
if not isinstance(llm_instance, LLM):
|
||||
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewTestStartedEvent(
|
||||
crew_name=self.name or "crew",
|
||||
n_iterations=n_iterations,
|
||||
eval_llm=llm_instance,
|
||||
eval_llm=eval_llm,
|
||||
inputs=inputs,
|
||||
),
|
||||
)
|
||||
test_crew = self.copy()
|
||||
evaluator = CrewEvaluator(test_crew, llm_instance)
|
||||
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
|
||||
50
src/crewai/flow/task_decorator.py
Normal file
50
src/crewai/flow/task_decorator.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def task(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator for Flow methods that return a Task.
|
||||
|
||||
This decorator ensures that when a method returns a ConditionalTask,
|
||||
the condition is properly evaluated based on the previous task's output.
|
||||
|
||||
Args:
|
||||
func: The method to decorate
|
||||
|
||||
Returns:
|
||||
The decorated method
|
||||
"""
|
||||
setattr(func, "is_task", True)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
result = func(self, *args, **kwargs)
|
||||
|
||||
# Set the task name if not already set
|
||||
if hasattr(result, "name") and not result.name:
|
||||
result.name = func.__name__
|
||||
|
||||
# If this is a ConditionalTask, ensure it has a valid condition
|
||||
if isinstance(result, ConditionalTask):
|
||||
# If the condition is a boolean, wrap it in a function
|
||||
if isinstance(result.condition, bool):
|
||||
bool_value = result.condition
|
||||
result.condition = lambda _: bool_value
|
||||
|
||||
# Get the previous task output if available
|
||||
previous_outputs = getattr(self, "_method_outputs", [])
|
||||
previous_output = previous_outputs[-1] if previous_outputs else None
|
||||
|
||||
# If there's a previous output and it's a TaskOutput, check if we should execute
|
||||
if previous_output and isinstance(previous_output, TaskOutput):
|
||||
if not result.should_execute(previous_output):
|
||||
# Return a skipped task output instead of the task
|
||||
return result.get_skipped_task_output()
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -4,9 +4,18 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +25,7 @@ from crewai.utilities.events.llm_events import (
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||
|
||||
@@ -23,8 +33,11 @@ with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
from litellm import Choices
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
@@ -35,223 +48,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class LLM(ABC):
|
||||
"""Base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
Users can extend this class to create custom LLM implementations that don't
|
||||
rely on litellm's authentication mechanism.
|
||||
|
||||
Custom LLM implementations should handle error cases gracefully, including
|
||||
timeouts, authentication failures, and malformed responses. They should also
|
||||
implement proper validation for input parameters and provide clear error
|
||||
messages when things go wrong.
|
||||
|
||||
Attributes:
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Create a new LLM instance.
|
||||
|
||||
This method handles backward compatibility by creating a DefaultLLM instance
|
||||
when the LLM class is instantiated directly with parameters.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
|
||||
"""
|
||||
if cls is LLM and (args or kwargs.get('model') is not None):
|
||||
# Import locally to avoid circular imports
|
||||
# This is safe because DefaultLLM is defined later in this file
|
||||
DefaultLLM = globals().get('DefaultLLM')
|
||||
if DefaultLLM is None:
|
||||
# If DefaultLLM is not yet defined, return a placeholder
|
||||
# that will be replaced with a real DefaultLLM instance later
|
||||
return object.__new__(cls)
|
||||
return DefaultLLM(*args, **kwargs)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the LLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
|
||||
All custom LLM implementations should call super().__init__() to ensure
|
||||
that these default attributes are properly initialized.
|
||||
"""
|
||||
self.stop = []
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
) -> 'DefaultLLM':
|
||||
"""Create a default LLM instance using litellm.
|
||||
|
||||
This factory method creates a default LLM instance using litellm as the backend.
|
||||
It's the recommended way to create LLM instances for most users.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gpt-4").
|
||||
timeout: Optional timeout for the LLM call.
|
||||
temperature: Optional temperature for the LLM call.
|
||||
top_p: Optional top_p for the LLM call.
|
||||
n: Optional n for the LLM call.
|
||||
stop: Optional stop sequences for the LLM call.
|
||||
max_completion_tokens: Optional max_completion_tokens for the LLM call.
|
||||
max_tokens: Optional max_tokens for the LLM call.
|
||||
presence_penalty: Optional presence_penalty for the LLM call.
|
||||
frequency_penalty: Optional frequency_penalty for the LLM call.
|
||||
logit_bias: Optional logit_bias for the LLM call.
|
||||
response_format: Optional response_format for the LLM call.
|
||||
seed: Optional seed for the LLM call.
|
||||
logprobs: Optional logprobs for the LLM call.
|
||||
top_logprobs: Optional top_logprobs for the LLM call.
|
||||
base_url: Optional base_url for the LLM call.
|
||||
api_base: Optional api_base for the LLM call.
|
||||
api_version: Optional api_version for the LLM call.
|
||||
api_key: Optional api_key for the LLM call.
|
||||
callbacks: Optional callbacks for the LLM call.
|
||||
reasoning_effort: Optional reasoning_effort for the LLM call.
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
A DefaultLLM instance configured with the provided parameters.
|
||||
"""
|
||||
from crewai.llm import DefaultLLM
|
||||
|
||||
return DefaultLLM(
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement call()")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
function calling (tools), and False otherwise. If this method returns
|
||||
True, the LLM should be able to handle the 'tools' parameter in the
|
||||
call() method.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement supports_function_calling()")
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
stop words, and False otherwise. If this method returns True, the
|
||||
LLM should respect the 'stop' attribute when generating responses.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement supports_stop_words()")
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
This method should return the maximum number of tokens that the LLM
|
||||
can process in a single request. This is used by CrewAI to ensure
|
||||
that messages don't exceed the LLM's context window.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_context_window_size()")
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
self._original_stream = original_stream
|
||||
@@ -344,14 +140,18 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class DefaultLLM(LLM):
|
||||
"""Default LLM implementation using litellm.
|
||||
|
||||
This class provides a concrete implementation of the LLM interface
|
||||
using litellm as the backend. It's the default implementation used
|
||||
by CrewAI when no custom LLM is provided.
|
||||
"""
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -375,10 +175,9 @@ class DefaultLLM(LLM):
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__() # Initialize the base class
|
||||
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -402,12 +201,13 @@ class DefaultLLM(LLM):
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop = [] # Already initialized in base class
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -428,6 +228,432 @@ class DefaultLLM(LLM):
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Optional dict of available functions
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
"""
|
||||
# --- 1) Format messages according to provider requirements
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.api_base,
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": self.stream,
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
|
||||
Raises:
|
||||
Exception: If no content is received from the streaming response
|
||||
"""
|
||||
# --- 1) Initialize response tracking
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
|
||||
# --- 2) Make sure stream is set to True and include usage metrics
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
# --- 3) Process each chunk in the stream
|
||||
for chunk in litellm.completion(**params):
|
||||
chunk_count += 1
|
||||
last_chunk = chunk
|
||||
|
||||
# Extract content from the chunk
|
||||
chunk_content = None
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
try:
|
||||
# Try to access choices safely
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
# Handle different delta formats
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
# Handle dict format
|
||||
if isinstance(delta, dict):
|
||||
if "content" in delta and delta["content"] is not None:
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
# Some models might send empty content chunks
|
||||
chunk_content = ""
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting content from chunk: {e}")
|
||||
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
|
||||
|
||||
# Only add non-None content to the response
|
||||
if chunk_content is not None:
|
||||
# Add the chunk content to the full response
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(chunk=chunk_content),
|
||||
)
|
||||
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
"No chunks received in streaming response, falling back to non-streaming"
|
||||
)
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
non_streaming_params.pop(
|
||||
"stream_options", None
|
||||
) # Remove stream_options for non-streaming call
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, callbacks, available_functions
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
if not full_response.strip() and chunk_count > 0:
|
||||
logging.warning(
|
||||
f"Received {chunk_count} chunks but no content was extracted"
|
||||
)
|
||||
if last_chunk is not None:
|
||||
try:
|
||||
# Try to extract content from the last chunk's message
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
# Try to get content from message
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
logging.info(
|
||||
f"Extracted content from last chunk message: {full_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting content from last chunk: {e}")
|
||||
logging.debug(
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
|
||||
)
|
||||
|
||||
# --- 6) If still empty, raise an error instead of using a default response
|
||||
if not full_response.strip():
|
||||
raise Exception(
|
||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
)
|
||||
|
||||
# --- 7) Check for tool calls in the final response
|
||||
tool_calls = None
|
||||
try:
|
||||
if last_chunk:
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
if not tool_calls or not available_functions:
|
||||
# Log token usage if available in streaming mode
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
# Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
# --- 9) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 10) Log token usage if available in streaming mode
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
# --- 11) Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback functions
|
||||
usage_info: Usage information collected during streaming
|
||||
last_chunk: The last chunk received from the streaming response
|
||||
"""
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
# Use the usage_info we've been tracking
|
||||
if not usage_info:
|
||||
# Try to get usage from the last chunk if we haven't already
|
||||
try:
|
||||
if last_chunk:
|
||||
if (
|
||||
isinstance(last_chunk, dict)
|
||||
and "usage" in last_chunk
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs={}, # We don't have the original params here
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
response = litellm.completion(**params)
|
||||
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly
|
||||
if not tool_calls or not available_functions:
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
return text_response
|
||||
|
||||
# --- 6) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 7) If tool call handling didn't return a result, emit completion event and return text response
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
return text_response
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM
|
||||
available_functions: Dict of available functions
|
||||
|
||||
Returns:
|
||||
Optional[str]: The result of the tool call, or None if no tool call was made
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
return None
|
||||
|
||||
# --- 2) Extract function name from first tool call
|
||||
tool_call = tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
function_args = {} # Initialize to empty dict to avoid unbound variable
|
||||
|
||||
# --- 3) Check if function is available
|
||||
if function_name in available_functions:
|
||||
try:
|
||||
# --- 3.1) Parse function arguments
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
result = fn(**function_args)
|
||||
|
||||
# --- 3.3) Emit success event
|
||||
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
|
||||
return result
|
||||
except Exception as e:
|
||||
# --- 3.4) Handle execution errors
|
||||
fn = available_functions.get(
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolExecutionErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
tool_class=fn,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
)
|
||||
return None
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
@@ -457,22 +683,8 @@ class DefaultLLM(LLM):
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
|
||||
Examples:
|
||||
# Example 1: Simple string input
|
||||
>>> response = llm.call("Return the name of a random city.")
|
||||
>>> print(response)
|
||||
"Paris"
|
||||
|
||||
# Example 2: Message list with system and user messages
|
||||
>>> messages = [
|
||||
... {"role": "system", "content": "You are a geography expert"},
|
||||
... {"role": "user", "content": "What is France's capital?"}
|
||||
... ]
|
||||
>>> response = llm.call(messages)
|
||||
>>> print(response)
|
||||
"The capital of France is Paris."
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -482,127 +694,38 @@ class DefaultLLM(LLM):
|
||||
available_functions=available_functions,
|
||||
),
|
||||
)
|
||||
# Validate parameters before proceeding with the call.
|
||||
|
||||
# --- 2) Validate parameters before proceeding with the call
|
||||
self._validate_call_params()
|
||||
|
||||
# --- 3) Convert string messages to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# For O1 models, system messages are not supported.
|
||||
# Convert any system messages into assistant messages.
|
||||
# --- 4) Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 1) Format messages according to provider requirements
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.api_base,
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
# Remove None values from params
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
# --- 2) Make the completion call
|
||||
response = litellm.completion(**params)
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
# --- 4) If no tool calls, return the text response
|
||||
if not tool_calls or not available_functions:
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
return text_response
|
||||
|
||||
# --- 5) Handle the tool call
|
||||
tool_call = tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
if function_name in available_functions:
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.warning(f"Failed to parse function arguments: {e}")
|
||||
return text_response
|
||||
|
||||
fn = available_functions[function_name]
|
||||
try:
|
||||
# Call the actual tool function
|
||||
result = fn(**function_args)
|
||||
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error executing function '{function_name}': {e}"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolExecutionErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
tool_class=fn,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=f"Tool execution error: {str(e)}"
|
||||
),
|
||||
)
|
||||
return text_response
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
f"Tool call requested unknown function '{function_name}'"
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
return text_response
|
||||
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
@@ -653,6 +776,20 @@ class DefaultLLM(LLM):
|
||||
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Handle O1 models specially
|
||||
if "o1" in self.model.lower():
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Convert system messages to assistant messages
|
||||
if msg["role"] == "system":
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
return formatted_messages
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
|
||||
@@ -663,7 +800,7 @@ class DefaultLLM(LLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> str:
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -672,7 +809,7 @@ class DefaultLLM(LLM):
|
||||
"""
|
||||
if "/" in self.model:
|
||||
return self.model.split("/")[0]
|
||||
return "openai"
|
||||
return None
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
@@ -695,10 +832,12 @@ class DefaultLLM(LLM):
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "tools" in params
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -791,27 +930,3 @@ class DefaultLLM(LLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
|
||||
class BaseLLM(LLM):
|
||||
"""Deprecated: Use LLM instead.
|
||||
|
||||
This class is kept for backward compatibility and will be removed in a future release.
|
||||
It inherits from LLM and provides the same interface, but emits a deprecation warning
|
||||
when instantiated.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BaseLLM with a deprecation warning.
|
||||
|
||||
This constructor emits a deprecation warning and then calls the parent class's
|
||||
constructor to initialize the LLM.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"BaseLLM is deprecated and will be removed in a future release. "
|
||||
"Use LLM instead for custom implementations.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
"""Decorators for defining crew components and their behaviors."""
|
||||
|
||||
@@ -21,13 +23,35 @@ def after_kickoff(func):
|
||||
|
||||
def task(func):
|
||||
"""Marks a method as a crew task."""
|
||||
func.is_task = True
|
||||
setattr(func, "is_task", True)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if not result.name:
|
||||
|
||||
# Set the task name if not already set
|
||||
if hasattr(result, "name") and not result.name:
|
||||
result.name = func.__name__
|
||||
|
||||
# If this is a ConditionalTask, ensure it has a valid condition
|
||||
if isinstance(result, ConditionalTask):
|
||||
# If the condition is a boolean, wrap it in a function
|
||||
if isinstance(result.condition, bool):
|
||||
bool_value = result.condition
|
||||
result.condition = lambda _: bool_value
|
||||
|
||||
# Get the previous task output if available
|
||||
self = args[0] if args else None
|
||||
if self and hasattr(self, "_method_outputs"):
|
||||
previous_outputs = getattr(self, "_method_outputs", [])
|
||||
previous_output = previous_outputs[-1] if previous_outputs else None
|
||||
|
||||
# If there's a previous output and it's a TaskOutput, check if we should execute
|
||||
if previous_output and isinstance(previous_output, TaskOutput):
|
||||
if not result.should_execute(previous_output):
|
||||
# Return a skipped task output instead of the task
|
||||
return result.get_skipped_task_output()
|
||||
|
||||
return result
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Union, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -14,17 +14,23 @@ class ConditionalTask(Task):
|
||||
"""
|
||||
|
||||
condition: Callable[[TaskOutput], bool] = Field(
|
||||
default=None,
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
default=lambda _: True, # Default to always execute
|
||||
description="Function that determines whether the task should be executed or a boolean value.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition: Callable[[Any], bool],
|
||||
condition: Union[Callable[[Any], bool], bool],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
|
||||
# If condition is a boolean, wrap it in a function that always returns that boolean
|
||||
if isinstance(condition, bool):
|
||||
bool_value = condition
|
||||
self.condition = lambda _: bool_value
|
||||
else:
|
||||
self.condition = cast(Callable[[TaskOutput], bool], condition)
|
||||
|
||||
def should_execute(self, context: TaskOutput) -> bool:
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,12 @@ from .agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
)
|
||||
from .task_events import TaskStartedEvent, TaskCompletedEvent, TaskFailedEvent, TaskEvaluationEvent
|
||||
from .task_events import (
|
||||
TaskStartedEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
TaskEvaluationEvent,
|
||||
)
|
||||
from .flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
@@ -34,7 +39,13 @@ from .tool_usage_events import (
|
||||
ToolUsageEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from .llm_events import LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent
|
||||
from .llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
# events
|
||||
from .event_listener import EventListener
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from io import StringIO
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
@@ -11,6 +12,7 @@ from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent
|
||||
@@ -46,6 +48,8 @@ class EventListener(BaseEventListener):
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
@@ -280,9 +284,20 @@ class EventListener(BaseEventListener):
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event: LLMCallFailedEvent):
|
||||
self.logger.log(
|
||||
f"❌ LLM Call Failed: '{event.error}'",
|
||||
f"❌ LLM call failed: {event.error}",
|
||||
event.timestamp,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
|
||||
self.text_stream.write(event.chunk)
|
||||
|
||||
self.text_stream.seek(self.next_chunk)
|
||||
|
||||
# Read from the in-memory stream
|
||||
content = self.text_stream.read()
|
||||
print(content, end="", flush=True)
|
||||
self.next_chunk = self.text_stream.tell()
|
||||
|
||||
|
||||
event_listener = EventListener()
|
||||
|
||||
@@ -23,6 +23,12 @@ from .flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from .task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -58,4 +64,8 @@ EventTypes = Union[
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageStartedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
]
|
||||
|
||||
@@ -34,3 +34,10 @@ class LLMCallFailedEvent(CrewEvent):
|
||||
|
||||
error: str
|
||||
type: str = "llm_call_failed"
|
||||
|
||||
|
||||
class LLMStreamChunkEvent(CrewEvent):
|
||||
"""Event emitted when a streaming chunk is received"""
|
||||
|
||||
type: str = "llm_stream_chunk"
|
||||
chunk: str
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def create_llm(
|
||||
@@ -19,17 +19,17 @@ def create_llm(
|
||||
- None: Use environment-based or fallback default model.
|
||||
|
||||
Returns:
|
||||
A LLM instance if successful, or None if something fails.
|
||||
An LLM instance if successful, or None if something fails.
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already a LLM object, return it directly
|
||||
# 1) If llm_value is already an LLM object, return it directly
|
||||
if isinstance(llm_value, LLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
created_llm = LLM.create(model=llm_value)
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
@@ -56,7 +56,7 @@ def create_llm(
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
created_llm = LLM.create(
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
|
||||
# Try creating the LLM
|
||||
try:
|
||||
new_llm = LLM.create(**llm_params)
|
||||
new_llm = LLM(**llm_params)
|
||||
return new_llm
|
||||
except Exception as e:
|
||||
print(
|
||||
|
||||
@@ -18,6 +18,7 @@ from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||
|
||||
|
||||
@@ -259,9 +260,7 @@ def test_cache_hitting():
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
):
|
||||
with (patch.object(CacheHandler, "read") as read,):
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
|
||||
2571
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
2571
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,89 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
|
||||
is the expected criteria for your final answer: Test output\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}], "model": "gpt-4", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '805'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=xecEkmr_qTiKn7EKC7aeGN5bpsbPM9ofyIsipL4VCYM-1734033219265-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.61.0
|
||||
x-stainless-arch:
|
||||
- x64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- Linux
|
||||
x-stainless-package-version:
|
||||
- 1.61.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
|
||||
sk-proj-********************************************************************************************************************************************************sLcA.
|
||||
You can find your API key at https://platform.openai.com/account/api-keys.\",\n
|
||||
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
|
||||
\"invalid_api_key\"\n }\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9201beec18a0762e-SEA
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '414'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Fri, 14 Mar 2025 06:34:31 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=wF6OyTyATDK7A9tGqAdaSB3QZfmd34JWPicYlDC1hug-1741934071-1.0.1.1-nZThPWX_7A9FsU7Z14PyrVhl6mCD99iuk9ujCFkNCCdepMHEwK9EXoDrP4IBBCXxkXmKjrVTSaQ63zpcociXuMHR8JKhth2fRUV2H4hMldY;
|
||||
path=/; expires=Fri, 14-Mar-25 07:04:31 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=rn5IWZdYMRmbyCa2_84MkWO46MIaP6soWc8npaLc9iQ-1741934071787-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
vary:
|
||||
- Origin
|
||||
x-request-id:
|
||||
- req_f55471c8eb5755daaef3d63eab5a95de
|
||||
http_version: HTTP/1.1
|
||||
status_code: 401
|
||||
version: 1
|
||||
190
tests/conditional_task_test.py
Normal file
190
tests/conditional_task_test.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
# Create mock agents for testing
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research information",
|
||||
backstory="You are a researcher with expertise in finding information.",
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Writer",
|
||||
goal="Write content",
|
||||
backstory="You are a writer with expertise in creating engaging content.",
|
||||
)
|
||||
|
||||
|
||||
def test_conditional_task_with_boolean_false():
|
||||
"""Test that a conditional task with a boolean False condition is skipped."""
|
||||
task1 = Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
# Use a boolean False directly as the condition
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional task that should be skipped",
|
||||
expected_output="This should not be executed",
|
||||
agent=writer,
|
||||
condition=False,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
)
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task 1 description",
|
||||
raw="Task 1 output",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Only the first task should be executed
|
||||
assert mock_execute_sync.call_count == 1
|
||||
|
||||
# The conditional task should be skipped
|
||||
assert task2.output is not None
|
||||
assert task2.output.raw == ""
|
||||
|
||||
# The final output should be from the first task
|
||||
assert result.raw.startswith("Task 1 output")
|
||||
|
||||
|
||||
def test_conditional_task_with_boolean_true():
|
||||
"""Test that a conditional task with a boolean True condition is executed."""
|
||||
task1 = Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
# Use a boolean True directly as the condition
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional task that should be executed",
|
||||
expected_output="This should be executed",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
)
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task output",
|
||||
raw="Task output",
|
||||
agent="Agent",
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
# Both tasks should be executed
|
||||
assert mock_execute_sync.call_count == 2
|
||||
|
||||
|
||||
def test_multiple_sequential_conditional_tasks():
|
||||
"""Test that multiple conditional tasks in sequence work correctly."""
|
||||
task1 = Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
# First conditional task (will be executed)
|
||||
task2 = ConditionalTask(
|
||||
description="First conditional task",
|
||||
expected_output="First conditional output",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
# Second conditional task (will be skipped)
|
||||
task3 = ConditionalTask(
|
||||
description="Second conditional task",
|
||||
expected_output="Second conditional output",
|
||||
agent=researcher,
|
||||
condition=False,
|
||||
)
|
||||
|
||||
# Third conditional task (will be executed)
|
||||
task4 = ConditionalTask(
|
||||
description="Third conditional task",
|
||||
expected_output="Third conditional output",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2, task3, task4],
|
||||
)
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task output",
|
||||
raw="Task output",
|
||||
agent="Agent",
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Tasks 1, 2, and 4 should be executed (task 3 is skipped)
|
||||
assert mock_execute_sync.call_count == 3
|
||||
|
||||
# Task 3 should be skipped
|
||||
assert task3.output is not None
|
||||
assert task3.output.raw == ""
|
||||
|
||||
|
||||
def test_last_task_conditional():
|
||||
"""Test that a conditional task at the end of the task list works correctly."""
|
||||
task1 = Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
# Last task is conditional and will be skipped
|
||||
task2 = ConditionalTask(
|
||||
description="Last conditional task",
|
||||
expected_output="Last conditional output",
|
||||
agent=writer,
|
||||
condition=False,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task1, task2],
|
||||
)
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task 1 output",
|
||||
raw="Task 1 output",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Only the first task should be executed
|
||||
assert mock_execute_sync.call_count == 1
|
||||
|
||||
# The conditional task should be skipped
|
||||
assert task2.output is not None
|
||||
assert task2.output.raw == ""
|
||||
|
||||
# The final output should be from the first task
|
||||
assert result.raw.startswith("Task 1 output")
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import Future
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -35,6 +36,11 @@ from crewai.utilities.events.crew_events import (
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
|
||||
# Skip streaming tests when running in CI/CD environments
|
||||
skip_streaming_in_ci = pytest.mark.skipif(
|
||||
os.getenv("CI") is not None, reason="Skipping streaming tests in CI/CD environments"
|
||||
)
|
||||
|
||||
ceo = Agent(
|
||||
role="CEO",
|
||||
goal="Make sure the writers in your company produce amazing content.",
|
||||
@@ -948,6 +954,7 @@ def test_api_calls_throttling(capsys):
|
||||
moveon.assert_called()
|
||||
|
||||
|
||||
@skip_streaming_in_ci
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_usage_metrics():
|
||||
inputs = [
|
||||
@@ -960,6 +967,7 @@ def test_crew_kickoff_usage_metrics():
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
llm=LLM(model="gpt-4o"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
@@ -968,12 +976,50 @@ def test_crew_kickoff_usage_metrics():
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Use real LLM calls instead of mocking
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
for result in results:
|
||||
# Assert that all required keys are in usage_metrics and their values are not None
|
||||
# Assert that all required keys are in usage_metrics and their values are greater than 0
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests > 0
|
||||
assert result.token_usage.cached_prompt_tokens == 0
|
||||
|
||||
|
||||
@skip_streaming_in_ci
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_streaming_usage_metrics():
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
llm=LLM(model="gpt-4o", stream=True),
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Use real LLM calls instead of mocking
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
for result in results:
|
||||
# Assert that all required keys are in usage_metrics and their values are greater than 0
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
@@ -3973,3 +4019,5 @@ def test_crew_with_knowledge_sources_works_with_copy():
|
||||
assert crew_copy.knowledge_sources == crew.knowledge_sources
|
||||
assert len(crew_copy.agents) == len(crew.agents)
|
||||
assert len(crew_copy.tasks) == len(crew.tasks)
|
||||
|
||||
assert len(crew_copy.tasks) == len(crew.tasks)
|
||||
|
||||
@@ -1,570 +0,0 @@
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class CustomLLM(LLM):
|
||||
"""Custom LLM implementation for testing.
|
||||
|
||||
This is a simple implementation of the LLM abstract base class
|
||||
that returns a predefined response for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self, response: str = "Custom LLM response"):
|
||||
"""Initialize the CustomLLM with a predefined response.
|
||||
|
||||
Args:
|
||||
response: The predefined response to return from call().
|
||||
"""
|
||||
super().__init__()
|
||||
self.response = response
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Record the call and return the predefined response.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The predefined response string.
|
||||
"""
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
return self.response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_custom_llm_implementation():
|
||||
"""Test that a custom LLM implementation works with create_llm."""
|
||||
custom_llm = CustomLLM(response="The answer is 42")
|
||||
|
||||
# Test that create_llm returns the custom LLM instance directly
|
||||
result_llm = create_llm(custom_llm)
|
||||
|
||||
assert result_llm is custom_llm
|
||||
|
||||
# Test calling the custom LLM
|
||||
response = result_llm.call("What is the answer to life, the universe, and everything?")
|
||||
|
||||
# Verify that the custom LLM was called
|
||||
assert len(custom_llm.calls) > 0
|
||||
# Verify that the response from the custom LLM was used
|
||||
assert response == "The answer is 42"
|
||||
|
||||
|
||||
class JWTAuthLLM(LLM):
|
||||
"""Custom LLM implementation with JWT authentication.
|
||||
|
||||
This class demonstrates how to implement a custom LLM that uses JWT
|
||||
authentication instead of API key-based authentication. It validates
|
||||
the JWT token before each call and checks for token expiration.
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_token: str, expiration_buffer: int = 60):
|
||||
"""Initialize the JWTAuthLLM with a JWT token.
|
||||
|
||||
Args:
|
||||
jwt_token: The JWT token to use for authentication.
|
||||
expiration_buffer: Buffer time in seconds to warn about token expiration.
|
||||
Default is 60 seconds.
|
||||
|
||||
Raises:
|
||||
ValueError: If the JWT token is invalid or missing.
|
||||
"""
|
||||
super().__init__()
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
raise ValueError("Invalid JWT token")
|
||||
|
||||
self.jwt_token = jwt_token
|
||||
self.expiration_buffer = expiration_buffer
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
# Validate the token immediately
|
||||
self._validate_token()
|
||||
|
||||
def _validate_token(self) -> None:
|
||||
"""Validate the JWT token.
|
||||
|
||||
Checks if the token is valid and not expired. Also warns if the token
|
||||
is about to expire within the expiration_buffer time.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid, expired, or malformed.
|
||||
"""
|
||||
try:
|
||||
# Decode without verification to check expiration
|
||||
# In a real implementation, you would verify the signature
|
||||
decoded = jwt.decode(self.jwt_token, options={"verify_signature": False})
|
||||
|
||||
# Check if token is expired or about to expire
|
||||
if 'exp' in decoded:
|
||||
expiration_time = decoded['exp']
|
||||
current_time = time.time()
|
||||
|
||||
if expiration_time < current_time:
|
||||
raise ValueError("JWT token has expired")
|
||||
|
||||
if expiration_time < current_time + self.expiration_buffer:
|
||||
# Token will expire soon, log a warning
|
||||
import logging
|
||||
logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds")
|
||||
except jwt.PyJWTError as e:
|
||||
raise ValueError(f"Invalid JWT token format: {str(e)}")
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with JWT authentication.
|
||||
|
||||
Validates the JWT token before making the call to ensure it's still valid.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The LLM response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the JWT token is invalid or expired.
|
||||
TimeoutError: If the request times out.
|
||||
ConnectionError: If there's a connection issue.
|
||||
"""
|
||||
# Validate token before making the call
|
||||
self._validate_token()
|
||||
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
# In a real implementation, this would use the JWT token to authenticate
|
||||
# with an external service
|
||||
return "Response from JWT-authenticated LLM"
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported."""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size."""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_custom_llm_with_jwt_auth():
|
||||
"""Test a custom LLM implementation with JWT authentication."""
|
||||
# Create a valid JWT token that expires 1 hour from now
|
||||
valid_token = jwt.encode(
|
||||
{"exp": int(time.time()) + 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||
|
||||
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||
result_llm = create_llm(jwt_llm)
|
||||
|
||||
assert result_llm is jwt_llm
|
||||
|
||||
# Test calling the JWT-authenticated LLM
|
||||
response = result_llm.call("Test message")
|
||||
|
||||
# Verify that the JWT-authenticated LLM was called
|
||||
assert len(jwt_llm.calls) > 0
|
||||
# Verify that the response from the JWT-authenticated LLM was used
|
||||
assert response == "Response from JWT-authenticated LLM"
|
||||
|
||||
|
||||
def test_jwt_auth_llm_validation():
|
||||
"""Test that JWT token validation works correctly."""
|
||||
# Test with invalid JWT token (empty string)
|
||||
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||
JWTAuthLLM(jwt_token="")
|
||||
|
||||
# Test with invalid JWT token (non-string)
|
||||
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||
JWTAuthLLM(jwt_token=None)
|
||||
|
||||
# Test with expired token
|
||||
# Create a token that expired 1 hour ago
|
||||
expired_token = jwt.encode(
|
||||
{"exp": int(time.time()) - 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
with pytest.raises(ValueError, match="JWT token has expired"):
|
||||
JWTAuthLLM(jwt_token=expired_token)
|
||||
|
||||
# Test with malformed token
|
||||
with pytest.raises(ValueError, match="Invalid JWT token format"):
|
||||
JWTAuthLLM(jwt_token="not.a.valid.jwt.token")
|
||||
|
||||
# Test with valid token
|
||||
# Create a token that expires 1 hour from now
|
||||
valid_token = jwt.encode(
|
||||
{"exp": int(time.time()) + 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
# This should not raise an exception
|
||||
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||
assert jwt_llm.jwt_token == valid_token
|
||||
|
||||
|
||||
class TimeoutHandlingLLM(LLM):
|
||||
"""Custom LLM implementation with timeout handling and retry logic."""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30):
|
||||
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
timeout: Timeout in seconds for each API call.
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
self.fail_count = 0 # Number of times to simulate failure
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Simulate API calls with timeout handling and retry logic.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
A response string based on whether this is the first attempt or a retry.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
"""
|
||||
# Record the initial call
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"attempt": 0
|
||||
})
|
||||
|
||||
# Simulate retry logic
|
||||
for attempt in range(self.max_retries):
|
||||
# Skip the first attempt recording since we already did that above
|
||||
if attempt == 0:
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on first attempt
|
||||
return "First attempt response"
|
||||
else:
|
||||
# This is a retry attempt (attempt > 0)
|
||||
# Always record retry attempts
|
||||
self.calls.append({
|
||||
"retry_attempt": attempt,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on retry
|
||||
return "Response after retry"
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_timeout_handling_llm():
|
||||
"""Test a custom LLM implementation with timeout handling and retry logic."""
|
||||
# Test successful first attempt
|
||||
llm = TimeoutHandlingLLM()
|
||||
response = llm.call("Test message")
|
||||
assert response == "First attempt response"
|
||||
assert len(llm.calls) == 1
|
||||
|
||||
# Test successful retry
|
||||
llm = TimeoutHandlingLLM()
|
||||
llm.fail_count = 1 # Fail once, then succeed
|
||||
response = llm.call("Test message")
|
||||
assert response == "Response after retry"
|
||||
assert len(llm.calls) == 2 # Initial call + successful retry call
|
||||
|
||||
# Test failure after all retries
|
||||
llm = TimeoutHandlingLLM(max_retries=2)
|
||||
llm.fail_count = 2 # Fail twice, which is all retries
|
||||
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||
llm.call("Test message")
|
||||
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||
|
||||
|
||||
def test_rate_limited_llm():
|
||||
"""Test that rate limiting works correctly."""
|
||||
# Create a rate limited LLM with a very low limit (2 requests per minute)
|
||||
llm = RateLimitedLLM(requests_per_minute=2)
|
||||
|
||||
# First request should succeed
|
||||
response1 = llm.call("Test message 1")
|
||||
assert response1 == "Rate limited response"
|
||||
assert len(llm.calls) == 1
|
||||
|
||||
# Second request should succeed
|
||||
response2 = llm.call("Test message 2")
|
||||
assert response2 == "Rate limited response"
|
||||
assert len(llm.calls) == 2
|
||||
|
||||
# Third request should fail due to rate limiting
|
||||
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||
llm.call("Test message 3")
|
||||
|
||||
# Test with invalid requests_per_minute
|
||||
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||
RateLimitedLLM(requests_per_minute=0)
|
||||
|
||||
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||
RateLimitedLLM(requests_per_minute=-1)
|
||||
|
||||
|
||||
def test_rate_limit_reset():
|
||||
"""Test that rate limits reset after the time window passes."""
|
||||
# Create a rate limited LLM with a very low limit (1 request per minute)
|
||||
# and a short time window for testing (1 second instead of 60 seconds)
|
||||
time_window = 1 # 1 second instead of 60 seconds
|
||||
llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window)
|
||||
|
||||
# First request should succeed
|
||||
response1 = llm.call("Test message 1")
|
||||
assert response1 == "Rate limited response"
|
||||
|
||||
# Second request should fail due to rate limiting
|
||||
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||
llm.call("Test message 2")
|
||||
|
||||
# Wait for the rate limit to reset
|
||||
import time
|
||||
time.sleep(time_window + 0.1) # Add a small buffer
|
||||
|
||||
# After waiting, we should be able to make another request
|
||||
response3 = llm.call("Test message 3")
|
||||
assert response3 == "Rate limited response"
|
||||
assert len(llm.calls) == 2 # First and third requests
|
||||
|
||||
|
||||
class RateLimitedLLM(LLM):
|
||||
"""Custom LLM implementation with rate limiting.
|
||||
|
||||
This class demonstrates how to implement a custom LLM with rate limiting
|
||||
capabilities. It uses a sliding window algorithm to ensure that no more
|
||||
than a specified number of requests are made within a given time period.
|
||||
"""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60):
|
||||
"""Initialize the RateLimitedLLM with rate limiting parameters.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum number of requests allowed per minute.
|
||||
base_response: Default response to return.
|
||||
time_window: Time window in seconds for rate limiting (default: 60).
|
||||
This is configurable for testing purposes.
|
||||
|
||||
Raises:
|
||||
ValueError: If requests_per_minute is not a positive integer.
|
||||
"""
|
||||
super().__init__()
|
||||
if not isinstance(requests_per_minute, int) or requests_per_minute <= 0:
|
||||
raise ValueError("requests_per_minute must be a positive integer")
|
||||
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.base_response = base_response
|
||||
self.time_window = time_window
|
||||
self.request_times = deque()
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
def _check_rate_limit(self) -> None:
|
||||
"""Check if the current request exceeds the rate limit.
|
||||
|
||||
This method implements a sliding window rate limiting algorithm.
|
||||
It keeps track of request timestamps and ensures that no more than
|
||||
`requests_per_minute` requests are made within the configured time window.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rate limit is exceeded.
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Remove requests older than the time window
|
||||
while self.request_times and current_time - self.request_times[0] > self.time_window:
|
||||
self.request_times.popleft()
|
||||
|
||||
# Check if we've exceeded the rate limit
|
||||
if len(self.request_times) >= self.requests_per_minute:
|
||||
wait_time = self.time_window - (current_time - self.request_times[0])
|
||||
raise ValueError(
|
||||
f"Rate limit exceeded. Maximum {self.requests_per_minute} "
|
||||
f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# Record this request
|
||||
self.request_times.append(current_time)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with rate limiting.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The LLM response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rate limit is exceeded.
|
||||
"""
|
||||
# Check rate limit before making the call
|
||||
self._check_rate_limit()
|
||||
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
return self.base_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
152
tests/flow_conditional_task_test.py
Normal file
152
tests/flow_conditional_task_test.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.flow import Flow, listen, start
|
||||
from crewai.project.annotations import task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
# Create mock agents for testing
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="Research information",
|
||||
backstory="You are a researcher with expertise in finding information.",
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Writer",
|
||||
goal="Write content",
|
||||
backstory="You are a writer with expertise in creating engaging content.",
|
||||
)
|
||||
|
||||
|
||||
class TestFlowWithConditionalTasks(Flow):
|
||||
"""Test flow with conditional tasks."""
|
||||
|
||||
@start()
|
||||
@task
|
||||
def initial_task(self):
|
||||
"""Initial task that always executes."""
|
||||
return Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
@listen(initial_task)
|
||||
@task
|
||||
def conditional_task_false(self):
|
||||
"""Conditional task that should be skipped."""
|
||||
return ConditionalTask(
|
||||
description="Conditional task that should be skipped",
|
||||
expected_output="This should not be executed",
|
||||
agent=writer,
|
||||
condition=False,
|
||||
)
|
||||
|
||||
@listen(initial_task)
|
||||
@task
|
||||
def conditional_task_true(self):
|
||||
"""Conditional task that should be executed."""
|
||||
return ConditionalTask(
|
||||
description="Conditional task that should be executed",
|
||||
expected_output="This should be executed",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
@listen(conditional_task_true)
|
||||
@task
|
||||
def final_task(self):
|
||||
"""Final task that executes after the conditional task."""
|
||||
return Task(
|
||||
description="Final task",
|
||||
expected_output="Final output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_with_conditional_tasks():
|
||||
"""Test that conditional tasks work correctly in a Flow."""
|
||||
flow = TestFlowWithConditionalTasks()
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task output",
|
||||
raw="Task output",
|
||||
agent="Agent",
|
||||
)
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
# The initial task, conditional_task_true, and final_task should be executed
|
||||
# conditional_task_false should be skipped
|
||||
assert mock_execute_sync.call_count == 3
|
||||
|
||||
|
||||
class TestFlowWithSequentialConditionalTasks(Flow):
|
||||
"""Test flow with sequential conditional tasks."""
|
||||
|
||||
@start()
|
||||
@task
|
||||
def initial_task(self):
|
||||
"""Initial task that always executes."""
|
||||
return Task(
|
||||
description="Initial task",
|
||||
expected_output="Initial output",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
@listen(initial_task)
|
||||
@task
|
||||
def conditional_task_1(self):
|
||||
"""First conditional task that should be executed."""
|
||||
return ConditionalTask(
|
||||
description="First conditional task",
|
||||
expected_output="First conditional output",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
@listen(conditional_task_1)
|
||||
@task
|
||||
def conditional_task_2(self):
|
||||
"""Second conditional task that should be skipped."""
|
||||
return ConditionalTask(
|
||||
description="Second conditional task",
|
||||
expected_output="Second conditional output",
|
||||
agent=researcher,
|
||||
condition=False,
|
||||
)
|
||||
|
||||
@listen(conditional_task_2)
|
||||
@task
|
||||
def conditional_task_3(self):
|
||||
"""Third conditional task that should be executed."""
|
||||
return ConditionalTask(
|
||||
description="Third conditional task",
|
||||
expected_output="Third conditional output",
|
||||
agent=writer,
|
||||
condition=True,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_with_sequential_conditional_tasks():
|
||||
"""Test that sequential conditional tasks work correctly in a Flow."""
|
||||
flow = TestFlowWithSequentialConditionalTasks()
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute_sync:
|
||||
mock_execute_sync.return_value = TaskOutput(
|
||||
description="Task output",
|
||||
raw="Task output",
|
||||
agent="Agent",
|
||||
)
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
# The initial_task and conditional_task_1 should be executed
|
||||
# conditional_task_2 should be skipped, and since it's skipped,
|
||||
# conditional_task_3 should not be triggered
|
||||
assert mock_execute_sync.call_count == 2
|
||||
@@ -219,7 +219,7 @@ def test_get_custom_llm_provider_gemini():
|
||||
|
||||
def test_get_custom_llm_provider_openai():
|
||||
llm = LLM(model="gpt-4")
|
||||
assert llm._get_custom_llm_provider() == "openai"
|
||||
assert llm._get_custom_llm_provider() == None
|
||||
|
||||
|
||||
def test_validate_call_params_supported():
|
||||
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
|
||||
assert isinstance(result, str)
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
def test_context_window_validation():
|
||||
"""Test that context window validation works correctly."""
|
||||
# Test valid window size
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Tell me a short joke"}], "model":
|
||||
"gpt-3.5-turbo", "stop": [], "stream": true}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '121'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=IY8ppO70AMHr2skDSUsGh71zqHHdCQCZ3OvkPi26NBc-1740424913267-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.65.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.65.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.8
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: 'data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Why"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
couldn"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"''t"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
the"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
bicycle"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
stand"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
up"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
by"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
itself"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
Because"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
it"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
was"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
|
||||
two"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"-t"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"ired"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
|
||||
|
||||
|
||||
data: [DONE]
|
||||
|
||||
|
||||
'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 91ab1bcbad95bcda-ATL
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- text/event-stream; charset=utf-8
|
||||
Date:
|
||||
- Mon, 03 Mar 2025 18:13:34 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=Jydtg8l0yjWRI2vKmejdq.C1W.sasIwEbTrV2rUt6V0-1741025614-1.0.1.1-Af3gmq.j2ecn9QEa3aCVY09QU4VqoW2GTk9AjvzPA.jyAZlwhJd4paniSt3kSusH0tryW03iC8uaX826hb2xzapgcfSm6Jdh_eWh_BMCh_8;
|
||||
path=/; expires=Mon, 03-Mar-25 18:43:34 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=5wzaJSCvT1p1Eazad55wDvp1JsgxrlghhmmU9tx0fMs-1741025614868-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '127'
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '50000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999978'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_2a2a04977ace88fdd64cf570f80c0202
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,107 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Tell me a short joke"}], "model":
|
||||
"gpt-4o", "stop": [], "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '115'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.65.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.65.1
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.8
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJBbtswELzrFVteerEKSZbrxpcCDuBTUfSUtigCgSZXEhuKJLirNEbg
|
||||
vxeSHMtBXSAXHmZ2BjPLfU4AhNFiA0K1klUXbLpde/X1tvtW/tnfrW6//Lzb7UraLn8s2+xpJxaD
|
||||
wu9/o+IX1Qflu2CRjXcTrSJKxsE1X5d5kRWrdT4SnddoB1kTOC19WmRFmWaf0uzjSdh6o5DEBn4l
|
||||
AADP4ztEdBqfxAayxQvSIZFsUGzOQwAiejsgQhIZYulYLGZSecfoxtTf2wNo794zkDLo2BATcOyJ
|
||||
QbLv6DNsUcmeELjFA3TyAaEPgI8YD9wa17y7NI5Y9ySHXq639oQfz0mtb0L0ezrxZ7w2zlBbRZTk
|
||||
3ZCK2AcxsscE4H7cSP+qpAjRd4Er9g/oBsO8mOzE/AVXSPYs7YwX5eKKW6WRpbF0sVGhpGpRz8p5
|
||||
/bLXxl8QyUXnf8Nc8556G9e8xX4mlMLAqKsQURv1uvA8FnE40P+NnXc8BhaE8dEorNhgHP5BYy17
|
||||
O92OoAMxdlVtXIMxRDMdUB2qWt3UuV5ny5VIjslfAAAA//8DADx20t9JAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 91bbfc033e461d6e-ATL
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 05 Mar 2025 19:22:51 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=LecfSlhN6VGr4kTlMiMCqRPInNb1m8zOikTZxtsE_WM-1741202571-1.0.1.1-T8nh2g1PcqyLIV97_HH9Q_nSUyCtaiFAOzvMxlswn6XjJCcSLJhi_fmkbylwppwoRPTxgs4S6VsVH0mp4ZcDTABBbtemKj7vS8QRDpRrmsU;
|
||||
path=/; expires=Wed, 05-Mar-25 19:52:51 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=wyMrJP5k5bgWyD8rsK4JPvAJ78JWrsrT0lyV9DP4WZM-1741202571727-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '416'
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999978'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_f42504d00bda0a492dced0ba3cf302d8
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -38,6 +39,7 @@ from crewai.utilities.events.llm_events import (
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.utilities.events.task_events import (
|
||||
TaskCompletedEvent,
|
||||
@@ -48,6 +50,11 @@ from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
|
||||
# Skip streaming tests when running in CI/CD environments
|
||||
skip_streaming_in_ci = pytest.mark.skipif(
|
||||
os.getenv("CI") is not None, reason="Skipping streaming tests in CI/CD environments"
|
||||
)
|
||||
|
||||
base_agent = Agent(
|
||||
role="base_agent",
|
||||
llm="gpt-4o-mini",
|
||||
@@ -615,3 +622,152 @@ def test_llm_emits_call_failed_event():
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].type == "llm_call_failed"
|
||||
assert received_events[0].error == error_message
|
||||
|
||||
|
||||
@skip_streaming_in_ci
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_emits_stream_chunk_events():
|
||||
"""Test that LLM emits stream chunk events when streaming is enabled."""
|
||||
received_chunks = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Verify that we received chunks
|
||||
assert len(received_chunks) > 0
|
||||
|
||||
# Verify that concatenating all chunks equals the final response
|
||||
assert "".join(received_chunks) == response
|
||||
|
||||
|
||||
@skip_streaming_in_ci
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_no_stream_chunks_when_streaming_disabled():
|
||||
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
|
||||
received_chunks = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
# Create an LLM with streaming disabled
|
||||
llm = LLM(model="gpt-4o", stream=False)
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Verify that we didn't receive any chunks
|
||||
assert len(received_chunks) == 0
|
||||
|
||||
# Verify we got a response
|
||||
assert response and isinstance(response, str)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_fallback_to_non_streaming():
|
||||
"""Test that streaming falls back to non-streaming when there's an error."""
|
||||
received_chunks = []
|
||||
fallback_called = False
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Create a mock call method that handles the streaming error
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
|
||||
# Return a response as if fallback succeeded
|
||||
return "Fallback response after streaming error"
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
try:
|
||||
# Call the LLM
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Verify that we received some chunks
|
||||
assert len(received_chunks) == 2
|
||||
assert received_chunks[0] == "Test chunk 1"
|
||||
assert received_chunks[1] == "Test chunk 2"
|
||||
|
||||
# Verify fallback was triggered
|
||||
assert fallback_called
|
||||
|
||||
# Verify we got the fallback response
|
||||
assert response == "Fallback response after streaming error"
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_empty_response_handling():
|
||||
"""Test that streaming handles empty responses correctly."""
|
||||
received_chunks = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-3.5-turbo", stream=True)
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Create a mock call method that simulates empty chunks
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
try:
|
||||
# Call the LLM - this should handle empty response
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Verify that we received empty chunks
|
||||
assert len(received_chunks) == 3
|
||||
assert all(chunk == "" for chunk in received_chunks)
|
||||
|
||||
# Verify the response is the default message for empty responses
|
||||
assert "I apologize" in response and "couldn't generate" in response
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
Reference in New Issue
Block a user