refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

- ignore line length errors globally
- migrate knowledge/memory and crew query_knowledge to `SearchResult`
- remove legacy chromadb utils; fix empty metadata handling
- restore openai as default embedding provider; support instance-specific clients
- update and fix tests for `SearchResult` migration and rag changes
This commit is contained in:
Greyson LaLonde
2025-09-17 14:46:54 -04:00
committed by GitHub
parent 81bd81e5f5
commit f28e78c5ba
30 changed files with 1956 additions and 976 deletions

View File

@@ -131,13 +131,14 @@ select = [
"I001", # sort imports "I001", # sort imports
"I002", # remove unused imports "I002", # remove unused imports
] ]
ignore = ["E501"] # ignore line too long ignore = ["E501"] # ignore line too long globally
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101"] # Allow assert statements in tests "tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests
[tool.mypy] [tool.mypy]
exclude = ["src/crewai/cli/templates", "tests"] exclude = ["src/crewai/cli/templates", "tests/"]
[tool.bandit] [tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"] exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -3,26 +3,17 @@ import json
import re import re
import uuid import uuid
import warnings import warnings
from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import ( from typing import (
Any, Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast, cast,
) )
from opentelemetry import baggage from opentelemetry import baggage
from opentelemetry.context import attach, detach from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import ( from pydantic import (
UUID4, UUID4,
BaseModel, BaseModel,
@@ -39,26 +30,14 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow_trackable import FlowTrackable from crewai.events.event_bus import crewai_event_bus
from crewai.knowledge.knowledge import Knowledge from crewai.events.event_listener import EventListener
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.events.listeners.tracing.trace_listener import (
from crewai.llm import LLM, BaseLLM TraceCollectionListener,
from crewai.memory.entity.entity_memory import EntityMemory )
from crewai.memory.external.external_memory import ExternalMemory from crewai.events.listeners.tracing.utils import (
from crewai.memory.long_term.long_term_memory import LongTermMemory is_tracing_enabled,
from crewai.memory.short_term.short_term_memory import ShortTermMemory )
from crewai.process import Process
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.types.crew_events import ( from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent, CrewKickoffCompletedEvent,
CrewKickoffFailedEvent, CrewKickoffFailedEvent,
@@ -70,16 +49,28 @@ from crewai.events.types.crew_events import (
CrewTrainFailedEvent, CrewTrainFailedEvent,
CrewTrainStartedEvent, CrewTrainStartedEvent,
) )
from crewai.events.event_bus import crewai_event_bus from crewai.flow.flow_trackable import FlowTrackable
from crewai.events.event_listener import EventListener from crewai.knowledge.knowledge import Knowledge
from crewai.events.listeners.tracing.trace_listener import ( from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
TraceCollectionListener, from crewai.llm import LLM, BaseLLM
) from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.events.listeners.tracing.utils import ( from crewai.memory.short_term.short_term_memory import ShortTermMemory
is_tracing_enabled, from crewai.process import Process
) from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import ( from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs, aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks, aggregate_raw_outputs_from_tasks,
@@ -94,28 +85,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
class Crew(FlowTrackable, BaseModel): class Crew(FlowTrackable, BaseModel):
""" """
Represents a group of agents, defining how they should collaborate and the tasks they should perform. Represents a group of agents, defining how they should collaborate and the
tasks they should perform.
Attributes: Attributes:
tasks: List of tasks assigned to the crew. tasks: List of tasks assigned to the crew.
agents: List of agents part of this crew. agents: List of agents part of this crew.
manager_llm: The language model that will run manager agent. manager_llm: The language model that will run manager agent.
manager_agent: Custom agent that will be used as manager. manager_agent: Custom agent that will be used as manager.
memory: Whether the crew should use memory to store memories of it's execution. memory: Whether the crew should use memory to store memories of it's
cache: Whether the crew should use a cache to store the results of the tools execution. execution.
function_calling_llm: The language model that will run the tool calling for all the agents. cache: Whether the crew should use a cache to store the results of the
process: The process flow that the crew will follow (e.g., sequential, hierarchical). tools execution.
function_calling_llm: The language model that will run the tool calling
for all the agents.
process: The process flow that the crew will follow (e.g., sequential,
hierarchical).
verbose: Indicates the verbosity level for logging during execution. verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew. config: Configuration settings for the crew.
max_rpm: Maximum number of requests per minute for the crew execution to be respected. max_rpm: Maximum number of requests per minute for the crew execution to
be respected.
prompt_file: Path to the prompt json file to be used for the crew. prompt_file: Path to the prompt json file to be used for the crew.
id: A unique identifier for the crew instance. id: A unique identifier for the crew instance.
task_callback: Callback to be executed after each task for every agents execution. task_callback: Callback to be executed after each task for every agents
step_callback: Callback to be executed after each step for every agents execution. execution.
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models. step_callback: Callback to be executed after each step for every agents
execution.
share_crew: Whether you want to share the complete crew information and
execution with crewAI to make the library better, and allow us to
train models.
planning: Plan the crew execution and add the plan to the crew. planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew. chat_llm: The language model used for orchestrating chat interactions
security_config: Security configuration for the crew, including fingerprinting. with the crew.
security_config: Security configuration for the crew, including
fingerprinting.
""" """
__hash__ = object.__hash__ # type: ignore __hash__ = object.__hash__ # type: ignore
@@ -124,13 +127,13 @@ class Crew(FlowTrackable, BaseModel):
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr() _file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() _short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() _long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr() _entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr() _external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False) _train: bool | None = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr() _train_iteration: int | None = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None) _inputs: dict[str, Any] | None = PrivateAttr(default=None)
_logging_color: str = PrivateAttr( _logging_color: str = PrivateAttr(
default="bold_purple", default="bold_purple",
) )
@@ -138,107 +141,121 @@ class Crew(FlowTrackable, BaseModel):
default_factory=TaskOutputStorageHandler default_factory=TaskOutputStorageHandler
) )
name: Optional[str] = Field(default="crew") name: str | None = Field(default="crew")
cache: bool = Field(default=True) cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list) tasks: list[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list) agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential) process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False) verbose: bool = Field(default=False)
memory: bool = Field( memory: bool = Field(
default=False, default=False,
description="Whether the crew should use memory to store memories of it's execution", description="If crew should use memory to store memories of it's execution",
) )
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field( short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
default=None, default=None,
description="An Instance of the ShortTermMemory to be used by the Crew", description="An Instance of the ShortTermMemory to be used by the Crew",
) )
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field( long_term_memory: InstanceOf[LongTermMemory] | None = Field(
default=None, default=None,
description="An Instance of the LongTermMemory to be used by the Crew", description="An Instance of the LongTermMemory to be used by the Crew",
) )
entity_memory: Optional[InstanceOf[EntityMemory]] = Field( entity_memory: InstanceOf[EntityMemory] | None = Field(
default=None, default=None,
description="An Instance of the EntityMemory to be used by the Crew", description="An Instance of the EntityMemory to be used by the Crew",
) )
external_memory: Optional[InstanceOf[ExternalMemory]] = Field( external_memory: InstanceOf[ExternalMemory] | None = Field(
default=None, default=None,
description="An Instance of the ExternalMemory to be used by the Crew", description="An Instance of the ExternalMemory to be used by the Crew",
) )
embedder: Optional[dict] = Field( embedder: dict | None = Field(
default=None, default=None,
description="Configuration for the embedder to be used for the crew.", description="Configuration for the embedder to be used for the crew.",
) )
usage_metrics: Optional[UsageMetrics] = Field( usage_metrics: UsageMetrics | None = Field(
default=None, default=None,
description="Metrics for the LLM usage during all tasks execution.", description="Metrics for the LLM usage during all tasks execution.",
) )
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None description="Language model that will run the agent.", default=None
) )
manager_agent: Optional[BaseAgent] = Field( manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None description="Custom agent that will be used as manager.", default=None
) )
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
description="Language model that will run the agent.", default=None description="Language model that will run the agent.", default=None
) )
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) config: Json | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False) share_crew: bool | None = Field(default=False)
step_callback: Optional[Any] = Field( step_callback: Any | None = Field(
default=None, default=None,
description="Callback to be executed after each step for all agents execution.", description="Callback to be executed after each step for all agents execution.",
) )
task_callback: Optional[Any] = Field( task_callback: Any | None = Field(
default=None, default=None,
description="Callback to be executed after each task for all agents execution.", description="Callback to be executed after each task for all agents execution.",
) )
before_kickoff_callbacks: List[ before_kickoff_callbacks: list[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field( ] = Field(
default_factory=list, default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.", description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
) )
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field( after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list, default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.", description=(
"List of callbacks to be executed after crew kickoff. "
"It may be used to adjust the output of the crew."
),
) )
max_rpm: Optional[int] = Field( max_rpm: int | None = Field(
default=None, default=None,
description="Maximum number of requests per minute for the crew execution to be respected.", description=(
"Maximum number of requests per minute for the crew execution "
"to be respected."
),
) )
prompt_file: Optional[str] = Field( prompt_file: str | None = Field(
default=None, default=None,
description="Path to the prompt json file to be used for the crew.", description="Path to the prompt json file to be used for the crew.",
) )
output_log_file: Optional[Union[bool, str]] = Field( output_log_file: bool | str | None = Field(
default=None, default=None,
description="Path to the log file to be saved", description="Path to the log file to be saved",
) )
planning: Optional[bool] = Field( planning: bool | None = Field(
default=False, default=False,
description="Plan the crew execution and add the plan to the crew.", description="Plan the crew execution and add the plan to the crew.",
) )
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None, default=None,
description="Language model that will run the AgentPlanner if planning is True.", description=(
"Language model that will run the AgentPlanner if planning is True."
),
) )
task_execution_output_json_files: Optional[List[str]] = Field( task_execution_output_json_files: list[str] | None = Field(
default=None, default=None,
description="List of file paths for task execution JSON files.", description="List of file paths for task execution JSON files.",
) )
execution_logs: List[Dict[str, Any]] = Field( execution_logs: list[dict[str, Any]] = Field(
default=[], default=[],
description="List of execution logs for tasks", description="List of execution logs for tasks",
) )
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None, default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", description=(
"Knowledge sources for the crew. Add knowledge sources to the "
"knowledge object."
),
) )
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None, default=None,
description="LLM used to handle chatting with the crew.", description="LLM used to handle chatting with the crew.",
) )
knowledge: Optional[Knowledge] = Field( knowledge: Knowledge | None = Field(
default=None, default=None,
description="Knowledge for the crew.", description="Knowledge for the crew.",
) )
@@ -246,18 +263,18 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig, default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.", description="Security configuration for the crew, including fingerprinting.",
) )
token_usage: Optional[UsageMetrics] = Field( token_usage: UsageMetrics | None = Field(
default=None, default=None,
description="Metrics for the LLM usage during all tasks execution.", description="Metrics for the LLM usage during all tasks execution.",
) )
tracing: Optional[bool] = Field( tracing: bool | None = Field(
default=False, default=False,
description="Whether to enable tracing for the crew.", description="Whether to enable tracing for the crew.",
) )
@field_validator("id", mode="before") @field_validator("id", mode="before")
@classmethod @classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: def _deny_user_set_id(cls, v: UUID4 | None) -> None:
"""Prevent manual setting of the 'id' field by users.""" """Prevent manual setting of the 'id' field by users."""
if v: if v:
raise PydanticCustomError( raise PydanticCustomError(
@@ -266,9 +283,7 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before") @field_validator("config", mode="before")
@classmethod @classmethod
def check_config_type( def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
"""Validates that the config is a valid type. """Validates that the config is a valid type.
Args: Args:
v: The config to be validated. v: The config to be validated.
@@ -314,7 +329,8 @@ class Crew(FlowTrackable, BaseModel):
def create_crew_memory(self) -> "Crew": def create_crew_memory(self) -> "Crew":
"""Initialize private memory attributes.""" """Initialize private memory attributes."""
self._external_memory = ( self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally # External memory does not support a default value since it was
# designed to be managed entirely externally
self.external_memory.set_crew(self) if self.external_memory else None self.external_memory.set_crew(self) if self.external_memory else None
) )
@@ -355,7 +371,10 @@ class Crew(FlowTrackable, BaseModel):
if not self.manager_llm and not self.manager_agent: if not self.manager_llm and not self.manager_agent:
raise PydanticCustomError( raise PydanticCustomError(
"missing_manager_llm_or_manager_agent", "missing_manager_llm_or_manager_agent",
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.", (
"Attribute `manager_llm` or `manager_agent` is required "
"when using hierarchical process."
),
{}, {},
) )
@@ -398,7 +417,10 @@ class Crew(FlowTrackable, BaseModel):
if task.agent is None: if task.agent is None:
raise PydanticCustomError( raise PydanticCustomError(
"missing_agent_in_task", "missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" (
f"Sequential process error: Agent is missing in the task "
f"with the following description: {task.description}"
), # type: ignore # Dynamic string in error message
{}, {},
) )
@@ -459,7 +481,10 @@ class Crew(FlowTrackable, BaseModel):
if task.async_execution and isinstance(task, ConditionalTask): if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError( raise PydanticCustomError(
"invalid_async_conditional_task", "invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" (
f"Conditional Task: {task.description}, "
f"cannot be executed asynchronously."
),
{}, {},
) )
return self return self
@@ -478,7 +503,9 @@ class Crew(FlowTrackable, BaseModel):
for j in range(i - 1, -1, -1): for j in range(i - 1, -1, -1):
if self.tasks[j] == context_task: if self.tasks[j] == context_task:
raise ValueError( raise ValueError(
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context." f"Task '{task.description}' is asynchronous and "
f"cannot include other sequential asynchronous "
f"tasks in its context."
) )
if not self.tasks[j].async_execution: if not self.tasks[j].async_execution:
break break
@@ -496,13 +523,15 @@ class Crew(FlowTrackable, BaseModel):
continue # Skip context tasks not in the main tasks list continue # Skip context tasks not in the main tasks list
if task_indices[id(context_task)] > task_indices[id(task)]: if task_indices[id(context_task)] > task_indices[id(task)]:
raise ValueError( raise ValueError(
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed." f"Task '{task.description}' has a context dependency "
f"on a future task '{context_task.description}', "
f"which is not allowed."
) )
return self return self
@property @property
def key(self) -> str: def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [ source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks task.key for task in self.tasks
] ]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest() return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -518,9 +547,9 @@ class Crew(FlowTrackable, BaseModel):
return self.security_config.fingerprint return self.security_config.fingerprint
def _setup_from_config(self): def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config.""" """Initializes agents and tasks from the provided config."""
if self.config is None:
raise ValueError("Config should not be None.")
if not self.config.get("agents") or not self.config.get("tasks"): if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError( raise PydanticCustomError(
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {} "missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
@@ -530,7 +559,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]] self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]] self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task: def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration. """Creates a task instance from its configuration.
Args: Args:
@@ -559,7 +588,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file() CrewTrainingHandler(filename).initialize_file()
def train( def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
) -> None: ) -> None:
"""Trains the crew for a given number of iterations.""" """Trains the crew for a given number of iterations."""
inputs = inputs or {} inputs = inputs or {}
@@ -611,7 +640,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff( def kickoff(
self, self,
inputs: Optional[Dict[str, Any]] = None, inputs: dict[str, Any] | None = None,
) -> CrewOutput: ) -> CrewOutput:
ctx = baggage.set_baggage( ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key) "crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -682,9 +711,9 @@ class Crew(FlowTrackable, BaseModel):
finally: finally:
detach(token) detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]: def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results.""" """Executes the Crew's workflow for each input and aggregates results."""
results: List[CrewOutput] = [] results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics # Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics() total_usage_metrics = UsageMetrics()
@@ -703,14 +732,12 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset() self._task_output_handler.reset()
return results return results
async def kickoff_async( async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
self, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution.""" """Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {} inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs) return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]: async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs] crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data): async def run_crew(crew, input_data):
@@ -739,7 +766,9 @@ class Crew(FlowTrackable, BaseModel):
tasks=self.tasks, planning_agent_llm=self.planning_llm tasks=self.tasks, planning_agent_llm=self.planning_llm
)._handle_crew_planning() )._handle_crew_planning()
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): for task, step_plan in zip(
self.tasks, result.list_of_plans_per_task, strict=False
):
task.description += step_plan.plan task.description += step_plan.plan
def _store_execution_log( def _store_execution_log(
@@ -776,7 +805,7 @@ class Crew(FlowTrackable, BaseModel):
return self._execute_tasks(self.tasks) return self._execute_tasks(self.tasks)
def _run_hierarchical_process(self) -> CrewOutput: def _run_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to make sure the crew completes the tasks.""" """Creates and assigns a manager agent to complete the tasks."""
self._create_manager_agent() self._create_manager_agent()
return self._execute_tasks(self.tasks) return self._execute_tasks(self.tasks)
@@ -807,23 +836,24 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks( def _execute_tasks(
self, self,
tasks: List[Task], tasks: list[Task],
start_index: Optional[int] = 0, start_index: int | None = 0,
was_replayed: bool = False, was_replayed: bool = False,
) -> CrewOutput: ) -> CrewOutput:
"""Executes tasks sequentially and returns the final output. """Executes tasks sequentially and returns the final output.
Args: Args:
tasks (List[Task]): List of tasks to execute tasks (List[Task]): List of tasks to execute
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None. manager (Optional[BaseAgent], optional): Manager agent to use for
delegation. Defaults to None.
Returns: Returns:
CrewOutput: Final output of the crew CrewOutput: Final output of the crew
""" """
task_outputs: List[TaskOutput] = [] task_outputs: list[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = [] futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks): for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index: if start_index is not None and task_index < start_index:
@@ -838,7 +868,9 @@ class Crew(FlowTrackable, BaseModel):
agent_to_use = self._get_agent_to_use(task) agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None: if agent_to_use is None:
raise ValueError( raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided." f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
) )
# Determine which tools to use - task tools take precedence over agent tools # Determine which tools to use - task tools take precedence over agent tools
@@ -847,7 +879,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools( tools_for_task = self._prepare_tools(
agent_to_use, agent_to_use,
task, task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task), cast(list[Tool] | list[BaseTool], tools_for_task),
) )
self._log_task_start(task, agent_to_use.role) self._log_task_start(task, agent_to_use.role)
@@ -867,7 +899,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async( future = task.execute_async(
agent=agent_to_use, agent=agent_to_use,
context=context, context=context,
tools=cast(List[BaseTool], tools_for_task), tools=cast(list[BaseTool], tools_for_task),
) )
futures.append((task, future, task_index)) futures.append((task, future, task_index))
else: else:
@@ -879,7 +911,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync( task_output = task.execute_sync(
agent=agent_to_use, agent=agent_to_use,
context=context, context=context,
tools=cast(List[BaseTool], tools_for_task), tools=cast(list[BaseTool], tools_for_task),
) )
task_outputs.append(task_output) task_outputs.append(task_output)
self._process_task_result(task, task_output) self._process_task_result(task, task_output)
@@ -893,11 +925,11 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task( def _handle_conditional_task(
self, self,
task: ConditionalTask, task: ConditionalTask,
task_outputs: List[TaskOutput], task_outputs: list[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]], futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int, task_index: int,
was_replayed: bool, was_replayed: bool,
) -> Optional[TaskOutput]: ) -> TaskOutput | None:
if futures: if futures:
task_outputs = self._process_async_tasks(futures, was_replayed) task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear() futures.clear()
@@ -917,8 +949,8 @@ class Crew(FlowTrackable, BaseModel):
return None return None
def _prepare_tools( def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> List[BaseTool]: ) -> list[BaseTool]:
# Add delegation tools if agent allows delegation # Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr( if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False agent, "allow_delegation", False
@@ -947,22 +979,22 @@ class Crew(FlowTrackable, BaseModel):
): ):
tools = self._add_multimodal_tools(agent, tools) tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async # Return a List[BaseTool] compatible with Task.execute_sync and execute_async
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
if self.process == Process.hierarchical: if self.process == Process.hierarchical:
return self.manager_agent return self.manager_agent
return task.agent return task.agent
def _merge_tools( def _merge_tools(
self, self,
existing_tools: Union[List[Tool], List[BaseTool]], existing_tools: list[Tool] | list[BaseTool],
new_tools: Union[List[Tool], List[BaseTool]], new_tools: list[Tool] | list[BaseTool],
) -> List[BaseTool]: ) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name.""" """Merge new tools into existing tools list, avoiding duplicates."""
if not new_tools: if not new_tools:
return cast(List[BaseTool], existing_tools) return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools # Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools} new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +1005,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools # Add all new tools
tools.extend(new_tools) tools.extend(new_tools)
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _inject_delegation_tools( def _inject_delegation_tools(
self, self,
tools: Union[List[Tool], List[BaseTool]], tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent, task_agent: BaseAgent,
agents: List[BaseAgent], agents: list[BaseAgent],
) -> List[BaseTool]: ) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"): if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents) delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools # Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _add_multimodal_tools( def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> List[BaseTool]: ) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"): if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools() multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools # Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _add_code_execution_tools( def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> List[BaseTool]: ) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"): if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools() code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools # Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools)) return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _add_delegation_tools( def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]] self, task: Task, tools: list[Tool] | list[BaseTool]
) -> List[BaseTool]: ) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent] agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools: if not tools:
@@ -1015,7 +1047,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools( tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation tools, task.agent, agents_for_delegation
) )
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"): def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file: if self.output_log_file:
@@ -1024,8 +1056,8 @@ class Crew(FlowTrackable, BaseModel):
) )
def _update_manager_tools( def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]] self, task: Task, tools: list[Tool] | list[BaseTool]
) -> List[BaseTool]: ) -> list[BaseTool]:
if self.manager_agent: if self.manager_agent:
if task.agent: if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,18 +1065,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools( tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents tools, self.manager_agent, self.agents
) )
return cast(List[BaseTool], tools) return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str: def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context: if not task.context:
return "" return ""
context = ( return (
aggregate_raw_outputs_from_task_outputs(task_outputs) aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context is NOT_SPECIFIED if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_tasks(task.context) else aggregate_raw_outputs_from_tasks(task.context)
) )
return context
def _process_task_result(self, task: Task, output: TaskOutput) -> None: def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None" role = task.agent.role if task.agent is not None else "None"
@@ -1057,7 +1088,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw, output=output.raw,
) )
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput: def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs: if not task_outputs:
raise ValueError("No task outputs available to create crew output.") raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1119,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks( def _process_async_tasks(
self, self,
futures: List[Tuple[Task, Future[TaskOutput], int]], futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False, was_replayed: bool = False,
) -> List[TaskOutput]: ) -> list[TaskOutput]:
task_outputs: List[TaskOutput] = [] task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures: for future_task, future, task_index in futures:
task_output = future.result() task_output = future.result()
task_outputs.append(task_output) task_outputs.append(task_output)
@@ -1101,9 +1132,7 @@ class Crew(FlowTrackable, BaseModel):
) )
return task_outputs return task_outputs
def _find_task_index( def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]:
return next( return next(
( (
index index
@@ -1113,9 +1142,8 @@ class Crew(FlowTrackable, BaseModel):
None, None,
) )
def replay( def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
self, task_id: str, inputs: Optional[Dict[str, Any]] = None """Replay the crew execution from a specific task."""
) -> CrewOutput:
stored_outputs = self._task_output_handler.load() stored_outputs = self._task_output_handler.load()
if not stored_outputs: if not stored_outputs:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
@@ -1151,19 +1179,19 @@ class Crew(FlowTrackable, BaseModel):
self.tasks[i].output = task_output self.tasks[i].output = task_output
self._logging_color = "bold_blue" self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, start_index, True) return self._execute_tasks(self.tasks, start_index, True)
return result
def query_knowledge( def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]: ) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information."""
if self.knowledge: if self.knowledge:
return self.knowledge.query( return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold query, results_limit=results_limit, score_threshold=score_threshold
) )
return None return None
def fetch_inputs(self) -> Set[str]: def fetch_inputs(self) -> set[str]:
""" """
Gathers placeholders (e.g., {something}) referenced in tasks or agents. Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,11 +1200,11 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names. Returns a set of all discovered placeholder names.
""" """
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set() required_inputs: set[str] = set()
# Scan tasks for inputs # Scan tasks for inputs
for task in self.tasks: for task in self.tasks:
# description and expected_output might contain e.g. {topic}, {user_name}, etc. # description and expected_output might contain e.g. {topic}, {user_name}
text = f"{task.description or ''} {task.expected_output or ''}" text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text)) required_inputs.update(placeholder_pattern.findall(text))
@@ -1230,7 +1258,7 @@ class Crew(FlowTrackable, BaseModel):
cloned_tasks.append(cloned_task) cloned_tasks.append(cloned_task)
task_mapping[task.key] = cloned_task task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks): for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
if isinstance(original_task.context, list): if isinstance(original_task.context, list):
cloned_context = [ cloned_context = [
task_mapping[context_task.key] task_mapping[context_task.key]
@@ -1256,7 +1284,7 @@ class Crew(FlowTrackable, BaseModel):
copied_data.pop("agents", None) copied_data.pop("agents", None)
copied_data.pop("tasks", None) copied_data.pop("tasks", None)
copied_crew = Crew( return Crew(
**copied_data, **copied_data,
agents=cloned_agents, agents=cloned_agents,
tasks=cloned_tasks, tasks=cloned_tasks,
@@ -1266,15 +1294,13 @@ class Crew(FlowTrackable, BaseModel):
manager_llm=manager_llm, manager_llm=manager_llm,
) )
return copied_crew
def _set_tasks_callbacks(self) -> None: def _set_tasks_callbacks(self) -> None:
"""Sets callback for every task suing task_callback""" """Sets callback for every task suing task_callback"""
for task in self.tasks: for task in self.tasks:
if not task.callback: if not task.callback:
task.callback = self.task_callback task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None: def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents.""" """Interpolates the inputs in the tasks and agents."""
[ [
task.interpolate_inputs_and_add_conversation_history( task.interpolate_inputs_and_add_conversation_history(
@@ -1307,10 +1333,13 @@ class Crew(FlowTrackable, BaseModel):
def test( def test(
self, self,
n_iterations: int, n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]], eval_llm: str | InstanceOf[BaseLLM],
inputs: Optional[Dict[str, Any]] = None, inputs: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" """Test and evaluate the Crew with the given inputs for n iterations.
Uses concurrent.futures for concurrent execution.
"""
try: try:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator # Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm) llm_instance = create_llm(eval_llm)
@@ -1350,7 +1379,11 @@ class Crew(FlowTrackable, BaseModel):
raise raise
def __repr__(self): def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" return (
f"Crew(id={self.id}, process={self.process}, "
f"number_of_agents={len(self.agents)}, "
f"number_of_tasks={len(self.tasks)})"
)
def reset_memories(self, command_type: str) -> None: def reset_memories(self, command_type: str) -> None:
"""Reset specific or all memories for the crew. """Reset specific or all memories for the crew.
@@ -1364,7 +1397,7 @@ class Crew(FlowTrackable, BaseModel):
ValueError: If an invalid command type is provided. ValueError: If an invalid command type is provided.
RuntimeError: If memory reset operation fails. RuntimeError: If memory reset operation fails.
""" """
VALID_TYPES = frozenset( valid_types = frozenset(
[ [
"long", "long",
"short", "short",
@@ -1377,9 +1410,10 @@ class Crew(FlowTrackable, BaseModel):
] ]
) )
if command_type not in VALID_TYPES: if command_type not in valid_types:
raise ValueError( raise ValueError(
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}" f"Invalid command type. Must be one of: "
f"{', '.join(sorted(valid_types))}"
) )
try: try:
@@ -1389,7 +1423,7 @@ class Crew(FlowTrackable, BaseModel):
self._reset_specific_memory(command_type) self._reset_specific_memory(command_type)
except Exception as e: except Exception as e:
error_msg = f"Failed to reset {command_type} memory: {str(e)}" error_msg = f"Failed to reset {command_type} memory: {e!s}"
self._logger.log("error", error_msg) self._logger.log("error", error_msg)
raise RuntimeError(error_msg) from e raise RuntimeError(error_msg) from e
@@ -1397,7 +1431,7 @@ class Crew(FlowTrackable, BaseModel):
"""Reset all available memory systems.""" """Reset all available memory systems."""
memory_systems = self._get_memory_systems() memory_systems = self._get_memory_systems()
for memory_type, config in memory_systems.items(): for config in memory_systems.values():
if (system := config.get("system")) is not None: if (system := config.get("system")) is not None:
name = config.get("name") name = config.get("name")
try: try:
@@ -1405,11 +1439,13 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system) reset_fn(system)
self._logger.log( self._logger.log(
"info", "info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e ) from e
def _reset_specific_memory(self, memory_type: str) -> None: def _reset_specific_memory(self, memory_type: str) -> None:
@@ -1434,18 +1470,21 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system) reset_fn(system)
self._logger.log( self._logger.log(
"info", "info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e ) from e
def _get_memory_systems(self): def _get_memory_systems(self):
"""Get all available memory systems with their configuration. """Get all available memory systems with their configuration.
Returns: Returns:
Dict containing all memory systems with their reset functions and display names. Dict containing all memory systems with their reset functions and
display names.
""" """
def default_reset(memory): def default_reset(memory):
@@ -1506,7 +1545,7 @@ class Crew(FlowTrackable, BaseModel):
}, },
} }
def reset_knowledge(self, knowledges: List[Knowledge]) -> None: def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage.""" """Reset crew and agent knowledge storage."""
for ks in knowledges: for ks in knowledges:
ks.reset() ks.reset()

View File

@@ -1,10 +1,11 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -13,23 +14,23 @@ class Knowledge(BaseModel):
""" """
Knowledge is a collection of sources and setup for the vector store to save and query relevant context. Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args: Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: KnowledgeStorage | None = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: dict[str, Any] | None = None
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: KnowledgeStorage | None = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: dict[str, Any] | None = None
collection_name: Optional[str] = None collection_name: str | None = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: list[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None, embedder: dict[str, Any] | None = None,
storage: Optional[KnowledgeStorage] = None, storage: KnowledgeStorage | None = None,
**data, **data,
): ):
super().__init__(**data) super().__init__(**data)
@@ -40,11 +41,10 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name embedder=embedder, collection_name=collection_name
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage()
def query( def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]: ) -> list[SearchResult]:
""" """
Query across all knowledge sources to find the most relevant information. Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks. Returns the top_k most relevant chunks.
@@ -55,12 +55,11 @@ class Knowledge(BaseModel):
if self.storage is None: if self.storage is None:
raise ValueError("Storage is not initialized.") raise ValueError("Storage is not initialized.")
results = self.storage.search( return self.storage.search(
query, query,
limit=results_limit, limit=results_limit,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
return results
def add_sources(self): def add_sources(self):
try: try:

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC): class BaseKnowledgeStorage(ABC):
@@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: List[str], query: list[str],
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> list[SearchResult]:
"""Search for documents in the knowledge base.""" """Search for documents in the knowledge base."""
pass
@abstractmethod @abstractmethod
def save( def save(self, documents: list[str]) -> None:
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
"""Save documents to the knowledge base.""" """Save documents to the knowledge base."""
pass
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None:
"""Reset the knowledge base.""" """Reset the knowledge base."""
pass

View File

@@ -1,24 +1,16 @@
import hashlib
import logging import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import warnings import warnings
from typing import Any, cast
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.utilities.chromadb import sanitize_collection_name from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
class KnowledgeStorage(BaseKnowledgeStorage): class KnowledgeStorage(BaseKnowledgeStorage):
@@ -27,167 +19,101 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency. search efficiency.
""" """
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder: dict[str, Any] | None = None,
collection_name: Optional[str] = None, collection_name: str | None = None,
): ) -> None:
self.collection_name = collection_name self.collection_name = collection_name
self._set_embedder_config(embedder) self._client: BaseClient | None = None
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection:
fetched = self.collection.query(
query_texts=query,
n_results=limit,
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message=r".*'model_fields'.*is deprecated.*", message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)", module=r"^chromadb(\.|$)",
) )
self.app = create_persistent_client( if embedder:
path=os.path.join(db_storage_path(), "knowledge"), embedding_function = get_embedding_function(embedder)
settings=Settings(allow_reset=True), config = ChromaDBConfig(
) embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def search(
self,
query: list[str],
limit: int = 3,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> list[SearchResult]:
try: try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = ( collection_name = (
f"knowledge_{self.collection_name}" f"knowledge_{self.collection_name}"
if self.collection_name if self.collection_name
else "knowledge" else "knowledge"
) )
if self.app: query_text = " ".join(query) if len(query) > 1 else query[0]
self.collection = self.app.get_or_create_collection(
name=sanitize_collection_name(collection_name),
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")
except Exception:
raise Exception("Failed to create or get collection")
def reset(self): return client.search(
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) collection_name=collection_name,
if not self.app: query=query_text,
self.app = create_persistent_client( limit=limit,
path=base_path, settings=Settings(allow_reset=True) metadata_filter=metadata_filter,
score_threshold=score_threshold,
) )
self.app.reset()
shutil.rmtree(base_path)
self.app = None
self.collection = None
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
if not self.collection:
raise Exception("Collection not initialized")
try:
# Create a dictionary to store unique documents
unique_docs = {}
# Generate IDs and create a mapping of id -> (document, metadata)
for idx, doc in enumerate(documents):
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
doc_metadata = None
if metadata is not None:
if isinstance(metadata, list):
doc_metadata = metadata[idx]
else:
doc_metadata = metadata
unique_docs[doc_id] = (doc, doc_metadata)
# Prepare filtered lists for ChromaDB
filtered_docs = []
filtered_metadata = []
filtered_ids = []
# Build the filtered lists
for doc_id, (doc, meta) in unique_docs.items():
filtered_docs.append(doc)
filtered_metadata.append(meta)
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
self.collection.upsert(
documents=filtered_docs,
metadatas=final_metadata,
ids=filtered_ids,
)
except chromadb.errors.InvalidDimensionException as e:
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
except Exception as e: except Exception as e:
logging.error(f"Error during knowledge search: {e!s}")
return []
def reset(self) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
logging.error(f"Error during knowledge reset: {e!s}")
def save(self, documents: list[str]) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.get_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise raise
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
if embedder
else self._create_default_embedding_function()
)

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, List from crewai.rag.types import SearchResult
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str: def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
"""Extract knowledge from the task prompt.""" """Extract knowledge from the task prompt."""
valid_snippets = [ valid_snippets = [
result["context"] result["content"]
for result in knowledge_snippets for result in knowledge_snippets
if result and result.get("context") if result and result.get("content")
] ]
snippet = "\n".join(valid_snippets) snippet = "\n".join(valid_snippets)
return f"Additional Information: {snippet}" if valid_snippets else "" return f"Additional Information: {snippet}" if valid_snippets else ""

View File

@@ -1,4 +1,6 @@
from typing import Optional, TYPE_CHECKING from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.memory import ( from crewai.memory import (
EntityMemory, EntityMemory,
@@ -19,9 +21,9 @@ class ContextualMemory:
ltm: LongTermMemory, ltm: LongTermMemory,
em: EntityMemory, em: EntityMemory,
exm: ExternalMemory, exm: ExternalMemory,
agent: Optional["Agent"] = None, agent: Agent | None = None,
task: Optional["Task"] = None, task: Task | None = None,
): ) -> None:
self.stm = stm self.stm = stm
self.ltm = ltm self.ltm = ltm
self.em = em self.em = em
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent self.exm.agent = self.agent
self.exm.task = self.task self.exm.task = self.task
def build_context_for_task(self, task, context) -> str: def build_context_for_task(self, task: Task, context: str) -> str:
""" """
Automatically builds a minimal, highly relevant set of contextual information Automatically builds a minimal, highly relevant set of contextual information
for a given task. for a given task.
@@ -52,14 +54,15 @@ class ContextualMemory:
if query == "": if query == "":
return "" return ""
context = [] context_parts = [
context.append(self._fetch_ltm_context(task.description)) self._fetch_ltm_context(task.description),
context.append(self._fetch_stm_context(query)) self._fetch_stm_context(query),
context.append(self._fetch_entity_context(query)) self._fetch_entity_context(query),
context.append(self._fetch_external_context(query)) self._fetch_external_context(query),
return "\n".join(filter(None, context)) ]
return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str: def _fetch_stm_context(self, query: str) -> str:
""" """
Fetches recent relevant insights from STM related to the task's description and expected_output, Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -70,11 +73,11 @@ class ContextualMemory:
stm_results = self.stm.search(query) stm_results = self.stm.search(query)
formatted_results = "\n".join( formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results] [f"- {result['content']}" for result in stm_results]
) )
return f"Recent Insights:\n{formatted_results}" if stm_results else "" return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]: def _fetch_ltm_context(self, task: str) -> str | None:
""" """
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -90,14 +93,14 @@ class ContextualMemory:
formatted_results = [ formatted_results = [
suggestion suggestion
for result in ltm_results for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" for suggestion in result["metadata"]["suggestions"]
] ]
formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else "" return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str: def _fetch_entity_context(self, query: str) -> str:
""" """
Fetches relevant entity information from Entity Memory related to the task's description and expected_output, Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -107,7 +110,7 @@ class ContextualMemory:
em_results = self.em.search(query) em_results = self.em.search(query)
formatted_results = "\n".join( formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" [f"- {result['content']}" for result in em_results]
) )
return f"Entities:\n{formatted_results}" if em_results else "" return f"Entities:\n{formatted_results}" if em_results else ""
@@ -128,6 +131,6 @@ class ContextualMemory:
return "" return ""
formatted_memories = "\n".join( formatted_memories = "\n".join(
f"- {result['context']}" for result in external_memories f"- {result['content']}" for result in external_memories
) )
return f"External memories:\n{formatted_memories}" return f"External memories:\n{formatted_memories}"

View File

@@ -1,12 +1,13 @@
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterable from collections.abc import Iterable
from typing import Any
from mem0 import Memory, MemoryClient # type: ignore[import-untyped] from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name from crewai.rag.chromadb.utils import _sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255 MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -15,6 +16,7 @@ class Mem0Storage(Storage):
""" """
Extends Storage to handle embedding and searching across entities using Mem0. Extends Storage to handle embedding and searching across entities using Mem0.
""" """
def __init__(self, type, crew=None, config=None): def __init__(self, type, crew=None, config=None):
super().__init__() super().__init__()
@@ -30,7 +32,8 @@ class Mem0Storage(Storage):
supported_types = {"short_term", "long_term", "entities", "external"} supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types: if type not in supported_types:
raise ValueError( raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}" f"Invalid type '{type}' for Mem0Storage. "
f"Must be one of: {', '.join(supported_types)}"
) )
def _extract_config_values(self): def _extract_config_values(self):
@@ -68,7 +71,8 @@ class Mem0Storage(Storage):
- Includes user_id and agent_id if both are present. - Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present. - Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present. - Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present. - Includes run_id if memory_type is 'short_term' and
mem0_run_id is present.
""" """
filter = defaultdict(list) filter = defaultdict(list)
@@ -91,10 +95,14 @@ class Mem0Storage(Storage):
def save(self, value: Any, metadata: dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
return next( return next(
(m.get("content", "") for m in reversed(list(messages)) if m.get("role") == role), (
"" m.get("content", "")
for m in reversed(list(messages))
if m.get("role") == role
),
"",
) )
conversations = [] conversations = []
messages = metadata.pop("messages", None) messages = metadata.pop("messages", None)
if messages: if messages:
@@ -103,7 +111,7 @@ class Mem0Storage(Storage):
if user_msg := self._get_user_message(last_user): if user_msg := self._get_user_message(last_user):
conversations.append({"role": "user", "content": user_msg}) conversations.append({"role": "user", "content": user_msg})
if assistant_msg := self._get_assistant_message(last_assistant): if assistant_msg := self._get_assistant_message(last_assistant):
conversations.append({"role": "assistant", "content": assistant_msg}) conversations.append({"role": "assistant", "content": assistant_msg})
else: else:
@@ -115,13 +123,13 @@ class Mem0Storage(Storage):
"short_term": "short_term", "short_term": "short_term",
"long_term": "long_term", "long_term": "long_term",
"entities": "entity", "entities": "entity",
"external": "external" "external": "external",
} }
# Shared base params # Shared base params
params: dict[str, Any] = { params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata}, "metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer "infer": self.infer,
} }
# MemoryClient-specific overrides # MemoryClient-specific overrides
@@ -142,13 +150,15 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params) self.memory.add(conversations, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> list[Any]: def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> list[Any]:
params = { params = {
"query": query, "query": query,
"limit": limit, "limit": limit,
"version": "v2", "version": "v2",
"output_format": "v1.1" "output_format": "v1.1",
} }
if user_id := self.config.get("user_id", ""): if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id params["user_id"] = user_id
@@ -169,10 +179,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created. # automatically when the crew is created.
params["filters"] = self._create_filter_for_search() params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold params["threshold"] = score_threshold
if isinstance(self.memory, Memory): if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format'] del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"): if params.get("run_id"):
del params["run_id"] del params["run_id"]
@@ -180,7 +190,7 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve # This makes it compatible for Contextual Memory to retrieve
for result in results["results"]: for result in results["results"]:
result["context"] = result["memory"] result["content"] = result["memory"]
return [r for r in results["results"]] return [r for r in results["results"]]
@@ -201,7 +211,9 @@ class Mem0Storage(Storage):
agents = self.crew.agents agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents) agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0) return _sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)
def _get_assistant_message(self, text: str) -> str: def _get_assistant_message(self, text: str) -> str:
marker = "Final Answer:" marker = "Final Answer:"

View File

@@ -1,17 +1,16 @@
import logging import logging
import os import warnings
import shutil from typing import Any
import uuid
from typing import Any, Dict, List, Optional from crewai.rag.chromadb.config import ChromaDBConfig
from chromadb.api import ClientAPI from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.rag.types import BaseRecord
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
@@ -20,8 +19,6 @@ class RAGStorage(BaseRAGStorage):
search efficiency. search efficiency.
""" """
app: ClientAPI | None = None
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self, type, allow_reset=True, embedder_config=None, crew=None, path=None
): ):
@@ -33,37 +30,25 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents) self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type self.type = type
self._client: BaseClient | None = None
self.allow_reset = allow_reset self.allow_reset = allow_reset
self.path = path self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message=r".*'model_fields'.*is deprecated.*", message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)", module=r"^chromadb(\.|$)",
) )
self._set_embedder_config() if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function)
self._client = create_client(config)
self.app = create_persistent_client( def _get_client(self) -> BaseClient:
path=self.path if self.path else self.storage_file_name, """Get the appropriate client - instance-specific or global."""
settings=Settings(allow_reset=self.allow_reset), return self._client if self._client else get_rag_client()
)
self.collection = self.app.get_or_create_collection(
name=self.type, embedding_function=self.embedder_config
)
logging.info(f"Collection found or created: {self.collection}")
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
""" """
@@ -85,77 +70,65 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}" return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try: try:
self._generate_embedding(value, metadata) client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.get_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}") logging.error(f"Error during {self.type} save: {e!s}")
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
if not hasattr(self, "app"):
self._initialize_app()
try: try:
with suppress_logging( client = self._get_client()
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR collection_name = (
): f"memory_{self.type}_{self.agents}"
response = self.collection.query(query_texts=query, n_results=limit) if self.agents
else f"memory_{self.type}"
results = [] )
for i in range(len(response["ids"][0])): return client.search(
result = { collection_name=collection_name,
"id": response["ids"][0][i], query=query,
"metadata": response["metadatas"][0][i], limit=limit,
"context": response["documents"][0][i], metadata_filter=filter,
"score": response["distances"][0][i], score_threshold=score_threshold,
} )
if result["score"] >= score_threshold:
results.append(result)
return results
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}") logging.error(f"Error during {self.type} search: {e!s}")
return [] return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
def reset(self) -> None: def reset(self) -> None:
try: try:
if self.app: client = self._get_client()
self.app.reset() collection_name = (
shutil.rmtree(f"{db_storage_path()}/{self.type}") f"memory_{self.type}_{self.agents}"
self.app = None if self.agents
self.collection = None else f"memory_{self.type}"
)
client.delete_collection(collection_name=collection_name)
except Exception as e: except Exception as e:
if "attempt to write a readonly database" in str(e): if "attempt to write a readonly database" in str(
# Ignore this specific error e
) or "does not exist" in str(e):
# Ignore readonly database and collection not found errors (already reset)
pass pass
else: else:
raise Exception( raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}" f"An error occurred while resetting the {self.type} memory: {e}"
) ) from e
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)

View File

@@ -4,8 +4,9 @@ import logging
from typing import Any from typing import Any
from chromadb.api.types import ( from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction, EmbeddingFunction as ChromaEmbeddingFunction,
)
from chromadb.api.types import (
QueryResult, QueryResult,
) )
from typing_extensions import Unpack from typing_extensions import Unpack
@@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import (
_process_query_results, _process_query_results,
_sanitize_collection_name, _sanitize_collection_name,
) )
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import ( from crewai.rag.core.base_client import (
BaseClient, BaseClient,
BaseCollectionParams,
BaseCollectionAddParams, BaseCollectionAddParams,
BaseCollectionParams,
) )
from crewai.rag.types import SearchResult from crewai.rag.types import SearchResult
from crewai.utilities.logger_utils import suppress_logging
class ChromaDBClient(BaseClient): class ChromaDBClient(BaseClient):
@@ -46,7 +47,7 @@ class ChromaDBClient(BaseClient):
def __init__( def __init__(
self, self,
client: ChromaDBClientType, client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable], embedding_function: ChromaEmbeddingFunction,
) -> None: ) -> None:
"""Initialize ChromaDBClient with client and embedding function. """Initialize ChromaDBClient with client and embedding function.
@@ -306,10 +307,12 @@ class ChromaDBClient(BaseClient):
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert( collection.upsert(
ids=prepared.ids, ids=prepared.ids,
documents=prepared.texts, documents=prepared.texts,
metadatas=prepared.metadatas, metadatas=metadatas,
) )
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -347,10 +350,12 @@ class ChromaDBClient(BaseClient):
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert( await collection.upsert(
ids=prepared.ids, ids=prepared.ids,
documents=prepared.texts, documents=prepared.texts,
metadatas=prepared.metadatas, metadatas=metadatas,
) )
def search( def search(

View File

@@ -3,18 +3,18 @@
import warnings import warnings
from dataclasses import field from dataclasses import field
from typing import Literal, cast from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import ( from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE, DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH, DEFAULT_STORAGE_PATH,
DEFAULT_TENANT,
) )
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",

View File

@@ -2,11 +2,12 @@
import os import os
from hashlib import md5 from hashlib import md5
import portalocker import portalocker
from chromadb import PersistentClient from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
def create_client(config: ChromaDBConfig) -> ChromaDBClient: def create_client(config: ChromaDBConfig) -> ChromaDBClient:
@@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
""" """
persist_dir = config.settings.persist_directory persist_dir = config.settings.persist_directory
os.makedirs(persist_dir, exist_ok=True)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest() lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock") lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")

View File

@@ -3,27 +3,28 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, NamedTuple from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler from chromadb.api import AsyncClientAPI, ClientAPI
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import ( from chromadb.api.types import (
CollectionMetadata, CollectionMetadata,
DataLoader, DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include, Include,
Loadable, Loadable,
Where, Where,
WhereDocument, WhereDocument,
) )
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]): class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation.""" """Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod @classmethod
@@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple):
Attributes: Attributes:
ids: List of document IDs ids: List of document IDs
texts: List of document texts texts: List of document texts
metadatas: List of document metadata mappings metadatas: List of document metadata mappings (empty dict for no metadata)
""" """
ids: list[str] ids: list[str]
@@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
configuration: CollectionConfigurationInterface configuration: CollectionConfigurationInterface
metadata: CollectionMetadata metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable] embedding_function: ChromaEmbeddingFunction
data_loader: DataLoader[Loadable] data_loader: DataLoader[Loadable]
get_or_create: bool get_or_create: bool

View File

@@ -5,13 +5,14 @@ from collections.abc import Mapping
from typing import Literal, TypeGuard, cast from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import ( from chromadb.api.types import (
Include, Include,
IncludeEnum, IncludeEnum,
QueryResult, QueryResult,
) )
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import ( from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION, DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN, INVALID_CHARS_PATTERN,
@@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb(
metadata = doc.get("metadata") metadata = doc.get("metadata")
if metadata: if metadata:
if isinstance(metadata, list): if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {}) metadatas.append(metadata[0] if metadata and metadata[0] else {})
else: else:
metadatas.append(metadata) metadatas.append(metadata)
else: else:
@@ -154,7 +155,7 @@ def _convert_chromadb_results_to_search_results(
""" """
search_results: list[SearchResult] = [] search_results: list[SearchResult] = []
include_strings = [item.value for item in include] include_strings = [item.value for item in include] if include else []
ids = results["ids"][0] if results.get("ids") else [] ids = results["ids"][0] if results.get("ids") else []
@@ -188,7 +189,9 @@ def _convert_chromadb_results_to_search_results(
result: SearchResult = { result: SearchResult = {
"id": doc_id, "id": doc_id,
"content": documents[i] if documents and i < len(documents) else "", "content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {}, "metadata": dict(metadatas[i])
if metadatas and i < len(metadatas) and metadatas[i] is not None
else {},
"score": score, "score": score,
} }
search_results.append(result) search_results.append(result)
@@ -271,7 +274,7 @@ def _sanitize_collection_name(
sanitized = sanitized[:-1] + "z" sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH: if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized)) sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length: if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length] sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum(): if not sanitized[-1].isalnum():

View File

@@ -1,15 +1,15 @@
"""Protocol for vector database client implementations.""" """Protocol for vector database client implementations."""
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated from typing import Annotated, Any, Protocol, runtime_checkable
from typing_extensions import Unpack, Required, TypedDict
from pydantic import GetCoreSchemaHandler from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from typing_extensions import Required, TypedDict, Unpack
from crewai.rag.types import ( from crewai.rag.types import (
EmbeddingFunction,
BaseRecord, BaseRecord,
EmbeddingFunction,
SearchResult, SearchResult,
) )
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
query: Required[str] query: Required[str]
limit: int limit: int
metadata_filter: dict[str, Any] metadata_filter: dict[str, Any] | None
score_threshold: float score_threshold: float

View File

@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction, CohereEmbeddingFunction,
) )
from chromadb.utils.embedding_functions.google_embedding_function import ( from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction, GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction, GoogleVertexEmbeddingFunction,
) )
from chromadb.utils.embedding_functions.huggingface_embedding_function import ( from chromadb.utils.embedding_functions.huggingface_embedding_function import (
@@ -60,7 +60,7 @@ def get_embedding_function(
EmbeddingFunction instance ready for use with ChromaDB EmbeddingFunction instance ready for use with ChromaDB
Supported providers: Supported providers:
- openai: OpenAI embeddings (default) - openai: OpenAI embeddings
- cohere: Cohere embeddings - cohere: Cohere embeddings
- ollama: Ollama local embeddings - ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings - huggingface: HuggingFace embeddings
@@ -77,7 +77,7 @@ def get_embedding_function(
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB) - onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples: Examples:
# Use default OpenAI with retry logic # Use default OpenAI embedding
>>> embedder = get_embedding_function() >>> embedder = get_embedding_function()
# Use Cohere with dict # Use Cohere with dict

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
class BaseRAGStorage(ABC): class BaseRAGStorage(ABC):
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None, embedder_config: dict[str, Any] | None = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type
@@ -32,45 +32,21 @@ class BaseRAGStorage(ABC):
@abstractmethod @abstractmethod
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names.""" """Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod @abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage.""" """Save a value with metadata to the storage."""
pass
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
"""Search for entries in the storage.""" """Search for entries in the storage."""
pass
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None:
"""Reset the storage.""" """Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass

View File

@@ -1,83 +0,0 @@
import os
import re
import portalocker
from chromadb import PersistentClient
from hashlib import md5
from typing import Optional
from crewai.utilities.paths import db_storage_path
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized
def create_persistent_client(path: str, **kwargs):
"""
Creates a persistent client for ChromaDB with a lock file to prevent
concurrent creations. Works for both multi-threads and multi-processes
environments.
"""
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(path=path, **kwargs)
return client

View File

@@ -9,19 +9,19 @@ import pytest
from crewai import Agent, Crew, Task from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.process import Process
from crewai.tools import tool from crewai.tools import tool
from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController from crewai.utilities import RPMController
from crewai.utilities.errors import AgentRepositoryError from crewai.utilities.errors import AgentRepositoryError
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.process import Process
def test_agent_llm_creation_with_env_vars(): def test_agent_llm_creation_with_env_vars():
@@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_powered_by_new_o_model_family_that_uses_tool(): def test_agent_powered_by_new_o_model_family_that_uses_tool():
@tool @tool
def comapny_customer_data() -> float: def comapny_customer_data() -> str:
"""Useful for getting customer related data.""" """Useful for getting customer related data."""
return "The company has 42 customers" return "The company has 42 customers"
@@ -559,9 +559,9 @@ def test_agent_repeated_tool_usage(capsys):
expected_message = ( expected_message = (
"I tried reusing the same input, I must stop using this action input." "I tried reusing the same input, I must stop using this action input."
) )
assert ( assert expected_message in output, (
expected_message in output f"Expected message not found in output. Output was: {output}"
), f"Expected message not found in output. Output was: {output}" )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -602,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
has_max_iterations = "maximum iterations reached" in output_lower has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out has_final_answer = "final answer" in output_lower or "42" in captured.out
assert ( assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
has_repeated_usage_message or (has_max_iterations and has_final_answer) f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..." )
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -880,7 +880,7 @@ def test_agent_step_callback():
with patch.object(StepCallback, "callback") as callback: with patch.object(StepCallback, "callback") as callback:
@tool @tool
def learn_about_AI() -> str: def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it.""" """Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field." return "AI is a very broad field."
@@ -888,7 +888,7 @@ def test_agent_step_callback():
role="test role", role="test role",
goal="test goal", goal="test goal",
backstory="test backstory", backstory="test backstory",
tools=[learn_about_AI], tools=[learn_about_ai],
step_callback=StepCallback().callback, step_callback=StepCallback().callback,
) )
@@ -910,7 +910,7 @@ def test_agent_function_calling_llm():
llm = "gpt-4o" llm = "gpt-4o"
@tool @tool
def learn_about_AI() -> str: def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it.""" """Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field." return "AI is a very broad field."
@@ -918,7 +918,7 @@ def test_agent_function_calling_llm():
role="test role", role="test role",
goal="test goal", goal="test goal",
backstory="test backstory", backstory="test backstory",
tools=[learn_about_AI], tools=[learn_about_ai],
llm="gpt-4o", llm="gpt-4o",
max_iter=2, max_iter=2,
function_calling_llm=llm, function_calling_llm=llm,
@@ -1356,7 +1356,7 @@ def test_agent_training_handler(crew_training_handler):
verbose=True, verbose=True,
) )
crew_training_handler().load.return_value = { crew_training_handler().load.return_value = {
f"{str(agent.id)}": {"0": {"human_feedback": "good"}} f"{agent.id!s}": {"0": {"human_feedback": "good"}}
} }
result = agent._training_handler(task_prompt=task_prompt) result = agent._training_handler(task_prompt=task_prompt)
@@ -1473,7 +1473,7 @@ def test_agent_with_custom_stop_words():
) )
assert isinstance(agent.llm, LLM) assert isinstance(agent.llm, LLM)
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"]) assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
assert all(word in agent.llm.stop for word in stop_words) assert all(word in agent.llm.stop for word in stop_words)
assert "\nObservation:" in agent.llm.stop assert "\nObservation:" in agent.llm.stop
@@ -1530,7 +1530,7 @@ def test_llm_call_with_error():
llm = LLM(model="non-existent-model") llm = LLM(model="non-existent-model")
messages = [{"role": "user", "content": "This should fail"}] messages = [{"role": "user", "content": "This should fail"}]
with pytest.raises(Exception): with pytest.raises(Exception): # noqa: B017
llm.call(messages) llm.call(messages)
@@ -1830,11 +1830,11 @@ def test_agent_execute_task_with_ollama():
def test_agent_with_knowledge_sources(): def test_agent_with_knowledge_sources():
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge: with patch("crewai.knowledge") as mock_knowledge:
mock_knowledge_instance = MockKnowledge.return_value mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.search.return_value = [{"content": content}] mock_knowledge_instance.search.return_value = [{"content": content}]
MockKnowledge.add_sources.return_value = [string_source] mock_knowledge.add_sources.return_value = [string_source]
agent = Agent( agent = Agent(
role="Information Agent", role="Information Agent",
@@ -1863,12 +1863,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5) knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
with patch( with (
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" patch(
) as MockKnowledge: "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
mock_knowledge_instance = MockKnowledge.return_value ) as mock_knowledge_storage,
mock_knowledge_instance.sources = [string_source] patch(
mock_knowledge_instance.query.return_value = [{"content": content}] "crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query: with patch.object(Knowledge, "query") as mock_knowledge_query:
agent = Agent( agent = Agent(
role="Information Agent", role="Information Agent",
@@ -1898,15 +1911,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig() knowledge_config = KnowledgeConfig()
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" with (
) as MockKnowledge: patch(
mock_knowledge_instance = MockKnowledge.return_value "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
mock_knowledge_instance.sources = [string_source] ) as mock_knowledge_storage,
mock_knowledge_instance.query.return_value = [{"content": content}] patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query: with patch.object(Knowledge, "query") as mock_knowledge_query:
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
agent = Agent( agent = Agent(
role="Information Agent", role="Information Agent",
goal="Provide information based on knowledge sources", goal="Provide information based on knowledge sources",
@@ -1935,10 +1960,16 @@ def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge: with (
mock_knowledge_instance = MockKnowledge.return_value patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
) as mock_save,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}] mock_knowledge_instance.query.return_value = [{"content": content}]
mock_save.return_value = None
agent = Agent( agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters", role="Information Agent with extensive role description that is longer than 80 characters",
@@ -1968,8 +1999,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch( with patch(
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource", "crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
autospec=True, autospec=True,
) as MockKnowledgeSource: ) as mock_knowledge_source:
mock_knowledge_source_instance = MockKnowledgeSource.return_value mock_knowledge_source_instance = mock_knowledge_source.return_value
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
mock_knowledge_source_instance.sources = [string_source] mock_knowledge_source_instance.sources = [string_source]
@@ -1983,9 +2014,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch( with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledgeStorage: ) as mock_knowledge_storage:
mock_knowledge_storage = MockKnowledgeStorage.return_value mock_knowledge_storage_instance = mock_knowledge_storage.return_value
agent.knowledge_storage = mock_knowledge_storage agent.knowledge_storage = mock_knowledge_storage_instance
agent_copy = agent.copy() agent_copy = agent.copy()
@@ -2004,11 +2035,30 @@ def test_agent_with_knowledge_sources_generate_search_query():
content = "Brandon's favorite color is red and he likes Mexican food." content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content) string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge: with (
mock_knowledge_instance = MockKnowledge.return_value patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}] mock_knowledge_instance.query.return_value = [{"content": content}]
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
agent = Agent( agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters", role="Information Agent with extensive role description that is longer than 80 characters",
goal="Provide information based on knowledge sources", goal="Provide information based on knowledge sources",
@@ -2270,7 +2320,26 @@ def test_get_knowledge_search_query():
i18n = I18N() i18n = I18N()
task_prompt = task.prompt() task_prompt = task.prompt()
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query: with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
mock_get_query.return_value = "Capital of France" mock_get_query.return_value = "Capital of France"
crew = Crew(agents=[agent], tasks=[task]) crew = Crew(agents=[agent], tasks=[task])
@@ -2312,9 +2381,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
# Mock embedchain initialization to prevent race conditions in parallel CI execution # Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"): with patch("embedchain.client.Client.setup"):
from crewai_tools import ( from crewai_tools import (
SerperDevTool,
FileReadTool,
EnterpriseActionTool, EnterpriseActionTool,
FileReadTool,
SerperDevTool,
) )
mock_get_response = MagicMock() mock_get_response = MagicMock()
@@ -2347,7 +2416,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
tool_action = EnterpriseActionTool( tool_action = EnterpriseActionTool(
name="test_name", name="test_name",
description="test_description", description="test_description",
enterprise_action_token="test_token", enterprise_action_token="test_token", # noqa: S106
action_name="test_action_name", action_name="test_action_name",
action_schema={"test": "test"}, action_schema={"test": "test"},
) )

View File

@@ -1,7 +1,6 @@
"""Test Knowledge creation and querying functionality.""" """Test Knowledge creation and querying functionality."""
from pathlib import Path from pathlib import Path
from typing import List, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -23,7 +22,7 @@ def mock_vector_db():
instance = mock.return_value instance = mock.return_value
instance.query.return_value = [ instance.query.return_value = [
{ {
"context": "Brandon's favorite color is blue and he likes Mexican food.", "content": "Brandon's favorite color is blue and he likes Mexican food.",
"score": 0.9, "score": 0.9,
} }
] ]
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
content=content, metadata={"preference": "personal"} content=content, metadata={"preference": "personal"}
) )
mock_vector_db.sources = [string_source] mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query # Perform a query
query = "What is Brandon's favorite color?" query = "What is Brandon's favorite color?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("blue" in result["context"].lower() for result in results) assert any("blue" in result["content"].lower() for result in results)
# Verify the mock was called # Verify the mock was called
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
content=content, metadata={"preference": "personal"} content=content, metadata={"preference": "personal"}
) )
mock_vector_db.sources = [string_source] mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query # Perform a query
query = "What is Brandon's favorite movie?" query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results) assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
# Mock the vector db query response # Mock the vector db query response
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "Brandon has a dog named Max.", "score": 0.9} {"content": "Brandon has a dog named Max.", "score": 0.9}
] ]
mock_vector_db.sources = string_sources mock_vector_db.sources = string_sources
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("max" in result["context"].lower() for result in results) assert any("max" in result["content"].lower() for result in results)
# Verify the mock was called # Verify the mock was called
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
] ]
mock_vector_db.sources = string_sources mock_vector_db.sources = string_sources
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}] mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
# Perform a query # Perform a query
query = "What is Brandon's favorite book?" query = "What is Brandon's favorite book?"
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower() "the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results for result in results
) )
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"} file_paths=[file_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [file_source] mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query # Perform a query
query = "What sport does Brandon like?" query = "What sport does Brandon like?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("basketball" in result["context"].lower() for result in results) assert any("basketball" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"} file_paths=[file_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [file_source] mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query # Perform a query
query = "What is Brandon's favorite movie?" query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the results contain the expected information # Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results) assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
] ]
mock_vector_db.sources = file_sources mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "Brandon lives in New York.", "score": 0.9} {"content": "Brandon lives in New York.", "score": 0.9}
] ]
# Perform a query # Perform a query
query = "What city does he reside in?" query = "What city does he reside in?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("new york" in result["context"].lower() for result in results) assert any("new york" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
mock_vector_db.sources = file_sources mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{ {
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.", "content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"score": 0.9, "score": 0.9,
} }
] ]
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower() "the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results for result in results
) )
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
# Combine string and file sources # Combine string and file sources
mock_vector_db.sources = string_sources + file_sources mock_vector_db.sources = string_sources + file_sources
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}] mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
# Perform a query # Perform a query
query = "What is Brandon's favorite book?" query = "What is Brandon's favorite book?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("the alchemist" in result["context"].lower() for result in results) assert any("the alchemist" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
) )
mock_vector_db.sources = [pdf_source] mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "crewai create crew latest-ai-development", "score": 0.9} {"content": "crewai create crew latest-ai-development", "score": 0.9}
] ]
# Perform a query # Perform a query
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any( assert any(
"crewai create crew latest-ai-development" in result["context"].lower() "crewai create crew latest-ai-development" in result["content"].lower()
for result in results for result in results
) )
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
) )
mock_vector_db.sources = [csv_source] mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9} {"content": "Brandon is 30 years old.", "score": 0.9}
] ]
# Perform a query # Perform a query
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results) assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
) )
mock_vector_db.sources = [json_source] mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "Alice lives in Los Angeles.", "score": 0.9} {"content": "Alice lives in Los Angeles.", "score": 0.9}
] ]
# Perform a query # Perform a query
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("los angeles" in result["context"].lower() for result in results) assert any("los angeles" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
"""Test ExcelKnowledgeSource with a simple Excel file.""" """Test ExcelKnowledgeSource with a simple Excel file."""
# Create an Excel file with sample data # Create an Excel file with sample data
import pandas as pd import pandas as pd # type: ignore[import-untyped]
excel_data = { excel_data = {
"Name": ["Brandon", "Alice", "Bob"], "Name": ["Brandon", "Alice", "Bob"],
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
) )
mock_vector_db.sources = [excel_source] mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9} {"content": "Brandon is 30 years old.", "score": 0.9}
] ]
# Perform a query # Perform a query
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results) assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
mock_vector_db.sources = [docling_source] mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
{ {
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.", "content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9, "score": 0.9,
} }
] ]
# Perform a query # Perform a query
query = "What is reward hacking?" query = "What is reward hacking?"
results = mock_vector_db.query(query) results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results) assert any("reward hacking" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
@pytest.mark.vcr @pytest.mark.vcr
def test_multiple_docling_sources(): def test_multiple_docling_sources() -> None:
urls: List[Union[Path, str]] = [ urls: list[Path | str] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/", "https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/", "https://lilianweng.github.io/posts/2024-07-07-hallucination/",
] ]

View File

@@ -0,0 +1,191 @@
"""Tests for Knowledge SearchResult type conversion and integration."""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
StringKnowledgeSource,
)
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
extract_knowledge_context,
)
def test_knowledge_query_returns_searchresult() -> None:
"""Test that Knowledge.query returns SearchResult format."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = [
{
"content": "AI is fascinating",
"score": 0.9,
"metadata": {"source": "doc1"},
},
{
"content": "Machine learning rocks",
"score": 0.8,
"metadata": {"source": "doc2"},
},
]
sources = [StringKnowledgeSource(content="Test knowledge content")]
knowledge = Knowledge(collection_name="test_collection", sources=sources)
results = knowledge.query(
["AI technology"], results_limit=5, score_threshold=0.3
)
mock_storage.search.assert_called_once_with(
["AI technology"], limit=5, score_threshold=0.3
)
assert isinstance(results, list)
assert len(results) == 2
for result in results:
assert isinstance(result, dict)
assert "content" in result
assert "score" in result
assert "metadata" in result
assert results[0]["content"] == "AI is fascinating"
assert results[0]["score"] == 0.9
assert results[1]["content"] == "Machine learning rocks"
assert results[1]["score"] == 0.8
def test_knowledge_query_with_empty_results() -> None:
"""Test Knowledge.query with empty search results."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = []
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="empty_test", sources=sources)
results = knowledge.query(["nonexistent query"])
assert isinstance(results, list)
assert len(results) == 0
def test_extract_knowledge_context_with_searchresult() -> None:
"""Test extract_knowledge_context works with SearchResult format."""
search_results = [
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Python is great for AI" in context
assert "Machine learning algorithms" in context
assert "Deep learning frameworks" in context
expected_content = (
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
)
assert expected_content in context
def test_extract_knowledge_context_with_empty_content() -> None:
"""Test extract_knowledge_context handles empty or invalid content."""
search_results = [
{"content": "", "score": 0.5, "metadata": {}},
{"content": None, "score": 0.4, "metadata": {}},
{"score": 0.3, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert context == ""
def test_extract_knowledge_context_filters_invalid_results() -> None:
"""Test that extract_knowledge_context filters out invalid results."""
search_results: list[dict[str, Any] | None] = [
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
{"content": "", "score": 0.8, "metadata": {}},
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
None,
{"content": None, "score": 0.6, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Valid content 1" in context
assert "Valid content 2" in context
assert context.count("\n") == 1
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_storage_exception_handling(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge handles storage exceptions gracefully."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.side_effect = Exception("Storage error")
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="error_test", sources=sources)
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.storage = None
knowledge.query(["test query"])
def test_knowledge_add_sources_integration() -> None:
"""Test Knowledge.add_sources integrates properly with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [
StringKnowledgeSource(content="Content 1"),
StringKnowledgeSource(content="Content 2"),
]
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
knowledge.add_sources()
for source in sources:
assert source.storage == mock_storage
def test_knowledge_reset_integration() -> None:
"""Test Knowledge.reset integrates with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="reset_test", sources=sources)
knowledge.reset()
mock_storage.reset.assert_called_once()
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_reset_without_storage(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge.reset raises error when storage is None."""
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.reset()

View File

@@ -0,0 +1,196 @@
"""Integration tests for KnowledgeStorage RAG client migration."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_knowledge_storage_uses_rag_client(
mock_get_embedding: MagicMock,
mock_create_client: MagicMock,
mock_get_client: MagicMock,
) -> None:
"""Test that KnowledgeStorage properly integrates with RAG client."""
mock_client = MagicMock()
mock_create_client.return_value = mock_client
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
]
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
storage = KnowledgeStorage(
embedder=embedder_config, collection_name="test_knowledge"
)
mock_create_client.assert_called_once()
results = storage.search(["test query"], limit=5, score_threshold=0.3)
mock_get_client.assert_not_called()
mock_client.search.assert_called_once_with(
collection_name="knowledge_test_knowledge",
query="test query",
limit=5,
metadata_filter=None,
score_threshold=0.3,
)
assert isinstance(results, list)
assert len(results) == 1
assert isinstance(results[0], dict)
assert "content" in results[0]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
"""Test that collection names are properly prefixed."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage(collection_name="custom_knowledge")
storage.search(["test"], limit=1)
mock_client.search.assert_called_once()
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
mock_client.reset_mock()
storage_default = KnowledgeStorage()
storage_default.search(["test"], limit=1)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
"""Test document saving through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_docs")
documents = ["Document 1 content", "Document 2 content"]
storage.save(documents)
mock_client.get_or_create_collection.assert_called_once_with(
collection_name="knowledge_test_docs"
)
mock_client.add_documents.assert_called_once()
call_kwargs = mock_client.add_documents.call_args.kwargs
added_docs = call_kwargs["documents"]
assert len(added_docs) == 2
assert added_docs[0]["content"] == "Document 1 content"
assert added_docs[1]["content"] == "Document 2 content"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_reset_integration(mock_get_client: MagicMock) -> None:
"""Test collection reset through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_reset")
storage.reset()
mock_client.delete_collection.assert_called_once_with(
collection_name="knowledge_test_reset"
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_search_error_handling(mock_get_client: MagicMock) -> None:
"""Test error handling during search operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = Exception("RAG client error")
storage = KnowledgeStorage(collection_name="error_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_embedding_configuration_flow(
mock_get_embedding: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test that embedding configuration flows properly to RAG client."""
mock_embedding_func = MagicMock()
mock_get_embedding.return_value = mock_embedding_func
mock_get_client.return_value = MagicMock()
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
"""Test that query list is properly converted to string."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
storage.search(["single query"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "single query"
mock_client.reset_mock()
storage.search(["query one", "query two"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "query one query two"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
"""Test metadata filter parameter handling."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
metadata_filter = {"category": "technical", "priority": "high"}
storage.search(["test"], metadata_filter=metadata_filter)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] == metadata_filter
mock_client.reset_mock()
storage.search(["test"], metadata_filter=None)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] is None
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
"""Test specific handling of dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
storage = KnowledgeStorage(collection_name="dimension_test")
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
storage.save(["test document"])

View File

@@ -1,19 +1,20 @@
from unittest.mock import patch, ANY
from collections import defaultdict from collections import defaultdict
from unittest.mock import ANY, patch
import pytest import pytest
from crewai.agent import Agent from crewai.agent import Agent
from crewai.crew import Crew from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.short_term.short_term_memory import ShortTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.task import Task from crewai.task import Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
@pytest.fixture @pytest.fixture
@@ -38,22 +39,23 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory): def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list) events = defaultdict(list)
with crewai_event_bus.scoped_handlers(): with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event): def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event) events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent) @crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event): def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event) events["MemoryQueryCompletedEvent"].append(event)
# Call the save method # Call the save method
short_term_memory.search( short_term_memory.search(
query="test value", query="test value",
limit=3, limit=3,
score_threshold=0.35, score_threshold=0.35,
) )
assert len(events["MemoryQueryStartedEvent"]) == 1 assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1 assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory):
expected_result = [ expected_result = [
{ {
"context": memory.data, "content": memory.data,
"metadata": {"agent": "test_agent"}, "metadata": {"agent": "test_agent"},
"score": 0.95, "score": 0.95,
} }
] ]
with patch.object(ShortTermMemory, "search", return_value=expected_result): with patch.object(ShortTermMemory, "search", return_value=expected_result):
find = short_term_memory.search("test value", score_threshold=0.01)[0] find = short_term_memory.search("test value", score_threshold=0.01)[0]
assert find["context"] == memory.data, "Data value mismatch." assert find["content"] == memory.data, "Data value mismatch."
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch." assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."

View File

@@ -285,6 +285,43 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}], metadatas=[{"source": "test1"}, {"source": "test2"}],
) )
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
"""Test add_documents with documents that have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
client.add_documents(collection_name="test_collection", documents=documents)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
def test_add_documents_all_without_metadata(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents when all documents have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document 1"},
{"content": "Document 2"},
{"content": "Document 3"},
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] is None
def test_add_documents_empty_list_raises_error( def test_add_documents_empty_list_raises_error(
self, client, mock_chromadb_client self, client, mock_chromadb_client
) -> None: ) -> None:
@@ -358,6 +395,31 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}], metadatas=[{"source": "test1"}, {"source": "test2"}],
) )
@pytest.mark.asyncio
async def test_aadd_documents_without_metadata(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with documents that have no metadata."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_aadd_documents_empty_list_raises_error( async def test_aadd_documents_empty_list_raises_error(
self, async_client, mock_async_chromadb_client self, async_client, mock_async_chromadb_client

View File

@@ -0,0 +1,95 @@
"""Tests for ChromaDB utility functions."""
from crewai.rag.chromadb.utils import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
_is_ipv4_pattern,
_sanitize_collection_name,
)
class TestChromaDBUtils:
"""Test suite for ChromaDB utility functions."""
def test_sanitize_collection_name_long_name(self) -> None:
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = _sanitize_collection_name(long_name)
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_special_chars(self) -> None:
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = _sanitize_collection_name(special_chars)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_short_name(self) -> None:
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = _sanitize_collection_name(short_name)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_bad_ends(self) -> None:
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = _sanitize_collection_name(bad_ends)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_none(self) -> None:
"""Test sanitizing a None value."""
sanitized = _sanitize_collection_name(None)
assert sanitized == "default_collection"
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = _sanitize_collection_name(ipv4)
assert sanitized.startswith("ip_")
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_is_ipv4_pattern(self) -> None:
"""Test IPv4 pattern detection."""
assert _is_ipv4_pattern("192.168.1.1") is True
assert _is_ipv4_pattern("not.an.ip.address") is False
def test_sanitize_collection_name_properties(self) -> None:
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases: list[str] = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = _sanitize_collection_name(test_case)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_empty_string(self) -> None:
"""Test sanitizing an empty string."""
sanitized = _sanitize_collection_name("")
assert sanitized == "default_collection"
def test_sanitize_collection_name_whitespace_only(self) -> None:
"""Test sanitizing a string with only whitespace."""
sanitized = _sanitize_collection_name(" ")
assert (
sanitized == "a__z"
) # Spaces become underscores, padded to meet requirements
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()

View File

@@ -0,0 +1,250 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(
provider="openai", api_key="test-key", model="text-embedding-3-large"
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
# OpenAI uses model_name parameter, not model
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch(
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
) as mock_st:
mock_instance = MagicMock()
mock_st.return_value = mock_instance
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
mock_instance = MagicMock()
mock_ollama.return_value = mock_instance
config = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
mock_instance = MagicMock()
mock_cohere.return_value = mock_instance
config = {
"provider": "cohere",
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
mock_instance = MagicMock()
mock_hf.return_value = mock_instance
config = {
"provider": "huggingface",
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
mock_instance = MagicMock()
mock_onnx.return_value = mock_instance
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch(
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
) as mock_palm:
mock_instance = MagicMock()
mock_palm.return_value = mock_instance
config = {"provider": "google-palm", "api_key": "palm-key"}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function."""
with patch(
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
) as mock_bedrock:
mock_instance = MagicMock()
mock_bedrock.return_value = mock_instance
config = {
"provider": "amazon-bedrock",
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
mock_instance = MagicMock()
mock_jina.return_value = mock_instance
config = {
"provider": "jina",
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch(
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
) as mock_instructor:
mock_instance = MagicMock()
mock_instructor.return_value = mock_instance
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance

View File

@@ -0,0 +1,218 @@
"""Tests for RAG client error handling scenarios."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles RAG client connection failures."""
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
storage = KnowledgeStorage(collection_name="connection_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles search timeouts gracefully."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = TimeoutError("Search operation timed out")
storage = KnowledgeStorage(collection_name="timeout_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles missing collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = ValueError(
"Collection 'knowledge_missing' does not exist"
)
storage = KnowledgeStorage(collection_name="missing_collection")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles invalid embedding configurations."""
mock_get_client.return_value = MagicMock()
with patch(
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
) as mock_get_embedding:
mock_get_embedding.side_effect = ValueError(
"Unsupported provider: invalid_provider"
)
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
KnowledgeStorage(
embedder={"provider": "invalid_provider"},
collection_name="invalid_embedding_test",
)
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles RAG client failures in memory operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
storage = RAGStorage("short_term", crew=None)
results = storage.search("test query")
assert results == []
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles save operation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.add_documents.side_effect = Exception("Failed to add documents")
storage = RAGStorage("long_term", crew=None)
storage.save("test memory", {"key": "value"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage reset handles readonly database errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception(
"attempt to write a readonly database"
)
storage = KnowledgeStorage(collection_name="readonly_test")
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_collection_does_not_exist(
mock_get_client: MagicMock,
) -> None:
"""Test KnowledgeStorage reset handles non-existent collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
storage = KnowledgeStorage(collection_name="nonexistent_test")
storage.reset()
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
"""Test RAGStorage reset propagates unexpected errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
storage = RAGStorage("entities", crew=None)
with pytest.raises(
Exception, match="An error occurred while resetting the entities memory"
):
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles malformed search results."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "valid result", "metadata": {"source": "test"}},
{"invalid": "missing content field", "metadata": {"source": "test"}},
None,
{"content": None, "metadata": {"source": "test"}},
]
storage = KnowledgeStorage(collection_name="malformed_test")
results = storage.search(["test query"])
assert isinstance(results, list)
assert len(results) == 4
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles network interruptions during operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = [
ConnectionError("Network interruption"),
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
]
storage = KnowledgeStorage(collection_name="network_test")
first_attempt = storage.search(["test query"])
assert first_attempt == []
mock_client.search.side_effect = None
mock_client.search.return_value = [
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
]
second_attempt = storage.search(["test query"])
assert len(second_attempt) == 1
assert second_attempt[0]["content"] == "recovered result"
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles collection creation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.side_effect = Exception(
"Failed to create collection"
)
storage = RAGStorage("user_memory", crew=None)
storage.save("test data", {"metadata": "test"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
mock_get_client: MagicMock,
) -> None:
"""Test detailed handling of embedding dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception(
"Embedding dimension mismatch: expected 384, got 1536"
)
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
with pytest.raises(ValueError) as exc_info:
storage.save(["test document"])
assert "Embedding dimension mismatch" in str(exc_info.value)
assert "Make sure you're using the same embedding model" in str(exc_info.value)
assert "crewai reset-memories -a" in str(exc_info.value)

View File

@@ -1,8 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from mem0.client.main import MemoryClient from mem0 import Memory, MemoryClient
from mem0.memory.main import Memory
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -13,6 +12,67 @@ class MockCrew:
self.agents = [MagicMock(role="Test Agent")] self.agents = [MagicMock(role="Test Agent")]
# Test data constants
SYSTEM_CONTENT = (
"You are Friendly chatbot assistant. You are a kind and "
"knowledgeable chatbot assistant. You excel at understanding user needs, "
"providing helpful responses, and maintaining engaging conversations. "
"You remember previous interactions to provide a personalized experience.\n"
"Your personal goal is: Engage in useful and interesting conversations "
"with users while remembering context.\n"
"To give my best complete final answer to the task respond using the exact "
"following format:\n\n"
"Thought: I now can give a great answer\n"
"Final Answer: Your final answer must be the great and the most complete "
"as possible, it must be outcome described.\n\n"
"I MUST use these formats, my job depends on it!"
)
USER_CONTENT = (
"\nCurrent Task: Respond to user conversation. User message: "
"What do you know about me?\n\n"
"This is the expected criteria for your final answer: Contextually "
"appropriate, helpful, and friendly response.\n"
"you MUST return the actual complete content as the final answer, "
"not a summary.\n\n"
"# Useful context: \nExternal memories:\n"
"- User is from India\n"
"- User is interested in the solar system\n"
"- User name is Vidit Ostwal\n"
"- User is interested in French cuisine\n\n"
"Begin! This is VERY important to you, use the tools available and give "
"your best Final Answer, your job depends on it!\n\n"
"Thought:"
)
ASSISTANT_CONTENT = (
"I now can give a great answer \n"
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
TEST_DESCRIPTION = (
"Respond to user conversation. User message: What do you know about me?"
)
# Extracted content (after processing by _get_user_message and _get_assistant_message)
EXTRACTED_USER_CONTENT = "What do you know about me?"
EXTRACTED_ASSISTANT_CONTENT = (
"Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
@pytest.fixture @pytest.fixture
def mock_mem0_memory(): def mock_mem0_memory():
"""Fixture to create a mock Memory instance""" """Fixture to create a mock Memory instance"""
@@ -24,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies""" """Fixture to create a Mem0Storage instance with mocked dependencies"""
# Patch the Memory class to return our mock # Patch the Memory class to return our mock
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config: with patch(
"mem0.Memory.from_config", return_value=mock_mem0_memory
) as mock_from_config:
config = { config = {
"vector_store": { "vector_store": {
"provider": "mock_vector_store", "provider": "mock_vector_store",
@@ -55,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS # Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
crew = MockCrew() crew = MockCrew()
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True} embedder_config = {
"user_id": "test_user",
"local_mem0_config": config,
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True,
}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config) mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
return mem0_storage, mock_from_config, config return mem0_storage, mock_from_config, config
@@ -83,28 +152,31 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew() crew = MockCrew()
embedder_config={ embedder_config = {
"user_id": "test_user", "user_id": "test_user",
"api_key": "ABCDEFGH", "api_key": "ABCDEFGH",
"org_id": "my_org_id", "org_id": "my_org_id",
"project_id": "my_project_id", "project_id": "my_project_id",
"run_id": "my_run_id", "run_id": "my_run_id",
"includes": "include1", "includes": "include1",
"excludes": "exclude1", "excludes": "exclude1",
"infer": True "infer": True,
} }
return Mem0Storage(type="short_term", crew=crew, config=embedder_config) return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
@pytest.fixture @pytest.fixture
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory): def mem0_storage_with_memory_client_using_explictly_config(
mock_mem0_memory_client, mock_mem0_memory
):
"""Fixture to create a Mem0Storage instance with mocked dependencies""" """Fixture to create a Mem0Storage instance with mocked dependencies"""
# We need to patch both MemoryClient and Memory to prevent actual initialization # We need to patch both MemoryClient and Memory to prevent actual initialization
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \ with (
patch.object(Memory, "__new__", return_value=mock_mem0_memory): patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
):
crew = MockCrew() crew = MockCrew()
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}} new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
@@ -138,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
mock_mem0_memory_client.update_project = MagicMock() mock_mem0_memory_client.update_project = MagicMock()
new_categories = [ new_categories = [
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"}, {
"lifestyle_management_concerns": (
"Tracks daily routines, habits, hobbies and interests "
"including cooking, time management and work-life balance"
)
},
] ]
crew = MockCrew() crew = MockCrew()
config={ config = {
"user_id": "test_user", "user_id": "test_user",
"api_key": "ABCDEFGH", "api_key": "ABCDEFGH",
"org_id": "my_org_id", "org_id": "my_org_id",
"project_id": "my_project_id", "project_id": "my_project_id",
"custom_categories": new_categories "custom_categories": new_categories,
} }
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
_ = Mem0Storage(type="short_term", crew=crew, config=config) _ = Mem0Storage(type="short_term", crew=crew, config=config)
@@ -159,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
) )
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types""" """Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage, _, _ = mem0_storage_with_mocked_config
@@ -168,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
# Test short_term memory type (already set in fixture) # Test short_term memory type (already set in fixture)
test_value = "This is a test memory" test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata) mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], [
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True, infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id", run_id="my_run_id",
user_id="test_user", user_id="test_user",
agent_id='Test_Agent' agent_id="Test_Agent",
) )
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config): def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")] mem0_storage.crew.agents = [
MagicMock(role="Test Agent"),
MagicMock(role="Test Agent 2"),
MagicMock(role="Test Agent 3"),
]
mem0_storage.memory.add = MagicMock() mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory" test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata) mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], [
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True, infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id", run_id="my_run_id",
user_id="test_user", user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3' agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
) )
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew): def test_save_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test save method for different memory types""" """Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock() mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture) # Test short_term memory type (already set in fixture)
test_value = "This is a test memory" test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata) mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], [
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True, infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
version="v2", version="v2",
run_id="my_run_id", run_id="my_run_id",
includes="include1", includes="include1",
excludes="exclude1", excludes="exclude1",
output_format='v1.1', output_format="v1.1",
user_id='test_user', user_id="test_user",
agent_id='Test_Agent' agent_id="Test_Agent",
) )
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test search method for different memory types""" """Test search method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage, _, _ = mem0_storage_with_mocked_config
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results) mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -238,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
query="test query", query="test query",
limit=5, limit=5,
user_id="test_user", user_id="test_user",
filters={'AND': [{'run_id': 'my_run_id'}]}, filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5 threshold=0.5,
) )
assert len(results) == 2 assert len(results) == 2
assert results[0]["context"] == "Result 1" assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew): def test_search_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test search method for different memory types""" """Test search method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results) mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -259,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
limit=5, limit=5,
metadata={"type": "short_term"}, metadata={"type": "short_term"},
user_id="test_user", user_id="test_user",
version='v2', version="v2",
run_id="my_run_id", run_id="my_run_id",
output_format='v1.1', output_format="v1.1",
filters={'AND': [{'run_id': 'my_run_id'}]}, filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5 threshold=0.5,
) )
assert len(results) == 2 assert len(results) == 2
assert results[0]["context"] == "Result 1" assert results[0]["content"] == "Result 1"
def test_mem0_storage_default_infer_value(mock_mem0_memory_client): def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
@@ -275,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew() crew = MockCrew()
config={ config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
"user_id": "test_user",
"api_key": "ABCDEFGH"
}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config) mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
assert mem0_storage.infer is True assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client): def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = { config = {
"agent_id": "agent-123", "agent_id": "agent-123",
@@ -293,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
mem0_storage = Mem0Storage(type="external", config=config) mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"}) mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}], [{"role": "assistant", "content": "test memory"}],
infer=True, infer=True,
metadata={"type": "external", "key": "value"}, metadata={"type": "external", "key": "value"},
agent_id="agent-123", agent_id="agent-123",
) )
def test_search_method_with_agent_entity(): def test_search_method_with_agent_entity():
config = { config = {
"agent_id": "agent-123", "agent_id": "agent-123",
} }
mock_memory = MagicMock(spec=Memory) mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory): with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config) mem0_storage = Mem0Storage(type="external", config=config)
@@ -314,22 +466,29 @@ def test_search_method_with_agent_entity():
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with( mem0_storage.memory.search.assert_called_once_with(
query="test query", query="test query",
limit=5, limit=5,
filters={"AND": [{"agent_id": "agent-123"}]}, filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5, threshold=0.5,
) )
assert len(results) == 2 assert len(results) == 2
assert results[0]["context"] == "Result 1" assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id(): def test_search_method_with_agent_id_and_user_id():
mock_memory = MagicMock(spec=Memory) mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory): with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"}) mem0_storage = Mem0Storage(
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
)
mem0_storage.memory.search = MagicMock(return_value=mock_results) mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -337,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id():
mem0_storage.memory.search.assert_called_once_with( mem0_storage.memory.search.assert_called_once_with(
query="test query", query="test query",
limit=5, limit=5,
user_id='user-123', user_id="user-123",
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]}, filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5, threshold=0.5,
) )
assert len(results) == 2 assert len(results) == 2
assert results[0]["context"] == "Result 1" assert results[0]["content"] == "Result 1"

View File

@@ -1,123 +0,0 @@
import multiprocessing
import tempfile
import unittest
from chromadb.config import Settings
from unittest.mock import patch, MagicMock
from crewai.utilities.chromadb import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
is_ipv4_pattern,
sanitize_collection_name,
create_persistent_client,
)
def persistent_client_worker(path, queue):
try:
create_persistent_client(path=path)
queue.put(None)
except Exception as e:
queue.put(e)
class TestChromadbUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_special_chars(self):
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_short_name(self):
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_bad_ends(self):
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_none(self):
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
def test_sanitize_collection_name_ipv4_pattern(self):
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_is_ipv4_pattern(self):
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
def test_sanitize_collection_name_properties(self):
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_create_persistent_client_passes_args(self):
with patch(
"crewai.utilities.chromadb.PersistentClient"
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
mock_instance = MagicMock()
mock_persistent_client.return_value = mock_instance
settings = Settings(allow_reset=True)
client = create_persistent_client(path=tmpdir, settings=settings)
mock_persistent_client.assert_called_once_with(
path=tmpdir, settings=settings
)
self.assertIs(client, mock_instance)
def test_create_persistent_client_process_safe(self):
with tempfile.TemporaryDirectory() as tmpdir:
queue = multiprocessing.Queue()
processes = [
multiprocessing.Process(
target=persistent_client_worker, args=(tmpdir, queue)
)
for _ in range(5)
]
[p.start() for p in processes]
[p.join() for p in processes]
errors = [queue.get(timeout=5) for _ in processes]
self.assertTrue(all(err is None for err in errors))

View File

@@ -29,13 +29,15 @@ def mock_knowledge_source():
""" """
return StringKnowledgeSource(content=content) return StringKnowledgeSource(content=content)
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_knowledge_included_in_planning(mock_chroma): @patch("crewai.rag.config.utils.get_rag_client")
def test_knowledge_included_in_planning(mock_get_client):
"""Test that verifies knowledge sources are properly included in planning.""" """Test that verifies knowledge sources are properly included in planning."""
# Mock ChromaDB collection # Mock RAG client
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value mock_client = mock_get_client.return_value
mock_collection.add.return_value = None mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.return_value = None
# Create an agent with knowledge # Create an agent with knowledge
agent = Agent( agent = Agent(
role="AI Researcher", role="AI Researcher",
@@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma):
StringKnowledgeSource( StringKnowledgeSource(
content="AI systems require careful training and validation." content="AI systems require careful training and validation."
) )
] ],
) )
# Create a task for the agent # Create a task for the agent
task = Task( task = Task(
description="Explain the basics of AI systems", description="Explain the basics of AI systems",
expected_output="A clear explanation of AI fundamentals", expected_output="A clear explanation of AI fundamentals",
agent=agent agent=agent,
) )
# Create a crew planner # Create a crew planner
@@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma):
task_summary = planner._create_tasks_summary() task_summary = planner._create_tasks_summary()
# Verify that knowledge is included in planning when present # Verify that knowledge is included in planning when present
assert "AI systems require careful training" in task_summary, \ assert "AI systems require careful training" in task_summary, (
"Knowledge content should be present in task summary when knowledge exists" "Knowledge content should be present in task summary when knowledge exists"
assert '"agent_knowledge"' in task_summary, \ )
assert '"agent_knowledge"' in task_summary, (
"agent_knowledge field should be present in task summary when knowledge exists" "agent_knowledge field should be present in task summary when knowledge exists"
)
# Verify that knowledge is properly formatted # Verify that knowledge is properly formatted
assert isinstance(task.agent.knowledge_sources, list), \ assert isinstance(task.agent.knowledge_sources, list), (
"Knowledge sources should be stored in a list" "Knowledge sources should be stored in a list"
assert len(task.agent.knowledge_sources) > 0, \ )
assert len(task.agent.knowledge_sources) > 0, (
"At least one knowledge source should be present" "At least one knowledge source should be present"
assert task.agent.knowledge_sources[0].content in task_summary, \ )
assert task.agent.knowledge_sources[0].content in task_summary, (
"Knowledge source content should be included in task summary" "Knowledge source content should be included in task summary"
)
# Verify that other expected components are still present # Verify that other expected components are still present
assert task.description in task_summary, \ assert task.description in task_summary, (
"Task description should be present in task summary" "Task description should be present in task summary"
assert task.expected_output in task_summary, \ )
assert task.expected_output in task_summary, (
"Expected output should be present in task summary" "Expected output should be present in task summary"
assert agent.role in task_summary, \ )
"Agent role should be present in task summary" assert agent.role in task_summary, "Agent role should be present in task summary"