Compare commits

..

6 Commits

Author SHA1 Message Date
Devin AI
ad785baa16 Fix syntax errors in memory storage implementation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 20:31:39 +00:00
Devin AI
81e48947fb Fix import sorting issues to resolve linting errors
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 20:26:14 +00:00
Devin AI
b13590a359 Improve type system and test coverage for custom memory storage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 20:23:51 +00:00
Devin AI
541fa13df7 Fix remaining type checking issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 19:53:35 +00:00
Devin AI
5e29ac5f7b Fix linting and type checking issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 19:53:24 +00:00
Devin AI
583ac5711f Add support for custom memory storage implementations (fixes #2278)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-04 19:50:54 +00:00
35 changed files with 1100 additions and 5765 deletions

View File

@@ -0,0 +1,151 @@
# Custom Memory Storage
CrewAI supports custom memory storage implementations for different memory types. You can provide your own storage implementation by extending the `Storage` interface and passing it to the memory instances or through the `memory_config` parameter.
## Implementing a Custom Storage
To create a custom storage implementation, you need to extend the `Storage` interface and implement the required methods:
```python
from typing import Any, Dict, List
from crewai.memory.storage.interface import Storage
class CustomStorage(Storage):
"""Custom storage implementation."""
def __init__(self):
# Initialize your storage backend
self.data = []
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
# Implement your save logic
self.data.append({"value": value, "metadata": metadata})
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> List[Any]:
"""Search for values in the storage."""
# Implement your search logic
return [{"context": item["value"], "metadata": item["metadata"]} for item in self.data]
def reset(self) -> None:
"""Reset the storage."""
# Implement your reset logic
self.data = []
```
## Using Custom Storage
There are two ways to provide custom storage implementations to CrewAI:
### 1. Pass Custom Storage to Memory Instances
You can create memory instances with custom storage and pass them to the Crew:
```python
from crewai import Crew, Agent
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.user.user_memory import UserMemory
# Create custom storage instances
short_term_storage = CustomStorage()
long_term_storage = CustomStorage()
entity_storage = CustomStorage()
user_storage = CustomStorage()
# Create memory instances with custom storage
short_term_memory = ShortTermMemory(storage=short_term_storage)
long_term_memory = LongTermMemory(storage=long_term_storage)
entity_memory = EntityMemory(storage=entity_storage)
user_memory = UserMemory(storage=user_storage)
# Create a crew with custom memory instances
crew = Crew(
agents=[Agent(role="researcher", goal="research", backstory="I am a researcher")],
memory=True,
short_term_memory=short_term_memory,
long_term_memory=long_term_memory,
entity_memory=entity_memory,
memory_config={"user_memory": user_memory},
)
```
### 2. Pass Custom Storage through Memory Config
You can also provide custom storage implementations through the `memory_config` parameter:
```python
from crewai import Crew, Agent
# Create a crew with custom storage in memory_config
crew = Crew(
agents=[Agent(role="researcher", goal="research", backstory="I am a researcher")],
memory=True,
memory_config={
"storage": {
"short_term": CustomStorage(),
"long_term": CustomStorage(),
"entity": CustomStorage(),
"user": CustomStorage(),
}
},
)
```
## Example: Redis Storage
Here's an example of a custom storage implementation using Redis:
```python
import json
import redis
from typing import Any, Dict, List
from crewai.memory.storage.interface import Storage
class RedisStorage(Storage):
"""Redis-based storage implementation."""
def __init__(self, redis_url="redis://localhost:6379/0", prefix="crewai"):
self.redis = redis.from_url(redis_url)
self.prefix = prefix
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to Redis."""
key = f"{self.prefix}:{len(self.redis.keys(f'{self.prefix}:*'))}"
data = {"value": value, "metadata": metadata}
self.redis.set(key, json.dumps(data))
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> List[Any]:
"""Search for values in Redis."""
# This is a simple implementation that returns all values
# In a real implementation, you would use Redis search capabilities
results = []
for key in self.redis.keys(f"{self.prefix}:*"):
data = json.loads(self.redis.get(key))
results.append({"context": data["value"], "metadata": data["metadata"]})
if len(results) >= limit:
break
return results
def reset(self) -> None:
"""Reset the Redis storage."""
for key in self.redis.keys(f"{self.prefix}:*"):
self.redis.delete(key)
```
## Benefits of Custom Storage
Using custom storage implementations allows you to:
1. Store memory data in external databases or services
2. Implement custom search algorithms
3. Share memory between different crews or applications
4. Persist memory across application restarts
5. Implement custom memory retention policies
By extending the `Storage` interface, you can integrate CrewAI with any storage backend that suits your needs.

View File

@@ -224,7 +224,6 @@ CrewAI provides a wide range of events that you can listen for:
- **LLMCallStartedEvent**: Emitted when an LLM call starts
- **LLMCallCompletedEvent**: Emitted when an LLM call completes
- **LLMCallFailedEvent**: Emitted when an LLM call fails
- **LLMStreamChunkEvent**: Emitted for each chunk received during streaming LLM responses
## Event Handler Structure

View File

@@ -540,46 +540,6 @@ In this section, you'll find detailed examples that help you select, configure,
</Accordion>
</AccordionGroup>
## Streaming Responses
CrewAI supports streaming responses from LLMs, allowing your application to receive and process outputs in real-time as they're generated.
<Tabs>
<Tab title="Basic Setup">
Enable streaming by setting the `stream` parameter to `True` when initializing your LLM:
```python
from crewai import LLM
# Create an LLM with streaming enabled
llm = LLM(
model="openai/gpt-4o",
stream=True # Enable streaming
)
```
When streaming is enabled, responses are delivered in chunks as they're generated, creating a more responsive user experience.
</Tab>
<Tab title="Event Handling">
CrewAI emits events for each chunk received during streaming:
```python
from crewai import LLM
from crewai.utilities.events import EventHandler, LLMStreamChunkEvent
class MyEventHandler(EventHandler):
def on_llm_stream_chunk(self, event: LLMStreamChunkEvent):
# Process each chunk as it arrives
print(f"Received chunk: {event.chunk}")
# Register the event handler
from crewai.utilities.events import crewai_event_bus
crewai_event_bus.register_handler(MyEventHandler())
```
</Tab>
</Tabs>
## Structured LLM Calls
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.
@@ -709,4 +669,46 @@ Learn how to get the most out of your LLM configuration:
Use larger context models for extensive tasks
</Tip>
```python
# Large context model
llm = LLM(model="openai/gpt-4o") # 128K tokens
```
</Tab>
</Tabs>
## Getting Help
If you need assistance, these resources are available:
<CardGroup cols={3}>
<Card
title="LiteLLM Documentation"
href="https://docs.litellm.ai/docs/"
icon="book"
>
Comprehensive documentation for LiteLLM integration and troubleshooting common issues.
</Card>
<Card
title="GitHub Issues"
href="https://github.com/joaomdmoura/crewAI/issues"
icon="bug"
>
Report bugs, request features, or browse existing issues for solutions.
</Card>
<Card
title="Community Forum"
href="https://community.crewai.com"
icon="comment-question"
>
Connect with other CrewAI users, share experiences, and get help from the community.
</Card>
</CardGroup>
<Note>
Best Practices for API Key Security:
- Use environment variables or secure vaults
- Never commit keys to version control
- Rotate keys regularly
- Use separate keys for development and production
- Monitor key usage for unusual patterns
</Note>

View File

@@ -184,7 +184,7 @@ class Crew(BaseModel):
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
)
prompt_file: Optional[str] = Field(
prompt_file: str = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
@@ -262,8 +262,19 @@ class Crew(BaseModel):
def create_crew_memory(self) -> "Crew":
"""Set private attributes."""
if self.memory:
from crewai.memory.storage.rag_storage import RAGStorage
# Create default storage instances for each memory type if needed
long_term_storage = RAGStorage(type="long_term", crew=self, embedder_config=self.embedder)
short_term_storage = RAGStorage(type="short_term", crew=self, embedder_config=self.embedder)
entity_storage = RAGStorage(type="entity", crew=self, embedder_config=self.embedder)
self._long_term_memory = (
self.long_term_memory if self.long_term_memory else LongTermMemory()
self.long_term_memory if self.long_term_memory else LongTermMemory(
crew=self,
embedder_config=self.embedder,
storage=long_term_storage
)
)
self._short_term_memory = (
self.short_term_memory
@@ -271,12 +282,17 @@ class Crew(BaseModel):
else ShortTermMemory(
crew=self,
embedder_config=self.embedder,
storage=short_term_storage
)
)
self._entity_memory = (
self.entity_memory
if self.entity_memory
else EntityMemory(crew=self, embedder_config=self.embedder)
else EntityMemory(
crew=self,
embedder_config=self.embedder,
storage=entity_storage
)
)
if (
self.memory_config and "user_memory" in self.memory_config
@@ -808,7 +824,6 @@ class Crew(BaseModel):
)
if skipped_task_output:
task_outputs.append(skipped_task_output)
last_sync_output = skipped_task_output
continue
if task.async_execution:
@@ -822,10 +837,8 @@ class Crew(BaseModel):
)
futures.append((task, future, task_index))
else:
# Process any pending async tasks before executing a sync task
if futures:
processed_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(processed_outputs)
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
context = self._get_context(task, task_outputs)
@@ -835,14 +848,11 @@ class Crew(BaseModel):
tools=tools_for_task,
)
task_outputs.append(task_output)
last_sync_output = task_output
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, was_replayed)
# Process any remaining async tasks at the end
if futures:
processed_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(processed_outputs)
task_outputs = self._process_async_tasks(futures, was_replayed)
return self._create_crew_output(task_outputs)
@@ -854,17 +864,12 @@ class Crew(BaseModel):
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
# Process any pending async tasks to ensure we have the most up-to-date context
if futures:
processed_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(processed_outputs)
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
# Get the previous output to evaluate the condition
previous_output = task_outputs[-1] if task_outputs else None
# If there's no previous output or the condition evaluates to False, skip the task
if previous_output is None or not task.should_execute(previous_output):
if previous_output is not None and not task.should_execute(previous_output):
self._logger.log(
"debug",
f"Skipping conditional task: {task.description}",
@@ -872,13 +877,8 @@ class Crew(BaseModel):
)
skipped_task_output = task.get_skipped_task_output()
# Store the execution log for the skipped task
if not was_replayed:
self._store_execution_log(task, skipped_task_output, task_index)
# Set the output on the task itself so it can be referenced later
task.output = skipped_task_output
return skipped_task_output
return None

View File

@@ -1,50 +0,0 @@
from functools import wraps
from typing import Any, Callable, Optional, Union, cast
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
def task(func: Callable) -> Callable:
"""
Decorator for Flow methods that return a Task.
This decorator ensures that when a method returns a ConditionalTask,
the condition is properly evaluated based on the previous task's output.
Args:
func: The method to decorate
Returns:
The decorated method
"""
setattr(func, "is_task", True)
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
# Set the task name if not already set
if hasattr(result, "name") and not result.name:
result.name = func.__name__
# If this is a ConditionalTask, ensure it has a valid condition
if isinstance(result, ConditionalTask):
# If the condition is a boolean, wrap it in a function
if isinstance(result.condition, bool):
bool_value = result.condition
result.condition = lambda _: bool_value
# Get the previous task output if available
previous_outputs = getattr(self, "_method_outputs", [])
previous_output = previous_outputs[-1] if previous_outputs else None
# If there's a previous output and it's a TaskOutput, check if we should execute
if previous_output and isinstance(previous_output, TaskOutput):
if not result.should_execute(previous_output):
# Return a skipped task output instead of the task
return result.get_skipped_task_output()
return result
return wrapper

View File

@@ -5,17 +5,7 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from dotenv import load_dotenv
from pydantic import BaseModel
@@ -25,7 +15,6 @@ from crewai.utilities.events.llm_events import (
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
@@ -33,11 +22,8 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import Choices
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
from litellm.utils import get_supported_openai_params, supports_response_schema
from crewai.utilities.events import crewai_event_bus
@@ -140,17 +126,6 @@ def suppress_warnings():
sys.stderr = old_stderr
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
class LLM:
def __init__(
self,
@@ -175,7 +150,6 @@ class LLM:
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
**kwargs,
):
self.model = model
@@ -201,7 +175,6 @@ class LLM:
self.reasoning_effort = reasoning_effort
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
litellm.drop_params = True
@@ -228,432 +201,6 @@ class LLM:
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns:
Dict[str, Any]: Parameters for the completion call
"""
# --- 1) Format messages according to provider requirements
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": self.stream,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
return {k: v for k, v in params.items() if v is not None}
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a streaming response from the LLM.
Args:
params: Parameters for the completion call
callbacks: Optional list of callback functions
available_functions: Dict of available functions
Returns:
str: The complete response text
Raises:
Exception: If no content is received from the streaming response
"""
# --- 1) Initialize response tracking
full_response = ""
last_chunk = None
chunk_count = 0
usage_info = None
# --- 2) Make sure stream is set to True and include usage metrics
params["stream"] = True
params["stream_options"] = {"include_usage": True}
try:
# --- 3) Process each chunk in the stream
for chunk in litellm.completion(**params):
chunk_count += 1
last_chunk = chunk
# Extract content from the chunk
chunk_content = None
# Safely extract content from various chunk formats
try:
# Try to access choices safely
choices = None
if isinstance(chunk, dict) and "choices" in chunk:
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if choices and len(choices) > 0:
choice = choices[0]
# Handle different delta formats
delta = None
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
# Extract content from delta
if delta:
# Handle dict format
if isinstance(delta, dict):
if "content" in delta and delta["content"] is not None:
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
# Some models might send empty content chunks
chunk_content = ""
except Exception as e:
logging.debug(f"Error extracting content from chunk: {e}")
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
# Only add non-None content to the response
if chunk_content is not None:
# Add the chunk content to the full response
full_response += chunk_content
# Emit the chunk event
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(chunk=chunk_content),
)
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
"No chunks received in streaming response, falling back to non-streaming"
)
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
non_streaming_params.pop(
"stream_options", None
) # Remove stream_options for non-streaming call
return self._handle_non_streaming_response(
non_streaming_params, callbacks, available_functions
)
# --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0:
logging.warning(
f"Received {chunk_count} chunks but no content was extracted"
)
if last_chunk is not None:
try:
# Try to extract content from the last chunk's message
choices = None
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if choices and len(choices) > 0:
choice = choices[0]
# Try to get content from message
message = None
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
if content:
full_response = content
logging.info(
f"Extracted content from last chunk message: {full_response}"
)
except Exception as e:
logging.debug(f"Error extracting content from last chunk: {e}")
logging.debug(
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
)
# --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip():
raise Exception(
"No content received from streaming response. Received empty chunks or failed to extract content."
)
# --- 7) Check for tool calls in the final response
tool_calls = None
try:
if last_chunk:
choices = None
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if choices and len(choices) > 0:
choice = choices[0]
message = None
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
if not tool_calls or not available_functions:
# Log token usage if available in streaming mode
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
# Emit completion event and return response
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
# --- 9) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
if tool_result is not None:
return tool_result
# --- 10) Log token usage if available in streaming mode
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
# --- 11) Emit completion event and return response
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
# Emit failed event and re-raise the exception
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=str(e)),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
) -> None:
"""Handle callbacks with usage info for streaming responses.
Args:
callbacks: Optional list of callback functions
usage_info: Usage information collected during streaming
last_chunk: The last chunk received from the streaming response
"""
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
# Use the usage_info we've been tracking
if not usage_info:
# Try to get usage from the last chunk if we haven't already
try:
if last_chunk:
if (
isinstance(last_chunk, dict)
and "usage" in last_chunk
):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
):
usage_info = getattr(last_chunk, "usage")
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
if usage_info:
callback.log_success_event(
kwargs={}, # We don't have the original params here
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a non-streaming response from the LLM.
Args:
params: Parameters for the completion call
callbacks: Optional list of callback functions
available_functions: Dict of available functions
Returns:
str: The response text
"""
# --- 1) Make the completion call
response = litellm.completion(**params)
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
# --- 3) Handle callbacks with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
# --- 4) Check for tool calls
tool_calls = getattr(response_message, "tool_calls", [])
# --- 5) If no tool calls or no available functions, return the text response directly
if not tool_calls or not available_functions:
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
# --- 6) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
if tool_result is not None:
return tool_result
# --- 7) If tool call handling didn't return a result, emit completion event and return text response
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""Handle a tool call from the LLM.
Args:
tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
return None
# --- 2) Extract function name from first tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
function_args = {} # Initialize to empty dict to avoid unbound variable
# --- 3) Check if function is available
if function_name in available_functions:
try:
# --- 3.1) Parse function arguments
function_args = json.loads(tool_call.function.arguments)
fn = available_functions[function_name]
# --- 3.2) Execute function
result = fn(**function_args)
# --- 3.3) Emit success event
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
# --- 3.4) Handle execution errors
fn = available_functions.get(
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
)
return None
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -683,8 +230,22 @@ class LLM:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
Examples:
# Example 1: Simple string input
>>> response = llm.call("Return the name of a random city.")
>>> print(response)
"Paris"
# Example 2: Message list with system and user messages
>>> messages = [
... {"role": "system", "content": "You are a geography expert"},
... {"role": "user", "content": "What is France's capital?"}
... ]
>>> response = llm.call(messages)
>>> print(response)
"The capital of France is Paris."
"""
# --- 1) Emit call started event
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -694,38 +255,127 @@ class LLM:
available_functions=available_functions,
),
)
# --- 2) Validate parameters before proceeding with the call
# Validate parameters before proceeding with the call.
self._validate_call_params()
# --- 3) Convert string messages to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# --- 4) Handle O1 model special case (system messages not supported)
# For O1 models, system messages are not supported.
# Convert any system messages into assistant messages.
if "o1" in self.model.lower():
for message in messages:
if message.get("role") == "system":
message["role"] = "assistant"
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 1) Format messages according to provider requirements
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
# Remove None values from params
params = {k: v for k, v in params.items() if v is not None}
# --- 2) Make the completion call
response = litellm.completion(**params)
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 3) Handle callbacks with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
# --- 4) If no tool calls, return the text response
if not tool_calls or not available_functions:
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
return text_response
# --- 5) Handle the tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
if function_name in available_functions:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args)
self._handle_emit_call_events(result, LLMCallType.TOOL_CALL)
return result
except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
crewai_event_bus.emit(
self,
event=ToolExecutionErrorEvent(
tool_name=function_name,
tool_args=function_args,
tool_class=fn,
error=str(e),
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=f"Tool execution error: {str(e)}"
),
)
return text_response
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
params, callbacks, available_functions
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
return text_response
except Exception as e:
crewai_event_bus.emit(
@@ -776,20 +426,6 @@ class LLM:
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
)
# Handle O1 models specially
if "o1" in self.model.lower():
formatted_messages = []
for msg in messages:
# Convert system messages to assistant messages
if msg["role"] == "system":
formatted_messages.append(
{"role": "assistant", "content": msg["content"]}
)
else:
formatted_messages.append(msg)
return formatted_messages
# Handle Anthropic models
if not self.is_anthropic:
return messages
@@ -800,7 +436,7 @@ class LLM:
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
def _get_custom_llm_provider(self) -> str:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -809,7 +445,7 @@ class LLM:
"""
if "/" in self.model:
return self.model.split("/")[0]
return None
return "openai"
def _validate_call_params(self) -> None:
"""
@@ -832,12 +468,10 @@ class LLM:
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
params = get_supported_openai_params(model=self.model)
return params is not None and "tools" in params
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.error(f"Failed to get supported params: {str(e)}")
return False
def supports_stop_words(self) -> bool:

View File

@@ -47,7 +47,7 @@ class ContextualMemory:
stm_results = self.stm.search(query)
formatted_results = "\n".join(
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
f"- {result.get('memory', result.get('context', ''))}"
for result in stm_results
]
)
@@ -58,7 +58,7 @@ class ContextualMemory:
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
ltm_results = self.ltm.search(task, latest_n=2)
ltm_results = self.ltm.search(query=task, limit=2)
if not ltm_results:
return None
@@ -80,9 +80,9 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
f"- {result.get('memory', result.get('context', ''))}"
for result in em_results
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
]
)
return f"Entities:\n{formatted_results}" if em_results else ""
@@ -99,6 +99,6 @@ class ContextualMemory:
return ""
formatted_memories = "\n".join(
f"- {result['memory']}" for result in user_memories
f"- {result.get('memory', result.get('context', ''))}" for result in user_memories
)
return f"User memories/preferences:\n{formatted_memories}"

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import PrivateAttr
@@ -17,47 +17,71 @@ class EntityMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = None
memory_config = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
memory_provider = None
if memory_provider == "mem0":
memory_config = crew.memory_config
memory_provider = memory_config.get("provider")
# If no storage is provided, try to create one
if storage is None:
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
storage = Mem0Storage(type="entities", crew=crew)
else:
storage = (
storage
if storage
else RAGStorage(
type="entities",
allow_reset=True,
embedder_config=embedder_config,
# Try to select storage using helper method
storage = self._select_storage(
storage=storage,
memory_config=memory_config,
storage_type="entity",
crew=crew,
path=path,
default_storage_factory=lambda path, crew: RAGStorage(
type="entities",
allow_reset=True,
crew=crew,
embedder_config=embedder_config,
path=path,
)
)
)
except ValueError:
# Fallback to default storage
storage = RAGStorage(
type="entities",
allow_reset=True,
crew=crew,
embedder_config=embedder_config,
path=path,
)
# Initialize with parameters
super().__init__(
storage=storage,
embedder_config=embedder_config,
memory_provider=memory_provider
)
super().__init__(storage=storage)
self._memory_provider = memory_provider
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
"""Saves an entity item or value into the storage."""
if isinstance(value, EntityMemoryItem):
item = value
if self.memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
# Handle regular value and metadata
super().save(value, metadata, agent)
def reset(self) -> None:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -14,23 +14,77 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = None
memory_config = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_config = crew.memory_config
memory_provider = memory_config.get("provider")
# Initialize with basic parameters
super().__init__(
storage=storage,
embedder_config=embedder_config,
memory_provider=memory_provider
)
try:
# Try to select storage using helper method
self.storage = self._select_storage(
storage=storage,
memory_config=memory_config,
storage_type="long_term",
crew=crew,
path=path,
default_storage_factory=lambda path, crew: LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
)
except ValueError:
# Fallback to default storage
self.storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
"""Saves a value into the memory."""
if isinstance(value, LongTermMemoryItem):
item = value
item_metadata = item.metadata or {}
item_metadata.update({"agent": item.agent, "expected_output": item.expected_output})
# Handle special storage types like Mem0Storage
if hasattr(self.storage, "save") and callable(getattr(self.storage, "save")) and hasattr(self.storage.save, "__code__") and "task_description" in self.storage.save.__code__.co_varnames:
self.storage.save(
task_description=item.task,
score=item_metadata.get("quality", 0),
metadata=item_metadata,
datetime=item.datetime,
)
else:
# Use standard storage interface
self.storage.save(item.task, item_metadata)
else:
# Handle regular value and metadata
super().save(value, metadata, agent)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
"""Search for values in the memory."""
# Try to use the standard storage interface first
if hasattr(self.storage, "search") and callable(getattr(self.storage, "search")):
return self.storage.search(query=query, limit=limit, score_threshold=score_threshold)
# Fall back to load method for backward compatibility
elif hasattr(self.storage, "load") and callable(getattr(self.storage, "load")):
return self.storage.load(query, limit)
else:
raise AttributeError("Storage does not implement search or load method")
def reset(self) -> None:
self.storage.reset()

View File

@@ -1,20 +1,62 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, cast
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field
from crewai.memory.storage.interface import SearchResult, Storage
class Memory(BaseModel):
T = TypeVar('T', bound=Storage)
class Memory(BaseModel, Generic[T]):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
embedder_config: Optional[Dict[str, Any]] = None
storage: T
memory_provider: Optional[str] = Field(default=None, exclude=True)
storage: Any
def __init__(self, storage: Any, **data: Any):
def __init__(self, storage: T, **data: Any):
super().__init__(storage=storage, **data)
def _select_storage(
self,
storage: Optional[T] = None,
memory_config: Optional[Dict[str, Any]] = None,
storage_type: str = "",
crew=None,
path: Optional[str] = None,
default_storage_factory: Optional[Callable] = None,
) -> T:
"""Helper method to select the appropriate storage based on configuration"""
# Use the provided storage if available
if storage:
return storage
# Use storage from memory_config if available
if memory_config and "storage" in memory_config:
storage_config = memory_config.get("storage", {})
if storage_type in storage_config and storage_config[storage_type]:
return cast(T, storage_config[storage_type])
# Use Mem0Storage if specified in memory_config
if memory_config and memory_config.get("provider") == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
return cast(T, Mem0Storage(type=storage_type, crew=crew))
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
# Use default storage if provided
if default_storage_factory:
return cast(T, default_storage_factory(path=path, crew=crew))
# Fallback to empty storage
raise ValueError(f"No storage available for {storage_type}")
def save(
self,
value: Any,
@@ -25,14 +67,19 @@ class Memory(BaseModel):
if agent:
metadata["agent"] = agent
self.storage.save(value, metadata)
if self.storage:
self.storage.save(value, metadata)
else:
raise ValueError("Storage is not initialized")
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
) -> List[SearchResult]:
if not self.storage:
raise ValueError("Storage is not initialized")
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -19,32 +19,43 @@ class ShortTermMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = None
memory_config = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
memory_provider = None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
storage = Mem0Storage(type="short_term", crew=crew)
else:
storage = (
storage
if storage
else RAGStorage(
memory_config = crew.memory_config
memory_provider = memory_config.get("provider")
# Initialize with basic parameters
super().__init__(
storage=storage,
embedder_config=embedder_config,
memory_provider=memory_provider
)
try:
# Try to select storage using helper method
self.storage = self._select_storage(
storage=storage,
memory_config=memory_config,
storage_type="short_term",
crew=crew,
path=path,
default_storage_factory=lambda path, crew: RAGStorage(
type="short_term",
embedder_config=embedder_config,
crew=crew,
embedder_config=embedder_config,
path=path,
)
)
super().__init__(storage=storage)
self._memory_provider = memory_provider
except ValueError:
# Fallback to default storage
self.storage = RAGStorage(
type="short_term",
crew=crew,
embedder_config=embedder_config,
path=path,
)
def save(
self,
@@ -53,7 +64,7 @@ class ShortTermMemory(Memory):
agent: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self._memory_provider == "mem0":
if self.memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
super().save(value=item.data, metadata=item.metadata, agent=item.agent)

View File

@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from crewai.memory.storage.interface import SearchResult, Storage
class BaseRAGStorage(ABC):
class BaseRAGStorage(Storage[Any], ABC):
"""
Base class for RAG-based Storage implementations.
"""
@@ -44,9 +46,8 @@ class BaseRAGStorage(ABC):
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> List[SearchResult]:
"""Search for entries in the storage."""
pass

View File

@@ -1,16 +1,39 @@
from typing import Any, Dict, List
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Generic, List, Protocol, TypeVar, TypedDict, runtime_checkable
from pydantic import BaseModel, ConfigDict
class Storage:
class SearchResult(TypedDict, total=False):
"""Type definition for search results"""
context: str
metadata: Dict[str, Any]
score: float
memory: str # For Mem0Storage compatibility
T = TypeVar('T')
@runtime_checkable
class StorageProtocol(Protocol):
"""Protocol defining the storage interface"""
def save(self, value: Any, metadata: Dict[str, Any]) -> None: ...
def search(self, query: str, limit: int, score_threshold: float) -> List[Any]: ...
def reset(self) -> None: ...
class Storage(ABC, Generic[T]):
"""Abstract base class defining the storage interface"""
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass
@abstractmethod
def search(
self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]:
return {}
) -> List[SearchResult]:
pass
@abstractmethod
def reset(self) -> None:
pass

View File

@@ -111,3 +111,9 @@ class Mem0Storage(Storage):
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return agents
def reset(self) -> None:
"""Reset the storage by clearing all memories."""
# Mem0 doesn't have a direct reset method, but we can implement
# this in the future if needed. For now, we'll just pass.
pass

View File

@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.memory.storage.interface import SearchResult
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
@@ -37,7 +38,7 @@ class RAGStorage(BaseRAGStorage):
search efficiency.
"""
app: ClientAPI | None = None
app: Optional[ClientAPI] = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
@@ -112,9 +113,8 @@ class RAGStorage(BaseRAGStorage):
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> List[SearchResult]:
if not hasattr(self, "app"):
self._initialize_app()
@@ -124,8 +124,7 @@ class RAGStorage(BaseRAGStorage):
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
result: SearchResult = {
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
@@ -138,7 +137,7 @@ class RAGStorage(BaseRAGStorage):
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()

View File

@@ -11,15 +11,46 @@ class UserMemory(Memory):
MemoryItem instances.
"""
def __init__(self, crew=None):
def __init__(self, crew=None, embedder_config=None, storage=None, path=None, **kwargs):
memory_provider = None
memory_config = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_config = crew.memory_config
memory_provider = memory_config.get("provider")
# Initialize with basic parameters
super().__init__(
storage=storage,
embedder_config=embedder_config,
memory_provider=memory_provider
)
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
# Try to select storage using helper method
from crewai.memory.storage.rag_storage import RAGStorage
self.storage = self._select_storage(
storage=storage,
memory_config=memory_config,
storage_type="user",
crew=crew,
path=path,
default_storage_factory=lambda path, crew: RAGStorage(
type="user",
crew=crew,
embedder_config=embedder_config,
path=path,
)
)
except ValueError:
# Fallback to default storage
from crewai.memory.storage.rag_storage import RAGStorage
self.storage = RAGStorage(
type="user",
crew=crew,
embedder_config=embedder_config,
path=path,
)
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
def save(
self,
@@ -43,3 +74,9 @@ class UserMemory(Memory):
score_threshold=score_threshold,
)
return results
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}")

View File

@@ -1,10 +1,8 @@
from functools import wraps
from typing import Any, Callable, Optional, Union, cast
from typing import Callable
from crewai import Crew
from crewai.project.utils import memoize
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
"""Decorators for defining crew components and their behaviors."""
@@ -23,35 +21,13 @@ def after_kickoff(func):
def task(func):
"""Marks a method as a crew task."""
setattr(func, "is_task", True)
func.is_task = True
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# Set the task name if not already set
if hasattr(result, "name") and not result.name:
if not result.name:
result.name = func.__name__
# If this is a ConditionalTask, ensure it has a valid condition
if isinstance(result, ConditionalTask):
# If the condition is a boolean, wrap it in a function
if isinstance(result.condition, bool):
bool_value = result.condition
result.condition = lambda _: bool_value
# Get the previous task output if available
self = args[0] if args else None
if self and hasattr(self, "_method_outputs"):
previous_outputs = getattr(self, "_method_outputs", [])
previous_output = previous_outputs[-1] if previous_outputs else None
# If there's a previous output and it's a TaskOutput, check if we should execute
if previous_output and isinstance(previous_output, TaskOutput):
if not result.should_execute(previous_output):
# Return a skipped task output instead of the task
return result.get_skipped_task_output()
return result
return memoize(wrapper)

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Union, cast
from typing import Any, Callable
from pydantic import Field
@@ -14,23 +14,17 @@ class ConditionalTask(Task):
"""
condition: Callable[[TaskOutput], bool] = Field(
default=lambda _: True, # Default to always execute
description="Function that determines whether the task should be executed or a boolean value.",
default=None,
description="Maximum number of retries for an agent to execute a task when an error occurs.",
)
def __init__(
self,
condition: Union[Callable[[Any], bool], bool],
condition: Callable[[Any], bool],
**kwargs,
):
super().__init__(**kwargs)
# If condition is a boolean, wrap it in a function that always returns that boolean
if isinstance(condition, bool):
bool_value = condition
self.condition = lambda _: bool_value
else:
self.condition = cast(Callable[[TaskOutput], bool], condition)
self.condition = condition
def should_execute(self, context: TaskOutput) -> bool:
"""

View File

@@ -14,12 +14,7 @@ from .agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
)
from .task_events import (
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
TaskEvaluationEvent,
)
from .task_events import TaskStartedEvent, TaskCompletedEvent, TaskFailedEvent, TaskEvaluationEvent
from .flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
@@ -39,13 +34,7 @@ from .tool_usage_events import (
ToolUsageEvent,
ToolValidateInputErrorEvent,
)
from .llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from .llm_events import LLMCallCompletedEvent, LLMCallFailedEvent, LLMCallStartedEvent
# events
from .event_listener import EventListener

View File

@@ -1,4 +1,3 @@
from io import StringIO
from typing import Any, Dict
from pydantic import Field, PrivateAttr
@@ -12,7 +11,6 @@ from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from .agent_events import AgentExecutionCompletedEvent, AgentExecutionStartedEvent
@@ -48,8 +46,6 @@ class EventListener(BaseEventListener):
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
next_chunk = 0
text_stream = StringIO()
def __new__(cls):
if cls._instance is None:
@@ -284,20 +280,9 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event: LLMCallFailedEvent):
self.logger.log(
f"❌ LLM call failed: {event.error}",
f"❌ LLM Call Failed: '{event.error}'",
event.timestamp,
)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
self.text_stream.write(event.chunk)
self.text_stream.seek(self.next_chunk)
# Read from the in-memory stream
content = self.text_stream.read()
print(content, end="", flush=True)
self.next_chunk = self.text_stream.tell()
event_listener = EventListener()

View File

@@ -23,12 +23,6 @@ from .flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from .llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from .task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -64,8 +58,4 @@ EventTypes = Union[
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolUsageStartedEvent,
LLMCallStartedEvent,
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMStreamChunkEvent,
]

View File

@@ -34,10 +34,3 @@ class LLMCallFailedEvent(CrewEvent):
error: str
type: str = "llm_call_failed"
class LLMStreamChunkEvent(CrewEvent):
"""Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk"
chunk: str

View File

@@ -18,7 +18,6 @@ from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
@@ -260,7 +259,9 @@ def test_cache_hitting():
def handle_tool_end(source, event):
received_events.append(event)
with (patch.object(CacheHandler, "read") as read,):
with (
patch.object(CacheHandler, "read") as read,
):
read.return_value = "0"
task = Task(
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,190 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent, Crew, Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
# Create mock agents for testing
researcher = Agent(
role="Researcher",
goal="Research information",
backstory="You are a researcher with expertise in finding information.",
)
writer = Agent(
role="Writer",
goal="Write content",
backstory="You are a writer with expertise in creating engaging content.",
)
def test_conditional_task_with_boolean_false():
"""Test that a conditional task with a boolean False condition is skipped."""
task1 = Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
# Use a boolean False directly as the condition
task2 = ConditionalTask(
description="Conditional task that should be skipped",
expected_output="This should not be executed",
agent=writer,
condition=False,
)
crew = Crew(
agents=[researcher, writer],
tasks=[task1, task2],
)
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task 1 description",
raw="Task 1 output",
agent="Researcher",
)
result = crew.kickoff()
# Only the first task should be executed
assert mock_execute_sync.call_count == 1
# The conditional task should be skipped
assert task2.output is not None
assert task2.output.raw == ""
# The final output should be from the first task
assert result.raw.startswith("Task 1 output")
def test_conditional_task_with_boolean_true():
"""Test that a conditional task with a boolean True condition is executed."""
task1 = Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
# Use a boolean True directly as the condition
task2 = ConditionalTask(
description="Conditional task that should be executed",
expected_output="This should be executed",
agent=writer,
condition=True,
)
crew = Crew(
agents=[researcher, writer],
tasks=[task1, task2],
)
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task output",
raw="Task output",
agent="Agent",
)
crew.kickoff()
# Both tasks should be executed
assert mock_execute_sync.call_count == 2
def test_multiple_sequential_conditional_tasks():
"""Test that multiple conditional tasks in sequence work correctly."""
task1 = Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
# First conditional task (will be executed)
task2 = ConditionalTask(
description="First conditional task",
expected_output="First conditional output",
agent=writer,
condition=True,
)
# Second conditional task (will be skipped)
task3 = ConditionalTask(
description="Second conditional task",
expected_output="Second conditional output",
agent=researcher,
condition=False,
)
# Third conditional task (will be executed)
task4 = ConditionalTask(
description="Third conditional task",
expected_output="Third conditional output",
agent=writer,
condition=True,
)
crew = Crew(
agents=[researcher, writer],
tasks=[task1, task2, task3, task4],
)
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task output",
raw="Task output",
agent="Agent",
)
result = crew.kickoff()
# Tasks 1, 2, and 4 should be executed (task 3 is skipped)
assert mock_execute_sync.call_count == 3
# Task 3 should be skipped
assert task3.output is not None
assert task3.output.raw == ""
def test_last_task_conditional():
"""Test that a conditional task at the end of the task list works correctly."""
task1 = Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
# Last task is conditional and will be skipped
task2 = ConditionalTask(
description="Last conditional task",
expected_output="Last conditional output",
agent=writer,
condition=False,
)
crew = Crew(
agents=[researcher, writer],
tasks=[task1, task2],
)
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task 1 output",
raw="Task 1 output",
agent="Researcher",
)
result = crew.kickoff()
# Only the first task should be executed
assert mock_execute_sync.call_count == 1
# The conditional task should be skipped
assert task2.output is not None
assert task2.output.raw == ""
# The final output should be from the first task
assert result.raw.startswith("Task 1 output")

View File

@@ -2,7 +2,6 @@
import hashlib
import json
import os
from concurrent.futures import Future
from unittest import mock
from unittest.mock import MagicMock, patch
@@ -36,11 +35,6 @@ from crewai.utilities.events.crew_events import (
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
# Skip streaming tests when running in CI/CD environments
skip_streaming_in_ci = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping streaming tests in CI/CD environments"
)
ceo = Agent(
role="CEO",
goal="Make sure the writers in your company produce amazing content.",
@@ -954,7 +948,6 @@ def test_api_calls_throttling(capsys):
moveon.assert_called()
@skip_streaming_in_ci
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_usage_metrics():
inputs = [
@@ -967,7 +960,6 @@ def test_crew_kickoff_usage_metrics():
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
llm=LLM(model="gpt-4o"),
)
task = Task(
@@ -976,50 +968,12 @@ def test_crew_kickoff_usage_metrics():
agent=agent,
)
# Use real LLM calls instead of mocking
crew = Crew(agents=[agent], tasks=[task])
results = crew.kickoff_for_each(inputs=inputs)
assert len(results) == len(inputs)
for result in results:
# Assert that all required keys are in usage_metrics and their values are greater than 0
assert result.token_usage.total_tokens > 0
assert result.token_usage.prompt_tokens > 0
assert result.token_usage.completion_tokens > 0
assert result.token_usage.successful_requests > 0
assert result.token_usage.cached_prompt_tokens == 0
@skip_streaming_in_ci
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_streaming_usage_metrics():
inputs = [
{"topic": "dog"},
{"topic": "cat"},
{"topic": "apple"},
]
agent = Agent(
role="{topic} Researcher",
goal="Express hot takes on {topic}.",
backstory="You have a lot of experience with {topic}.",
llm=LLM(model="gpt-4o", stream=True),
max_iter=3,
)
task = Task(
description="Give me an analysis around {topic}.",
expected_output="1 bullet point about {topic} that's under 15 words.",
agent=agent,
)
# Use real LLM calls instead of mocking
crew = Crew(agents=[agent], tasks=[task])
results = crew.kickoff_for_each(inputs=inputs)
assert len(results) == len(inputs)
for result in results:
# Assert that all required keys are in usage_metrics and their values are greater than 0
# Assert that all required keys are in usage_metrics and their values are not None
assert result.token_usage.total_tokens > 0
assert result.token_usage.prompt_tokens > 0
assert result.token_usage.completion_tokens > 0
@@ -4019,5 +3973,3 @@ def test_crew_with_knowledge_sources_works_with_copy():
assert crew_copy.knowledge_sources == crew.knowledge_sources
assert len(crew_copy.agents) == len(crew.agents)
assert len(crew_copy.tasks) == len(crew.tasks)
assert len(crew_copy.tasks) == len(crew.tasks)

View File

@@ -1,152 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent, Task
from crewai.flow import Flow, listen, start
from crewai.project.annotations import task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
# Create mock agents for testing
researcher = Agent(
role="Researcher",
goal="Research information",
backstory="You are a researcher with expertise in finding information.",
)
writer = Agent(
role="Writer",
goal="Write content",
backstory="You are a writer with expertise in creating engaging content.",
)
class TestFlowWithConditionalTasks(Flow):
"""Test flow with conditional tasks."""
@start()
@task
def initial_task(self):
"""Initial task that always executes."""
return Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
@listen(initial_task)
@task
def conditional_task_false(self):
"""Conditional task that should be skipped."""
return ConditionalTask(
description="Conditional task that should be skipped",
expected_output="This should not be executed",
agent=writer,
condition=False,
)
@listen(initial_task)
@task
def conditional_task_true(self):
"""Conditional task that should be executed."""
return ConditionalTask(
description="Conditional task that should be executed",
expected_output="This should be executed",
agent=writer,
condition=True,
)
@listen(conditional_task_true)
@task
def final_task(self):
"""Final task that executes after the conditional task."""
return Task(
description="Final task",
expected_output="Final output",
agent=researcher,
)
def test_flow_with_conditional_tasks():
"""Test that conditional tasks work correctly in a Flow."""
flow = TestFlowWithConditionalTasks()
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task output",
raw="Task output",
agent="Agent",
)
flow.kickoff()
# The initial task, conditional_task_true, and final_task should be executed
# conditional_task_false should be skipped
assert mock_execute_sync.call_count == 3
class TestFlowWithSequentialConditionalTasks(Flow):
"""Test flow with sequential conditional tasks."""
@start()
@task
def initial_task(self):
"""Initial task that always executes."""
return Task(
description="Initial task",
expected_output="Initial output",
agent=researcher,
)
@listen(initial_task)
@task
def conditional_task_1(self):
"""First conditional task that should be executed."""
return ConditionalTask(
description="First conditional task",
expected_output="First conditional output",
agent=writer,
condition=True,
)
@listen(conditional_task_1)
@task
def conditional_task_2(self):
"""Second conditional task that should be skipped."""
return ConditionalTask(
description="Second conditional task",
expected_output="Second conditional output",
agent=researcher,
condition=False,
)
@listen(conditional_task_2)
@task
def conditional_task_3(self):
"""Third conditional task that should be executed."""
return ConditionalTask(
description="Third conditional task",
expected_output="Third conditional output",
agent=writer,
condition=True,
)
def test_flow_with_sequential_conditional_tasks():
"""Test that sequential conditional tasks work correctly in a Flow."""
flow = TestFlowWithSequentialConditionalTasks()
with patch.object(Task, "execute_sync") as mock_execute_sync:
mock_execute_sync.return_value = TaskOutput(
description="Task output",
raw="Task output",
agent="Agent",
)
flow.kickoff()
# The initial_task and conditional_task_1 should be executed
# conditional_task_2 should be skipped, and since it's skipped,
# conditional_task_3 should not be triggered
assert mock_execute_sync.call_count == 2

View File

@@ -219,7 +219,7 @@ def test_get_custom_llm_provider_gemini():
def test_get_custom_llm_provider_openai():
llm = LLM(model="gpt-4")
assert llm._get_custom_llm_provider() == None
assert llm._get_custom_llm_provider() == "openai"
def test_validate_call_params_supported():
@@ -285,7 +285,6 @@ def test_o3_mini_reasoning_effort_medium():
assert isinstance(result, str)
assert "Paris" in result
def test_context_window_validation():
"""Test that context window validation works correctly."""
# Test valid window size

View File

@@ -7,7 +7,31 @@ from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@pytest.fixture
def long_term_memory():
"""Fixture to create a LongTermMemory instance"""
return LongTermMemory()
# Create a mock storage for testing
from crewai.memory.storage.interface import Storage
class MockStorage(Storage):
def __init__(self):
self.data = []
def save(self, value, metadata):
self.data.append({"value": value, "metadata": metadata})
def search(self, query, limit=3, score_threshold=0.35):
return [
{
"context": item["value"],
"metadata": item["metadata"],
"score": 0.5,
"datetime": item["metadata"].get("datetime", "test_datetime")
}
for item in self.data
]
def reset(self):
self.data = []
return LongTermMemory(storage=MockStorage())
def test_save_and_search(long_term_memory):
@@ -20,7 +44,7 @@ def test_save_and_search(long_term_memory):
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
find = long_term_memory.search("test_task", latest_n=5)[0]
find = long_term_memory.search(query="test_task", limit=5)[0]
assert find["score"] == 0.5
assert find["datetime"] == "test_datetime"
assert find["metadata"]["agent"] == "test_agent"

View File

@@ -12,6 +12,8 @@ from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
from crewai.memory.storage.rag_storage import RAGStorage
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
@@ -25,7 +27,10 @@ def short_term_memory():
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
storage = RAGStorage(type="short_term")
crew = Crew(agents=[agent], tasks=[task])
return ShortTermMemory(storage=storage, crew=crew)
def test_save_and_search(short_term_memory):

View File

@@ -0,0 +1,211 @@
from typing import Any, Dict, List
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.storage.interface import SearchResult, Storage
from crewai.memory.user.user_memory import UserMemory
class CustomStorage(Storage[Any]):
"""Custom storage implementation for testing."""
def __init__(self):
self.data = []
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self.data.append({"value": value, "metadata": metadata})
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> List[SearchResult]:
return [{"context": item["value"], "metadata": item["metadata"], "score": 0.9} for item in self.data]
def reset(self) -> None:
self.data = []
def test_custom_storage_with_short_term_memory():
"""Test that custom storage works with short term memory."""
custom_storage = CustomStorage()
memory = ShortTermMemory(storage=custom_storage)
memory.save("test value", {"key": "value"})
results = memory.search("test")
assert len(results) > 0
assert results[0]["context"] == "test value"
assert results[0]["metadata"]["key"] == "value"
def test_custom_storage_with_long_term_memory():
"""Test that custom storage works with long term memory."""
custom_storage = CustomStorage()
memory = LongTermMemory(storage=custom_storage)
memory.save("test value", {"key": "value"})
results = memory.search("test")
assert len(results) > 0
assert results[0]["context"] == "test value"
assert results[0]["metadata"]["key"] == "value"
def test_custom_storage_with_entity_memory():
"""Test that custom storage works with entity memory."""
custom_storage = CustomStorage()
memory = EntityMemory(storage=custom_storage)
memory.save("test value", {"key": "value"})
results = memory.search("test")
assert len(results) > 0
assert results[0]["context"] == "test value"
assert results[0]["metadata"]["key"] == "value"
def test_custom_storage_with_user_memory():
"""Test that custom storage works with user memory."""
custom_storage = CustomStorage()
memory = UserMemory(storage=custom_storage)
memory.save("test value", {"key": "value"})
results = memory.search("test")
assert len(results) > 0
# UserMemory prepends "Remember the details about the user: " to the value
assert "test value" in results[0]["context"]
assert results[0]["metadata"]["key"] == "value"
def test_custom_storage_with_crew():
"""Test that custom storage works with crew."""
short_term_storage = CustomStorage()
long_term_storage = CustomStorage()
entity_storage = CustomStorage()
user_storage = CustomStorage()
# Create memory instances with custom storage
short_term_memory = ShortTermMemory(storage=short_term_storage)
long_term_memory = LongTermMemory(storage=long_term_storage)
entity_memory = EntityMemory(storage=entity_storage)
user_memory = UserMemory(storage=user_storage)
# Create a crew with custom memory instances
crew = Crew(
agents=[Agent(role="test", goal="test", backstory="test")],
memory=True,
short_term_memory=short_term_memory,
long_term_memory=long_term_memory,
entity_memory=entity_memory,
memory_config={"user_memory": user_memory},
)
# Test that the crew has the custom memory instances
assert crew._short_term_memory.storage == short_term_storage
assert crew._long_term_memory.storage == long_term_storage
assert crew._entity_memory.storage == entity_storage
assert crew._user_memory.storage == user_storage
def test_custom_storage_with_memory_config():
"""Test that custom storage works with memory_config."""
short_term_storage = CustomStorage()
long_term_memory = LongTermMemory(storage=CustomStorage())
entity_memory = EntityMemory(storage=CustomStorage())
user_memory = UserMemory(storage=CustomStorage())
# Create a crew with custom storage in memory_config
crew = Crew(
agents=[Agent(role="test", goal="test", backstory="test")],
memory=True,
short_term_memory=ShortTermMemory(storage=short_term_storage),
long_term_memory=long_term_memory,
entity_memory=entity_memory,
memory_config={
"user_memory": user_memory
},
)
# Test that the crew has the custom storage instances
assert crew._short_term_memory.storage == short_term_storage
assert crew._long_term_memory == long_term_memory
assert crew._entity_memory == entity_memory
assert crew._user_memory == user_memory
def test_custom_storage_error_handling():
"""Test error handling with custom storage."""
# Test exception propagation
class ErrorStorage(Storage[Any]):
"""Storage implementation that raises exceptions."""
def __init__(self):
self.data = []
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
raise ValueError("Save error")
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> List[SearchResult]:
raise ValueError("Search error")
def reset(self) -> None:
raise ValueError("Reset error")
storage = ErrorStorage()
memory = ShortTermMemory(storage=storage)
with pytest.raises(ValueError, match="Save error"):
memory.save("test", {})
with pytest.raises(ValueError, match="Search error"):
memory.search("test")
with pytest.raises(Exception, match="An error occurred while resetting the short-term memory: Reset error"):
memory.reset()
def test_custom_storage_edge_cases():
"""Test edge cases with custom storage."""
class EdgeCaseStorage(Storage[Any]):
"""Storage implementation for testing edge cases."""
def __init__(self):
self.data = []
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self.data.append({"value": value, "metadata": metadata})
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> List[SearchResult]:
return [{"context": item["value"], "metadata": item["metadata"], "score": 0.5} for item in self.data]
def reset(self) -> None:
self.data = []
storage = EdgeCaseStorage()
memory = ShortTermMemory(storage=storage)
# Test empty query
memory.save("test value", {"key": "value"})
results = memory.search("")
assert len(results) > 0
# Test very large metadata
large_metadata = {"key" + str(i): "value" * 100 for i in range(100)}
memory.save("test value", large_metadata)
results = memory.search("test")
assert len(results) > 0
assert results[1]["metadata"] == large_metadata
# Test unicode and special characters
unicode_value = "测试值 with special chars: !@#$%^&*()"
memory.save(unicode_value, {"key": "value"})
results = memory.search("测试")
assert len(results) > 0
assert unicode_value in results[2]["context"]

View File

@@ -1,170 +0,0 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": "Tell me a short joke"}], "model":
"gpt-3.5-turbo", "stop": [], "stream": true}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '121'
content-type:
- application/json
cookie:
- _cfuvid=IY8ppO70AMHr2skDSUsGh71zqHHdCQCZ3OvkPi26NBc-1740424913267-0.0.1.1-604800000
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.65.1
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.65.1
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: 'data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Why"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
couldn"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"''t"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
the"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
bicycle"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
stand"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
up"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
by"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
itself"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
Because"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
it"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
was"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"
two"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"-t"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"ired"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-B74aE2TDl9ZbKx2fXoVatoMDnErNm","object":"chat.completion.chunk","created":1741025614,"model":"gpt-3.5-turbo-0125","service_tier":"default","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
'
headers:
CF-RAY:
- 91ab1bcbad95bcda-ATL
Connection:
- keep-alive
Content-Type:
- text/event-stream; charset=utf-8
Date:
- Mon, 03 Mar 2025 18:13:34 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=Jydtg8l0yjWRI2vKmejdq.C1W.sasIwEbTrV2rUt6V0-1741025614-1.0.1.1-Af3gmq.j2ecn9QEa3aCVY09QU4VqoW2GTk9AjvzPA.jyAZlwhJd4paniSt3kSusH0tryW03iC8uaX826hb2xzapgcfSm6Jdh_eWh_BMCh_8;
path=/; expires=Mon, 03-Mar-25 18:43:34 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=5wzaJSCvT1p1Eazad55wDvp1JsgxrlghhmmU9tx0fMs-1741025614868-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '127'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '10000'
x-ratelimit-limit-tokens:
- '50000000'
x-ratelimit-remaining-requests:
- '9999'
x-ratelimit-remaining-tokens:
- '49999978'
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_2a2a04977ace88fdd64cf570f80c0202
status:
code: 200
message: OK
version: 1

View File

@@ -1,107 +0,0 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": "Tell me a short joke"}], "model":
"gpt-4o", "stop": [], "stream": false}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '115'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.65.1
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.65.1
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFJBbtswELzrFVteerEKSZbrxpcCDuBTUfSUtigCgSZXEhuKJLirNEbg
vxeSHMtBXSAXHmZ2BjPLfU4AhNFiA0K1klUXbLpde/X1tvtW/tnfrW6//Lzb7UraLn8s2+xpJxaD
wu9/o+IX1Qflu2CRjXcTrSJKxsE1X5d5kRWrdT4SnddoB1kTOC19WmRFmWaf0uzjSdh6o5DEBn4l
AADP4ztEdBqfxAayxQvSIZFsUGzOQwAiejsgQhIZYulYLGZSecfoxtTf2wNo794zkDLo2BATcOyJ
QbLv6DNsUcmeELjFA3TyAaEPgI8YD9wa17y7NI5Y9ySHXq639oQfz0mtb0L0ezrxZ7w2zlBbRZTk
3ZCK2AcxsscE4H7cSP+qpAjRd4Er9g/oBsO8mOzE/AVXSPYs7YwX5eKKW6WRpbF0sVGhpGpRz8p5
/bLXxl8QyUXnf8Nc8556G9e8xX4mlMLAqKsQURv1uvA8FnE40P+NnXc8BhaE8dEorNhgHP5BYy17
O92OoAMxdlVtXIMxRDMdUB2qWt3UuV5ny5VIjslfAAAA//8DADx20t9JAwAA
headers:
CF-RAY:
- 91bbfc033e461d6e-ATL
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 05 Mar 2025 19:22:51 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=LecfSlhN6VGr4kTlMiMCqRPInNb1m8zOikTZxtsE_WM-1741202571-1.0.1.1-T8nh2g1PcqyLIV97_HH9Q_nSUyCtaiFAOzvMxlswn6XjJCcSLJhi_fmkbylwppwoRPTxgs4S6VsVH0mp4ZcDTABBbtemKj7vS8QRDpRrmsU;
path=/; expires=Wed, 05-Mar-25 19:52:51 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=wyMrJP5k5bgWyD8rsK4JPvAJ78JWrsrT0lyV9DP4WZM-1741202571727-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '416'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '10000'
x-ratelimit-limit-tokens:
- '30000000'
x-ratelimit-remaining-requests:
- '9999'
x-ratelimit-remaining-tokens:
- '29999978'
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_f42504d00bda0a492dced0ba3cf302d8
status:
code: 200
message: OK
version: 1

View File

@@ -1,4 +1,3 @@
import os
from datetime import datetime
from unittest.mock import Mock, patch
@@ -39,7 +38,6 @@ from crewai.utilities.events.llm_events import (
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.task_events import (
TaskCompletedEvent,
@@ -50,11 +48,6 @@ from crewai.utilities.events.tool_usage_events import (
ToolUsageErrorEvent,
)
# Skip streaming tests when running in CI/CD environments
skip_streaming_in_ci = pytest.mark.skipif(
os.getenv("CI") is not None, reason="Skipping streaming tests in CI/CD environments"
)
base_agent = Agent(
role="base_agent",
llm="gpt-4o-mini",
@@ -622,152 +615,3 @@ def test_llm_emits_call_failed_event():
assert len(received_events) == 1
assert received_events[0].type == "llm_call_failed"
assert received_events[0].error == error_message
@skip_streaming_in_ci
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_stream_chunk_events():
"""Test that LLM emits stream chunk events when streaming is enabled."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Verify that we received chunks
assert len(received_chunks) > 0
# Verify that concatenating all chunks equals the final response
assert "".join(received_chunks) == response
@skip_streaming_in_ci
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_no_stream_chunks_when_streaming_disabled():
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming disabled
llm = LLM(model="gpt-4o", stream=False)
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Verify that we didn't receive any chunks
assert len(received_chunks) == 0
# Verify we got a response
assert response and isinstance(response, str)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_fallback_to_non_streaming():
"""Test that streaming falls back to non-streaming when there's an error."""
received_chunks = []
fallback_called = False
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Store original methods
original_call = llm.call
# Create a mock call method that handles the streaming error
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
# Mark that fallback would be called
fallback_called = True
# Return a response as if fallback succeeded
return "Fallback response after streaming error"
# Replace the call method with our mock
llm.call = mock_call
try:
# Call the LLM
response = llm.call("Tell me a short joke")
# Verify that we received some chunks
assert len(received_chunks) == 2
assert received_chunks[0] == "Test chunk 1"
assert received_chunks[1] == "Test chunk 2"
# Verify fallback was triggered
assert fallback_called
# Verify we got the fallback response
assert response == "Fallback response after streaming error"
finally:
# Restore the original method
llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_empty_response_handling():
"""Test that streaming handles empty responses correctly."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-3.5-turbo", stream=True)
# Store original methods
original_call = llm.call
# Create a mock call method that simulates empty chunks
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
# Return the default message for empty responses
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# Replace the call method with our mock
llm.call = mock_call
try:
# Call the LLM - this should handle empty response
response = llm.call("Tell me a short joke")
# Verify that we received empty chunks
assert len(received_chunks) == 3
assert all(chunk == "" for chunk in received_chunks)
# Verify the response is the default message for empty responses
assert "I apologize" in response and "couldn't generate" in response
finally:
# Restore the original method
llm.call = original_call