diff --git a/pyproject.toml b/pyproject.toml index 91f2eb629..a2587389b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,13 +131,14 @@ select = [ "I001", # sort imports "I002", # remove unused imports ] -ignore = ["E501"] # ignore line too long +ignore = ["E501"] # ignore line too long globally [tool.ruff.lint.per-file-ignores] -"tests/**/*.py" = ["S101"] # Allow assert statements in tests +"tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests [tool.mypy] -exclude = ["src/crewai/cli/templates", "tests"] +exclude = ["src/crewai/cli/templates", "tests/"] + [tool.bandit] exclude_dirs = ["src/crewai/cli/templates"] diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 9185d143d..124966116 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -3,26 +3,17 @@ import json import re import uuid import warnings +from collections.abc import Callable from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 from typing import ( Any, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Union, cast, ) from opentelemetry import baggage from opentelemetry.context import attach, detach - -from crewai.utilities.crew.models import CrewContext - from pydantic import ( UUID4, BaseModel, @@ -39,26 +30,14 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.cache import CacheHandler from crewai.crews.crew_output import CrewOutput -from crewai.flow.flow_trackable import FlowTrackable -from crewai.knowledge.knowledge import Knowledge -from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM, BaseLLM -from crewai.memory.entity.entity_memory import EntityMemory -from crewai.memory.external.external_memory import ExternalMemory -from crewai.memory.long_term.long_term_memory import LongTermMemory -from crewai.memory.short_term.short_term_memory import ShortTermMemory -from crewai.process import Process -from crewai.security import Fingerprint, SecurityConfig -from crewai.task import Task -from crewai.tasks.conditional_task import ConditionalTask -from crewai.tasks.task_output import TaskOutput -from crewai.tools.agent_tools.agent_tools import AgentTools -from crewai.tools.base_tool import BaseTool, Tool -from crewai.types.usage_metrics import UsageMetrics -from crewai.utilities import I18N, FileHandler, Logger, RPMController -from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE -from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator -from crewai.utilities.evaluators.task_evaluator import TaskEvaluator +from crewai.events.event_bus import crewai_event_bus +from crewai.events.event_listener import EventListener +from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, +) +from crewai.events.listeners.tracing.utils import ( + is_tracing_enabled, +) from crewai.events.types.crew_events import ( CrewKickoffCompletedEvent, CrewKickoffFailedEvent, @@ -70,16 +49,28 @@ from crewai.events.types.crew_events import ( CrewTrainFailedEvent, CrewTrainStartedEvent, ) -from crewai.events.event_bus import crewai_event_bus -from crewai.events.event_listener import EventListener -from crewai.events.listeners.tracing.trace_listener import ( - TraceCollectionListener, -) - - -from crewai.events.listeners.tracing.utils import ( - is_tracing_enabled, -) +from crewai.flow.flow_trackable import FlowTrackable +from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.llm import LLM, BaseLLM +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.external.external_memory import ExternalMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.process import Process +from crewai.rag.types import SearchResult +from crewai.security import Fingerprint, SecurityConfig +from crewai.task import Task +from crewai.tasks.conditional_task import ConditionalTask +from crewai.tasks.task_output import TaskOutput +from crewai.tools.agent_tools.agent_tools import AgentTools +from crewai.tools.base_tool import BaseTool, Tool +from crewai.types.usage_metrics import UsageMetrics +from crewai.utilities import I18N, FileHandler, Logger, RPMController +from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE +from crewai.utilities.crew.models import CrewContext +from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator +from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.formatter import ( aggregate_raw_outputs_from_task_outputs, aggregate_raw_outputs_from_tasks, @@ -94,28 +85,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") class Crew(FlowTrackable, BaseModel): """ - Represents a group of agents, defining how they should collaborate and the tasks they should perform. + Represents a group of agents, defining how they should collaborate and the + tasks they should perform. Attributes: tasks: List of tasks assigned to the crew. agents: List of agents part of this crew. manager_llm: The language model that will run manager agent. manager_agent: Custom agent that will be used as manager. - memory: Whether the crew should use memory to store memories of it's execution. - cache: Whether the crew should use a cache to store the results of the tools execution. - function_calling_llm: The language model that will run the tool calling for all the agents. - process: The process flow that the crew will follow (e.g., sequential, hierarchical). + memory: Whether the crew should use memory to store memories of it's + execution. + cache: Whether the crew should use a cache to store the results of the + tools execution. + function_calling_llm: The language model that will run the tool calling + for all the agents. + process: The process flow that the crew will follow (e.g., sequential, + hierarchical). verbose: Indicates the verbosity level for logging during execution. config: Configuration settings for the crew. - max_rpm: Maximum number of requests per minute for the crew execution to be respected. + max_rpm: Maximum number of requests per minute for the crew execution to + be respected. prompt_file: Path to the prompt json file to be used for the crew. id: A unique identifier for the crew instance. - task_callback: Callback to be executed after each task for every agents execution. - step_callback: Callback to be executed after each step for every agents execution. - share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models. + task_callback: Callback to be executed after each task for every agents + execution. + step_callback: Callback to be executed after each step for every agents + execution. + share_crew: Whether you want to share the complete crew information and + execution with crewAI to make the library better, and allow us to + train models. planning: Plan the crew execution and add the plan to the crew. - chat_llm: The language model used for orchestrating chat interactions with the crew. - security_config: Security configuration for the crew, including fingerprinting. + chat_llm: The language model used for orchestrating chat interactions + with the crew. + security_config: Security configuration for the crew, including + fingerprinting. """ __hash__ = object.__hash__ # type: ignore @@ -124,13 +127,13 @@ class Crew(FlowTrackable, BaseModel): _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) - _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() - _long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() - _entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr() - _external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr() - _train: Optional[bool] = PrivateAttr(default=False) - _train_iteration: Optional[int] = PrivateAttr() - _inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None) + _short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr() + _long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr() + _entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr() + _external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr() + _train: bool | None = PrivateAttr(default=False) + _train_iteration: int | None = PrivateAttr() + _inputs: dict[str, Any] | None = PrivateAttr(default=None) _logging_color: str = PrivateAttr( default="bold_purple", ) @@ -138,107 +141,121 @@ class Crew(FlowTrackable, BaseModel): default_factory=TaskOutputStorageHandler ) - name: Optional[str] = Field(default="crew") + name: str | None = Field(default="crew") cache: bool = Field(default=True) - tasks: List[Task] = Field(default_factory=list) - agents: List[BaseAgent] = Field(default_factory=list) + tasks: list[Task] = Field(default_factory=list) + agents: list[BaseAgent] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) memory: bool = Field( default=False, - description="Whether the crew should use memory to store memories of it's execution", + description="If crew should use memory to store memories of it's execution", ) - short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field( + short_term_memory: InstanceOf[ShortTermMemory] | None = Field( default=None, description="An Instance of the ShortTermMemory to be used by the Crew", ) - long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field( + long_term_memory: InstanceOf[LongTermMemory] | None = Field( default=None, description="An Instance of the LongTermMemory to be used by the Crew", ) - entity_memory: Optional[InstanceOf[EntityMemory]] = Field( + entity_memory: InstanceOf[EntityMemory] | None = Field( default=None, description="An Instance of the EntityMemory to be used by the Crew", ) - external_memory: Optional[InstanceOf[ExternalMemory]] = Field( + external_memory: InstanceOf[ExternalMemory] | None = Field( default=None, description="An Instance of the ExternalMemory to be used by the Crew", ) - embedder: Optional[dict] = Field( + embedder: dict | None = Field( default=None, description="Configuration for the embedder to be used for the crew.", ) - usage_metrics: Optional[UsageMetrics] = Field( + usage_metrics: UsageMetrics | None = Field( default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field( description="Language model that will run the agent.", default=None ) - manager_agent: Optional[BaseAgent] = Field( + manager_agent: BaseAgent | None = Field( description="Custom agent that will be used as manager.", default=None ) - function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( + function_calling_llm: str | InstanceOf[LLM] | Any | None = Field( description="Language model that will run the agent.", default=None ) - config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) + config: Json | dict[str, Any] | None = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) - share_crew: Optional[bool] = Field(default=False) - step_callback: Optional[Any] = Field( + share_crew: bool | None = Field(default=False) + step_callback: Any | None = Field( default=None, description="Callback to be executed after each step for all agents execution.", ) - task_callback: Optional[Any] = Field( + task_callback: Any | None = Field( default=None, description="Callback to be executed after each task for all agents execution.", ) - before_kickoff_callbacks: List[ - Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] + before_kickoff_callbacks: list[ + Callable[[dict[str, Any] | None], dict[str, Any] | None] ] = Field( default_factory=list, - description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.", + description=( + "List of callbacks to be executed before crew kickoff. " + "It may be used to adjust inputs before the crew is executed." + ), ) - after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field( + after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field( default_factory=list, - description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.", + description=( + "List of callbacks to be executed after crew kickoff. " + "It may be used to adjust the output of the crew." + ), ) - max_rpm: Optional[int] = Field( + max_rpm: int | None = Field( default=None, - description="Maximum number of requests per minute for the crew execution to be respected.", + description=( + "Maximum number of requests per minute for the crew execution " + "to be respected." + ), ) - prompt_file: Optional[str] = Field( + prompt_file: str | None = Field( default=None, description="Path to the prompt json file to be used for the crew.", ) - output_log_file: Optional[Union[bool, str]] = Field( + output_log_file: bool | str | None = Field( default=None, description="Path to the log file to be saved", ) - planning: Optional[bool] = Field( + planning: bool | None = Field( default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field( default=None, - description="Language model that will run the AgentPlanner if planning is True.", + description=( + "Language model that will run the AgentPlanner if planning is True." + ), ) - task_execution_output_json_files: Optional[List[str]] = Field( + task_execution_output_json_files: list[str] | None = Field( default=None, description="List of file paths for task execution JSON files.", ) - execution_logs: List[Dict[str, Any]] = Field( + execution_logs: list[dict[str, Any]] = Field( default=[], description="List of execution logs for tasks", ) - knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( + knowledge_sources: list[BaseKnowledgeSource] | None = Field( default=None, - description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", + description=( + "Knowledge sources for the crew. Add knowledge sources to the " + "knowledge object." + ), ) - chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field( default=None, description="LLM used to handle chatting with the crew.", ) - knowledge: Optional[Knowledge] = Field( + knowledge: Knowledge | None = Field( default=None, description="Knowledge for the crew.", ) @@ -246,18 +263,18 @@ class Crew(FlowTrackable, BaseModel): default_factory=SecurityConfig, description="Security configuration for the crew, including fingerprinting.", ) - token_usage: Optional[UsageMetrics] = Field( + token_usage: UsageMetrics | None = Field( default=None, description="Metrics for the LLM usage during all tasks execution.", ) - tracing: Optional[bool] = Field( + tracing: bool | None = Field( default=False, description="Whether to enable tracing for the crew.", ) @field_validator("id", mode="before") @classmethod - def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: + def _deny_user_set_id(cls, v: UUID4 | None) -> None: """Prevent manual setting of the 'id' field by users.""" if v: raise PydanticCustomError( @@ -266,9 +283,7 @@ class Crew(FlowTrackable, BaseModel): @field_validator("config", mode="before") @classmethod - def check_config_type( - cls, v: Union[Json, Dict[str, Any]] - ) -> Union[Json, Dict[str, Any]]: + def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]: """Validates that the config is a valid type. Args: v: The config to be validated. @@ -314,7 +329,8 @@ class Crew(FlowTrackable, BaseModel): def create_crew_memory(self) -> "Crew": """Initialize private memory attributes.""" self._external_memory = ( - # External memory doesn’t support a default value since it was designed to be managed entirely externally + # External memory does not support a default value since it was + # designed to be managed entirely externally self.external_memory.set_crew(self) if self.external_memory else None ) @@ -355,7 +371,10 @@ class Crew(FlowTrackable, BaseModel): if not self.manager_llm and not self.manager_agent: raise PydanticCustomError( "missing_manager_llm_or_manager_agent", - "Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.", + ( + "Attribute `manager_llm` or `manager_agent` is required " + "when using hierarchical process." + ), {}, ) @@ -398,7 +417,10 @@ class Crew(FlowTrackable, BaseModel): if task.agent is None: raise PydanticCustomError( "missing_agent_in_task", - f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" + ( + f"Sequential process error: Agent is missing in the task " + f"with the following description: {task.description}" + ), # type: ignore # Dynamic string in error message {}, ) @@ -459,7 +481,10 @@ class Crew(FlowTrackable, BaseModel): if task.async_execution and isinstance(task, ConditionalTask): raise PydanticCustomError( "invalid_async_conditional_task", - f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" + ( + f"Conditional Task: {task.description}, " + f"cannot be executed asynchronously." + ), {}, ) return self @@ -478,7 +503,9 @@ class Crew(FlowTrackable, BaseModel): for j in range(i - 1, -1, -1): if self.tasks[j] == context_task: raise ValueError( - f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context." + f"Task '{task.description}' is asynchronous and " + f"cannot include other sequential asynchronous " + f"tasks in its context." ) if not self.tasks[j].async_execution: break @@ -496,13 +523,15 @@ class Crew(FlowTrackable, BaseModel): continue # Skip context tasks not in the main tasks list if task_indices[id(context_task)] > task_indices[id(task)]: raise ValueError( - f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed." + f"Task '{task.description}' has a context dependency " + f"on a future task '{context_task.description}', " + f"which is not allowed." ) return self @property def key(self) -> str: - source: List[str] = [agent.key for agent in self.agents] + [ + source: list[str] = [agent.key for agent in self.agents] + [ task.key for task in self.tasks ] return md5("|".join(source).encode(), usedforsecurity=False).hexdigest() @@ -518,9 +547,9 @@ class Crew(FlowTrackable, BaseModel): return self.security_config.fingerprint def _setup_from_config(self): - assert self.config is not None, "Config should not be None." - """Initializes agents and tasks from the provided config.""" + if self.config is None: + raise ValueError("Config should not be None.") if not self.config.get("agents") or not self.config.get("tasks"): raise PydanticCustomError( "missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {} @@ -530,7 +559,7 @@ class Crew(FlowTrackable, BaseModel): self.agents = [Agent(**agent) for agent in self.config["agents"]] self.tasks = [self._create_task(task) for task in self.config["tasks"]] - def _create_task(self, task_config: Dict[str, Any]) -> Task: + def _create_task(self, task_config: dict[str, Any]) -> Task: """Creates a task instance from its configuration. Args: @@ -559,7 +588,7 @@ class Crew(FlowTrackable, BaseModel): CrewTrainingHandler(filename).initialize_file() def train( - self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None + self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None ) -> None: """Trains the crew for a given number of iterations.""" inputs = inputs or {} @@ -611,7 +640,7 @@ class Crew(FlowTrackable, BaseModel): def kickoff( self, - inputs: Optional[Dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, ) -> CrewOutput: ctx = baggage.set_baggage( "crew_context", CrewContext(id=str(self.id), key=self.key) @@ -682,9 +711,9 @@ class Crew(FlowTrackable, BaseModel): finally: detach(token) - def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]: - """Executes the Crew's workflow for each input in the list and aggregates results.""" - results: List[CrewOutput] = [] + def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]: + """Executes the Crew's workflow for each input and aggregates results.""" + results: list[CrewOutput] = [] # Initialize the parent crew's usage metrics total_usage_metrics = UsageMetrics() @@ -703,14 +732,12 @@ class Crew(FlowTrackable, BaseModel): self._task_output_handler.reset() return results - async def kickoff_async( - self, inputs: Optional[Dict[str, Any]] = None - ) -> CrewOutput: + async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput: """Asynchronous kickoff method to start the crew execution.""" inputs = inputs or {} return await asyncio.to_thread(self.kickoff, inputs) - async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]: + async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]: crew_copies = [self.copy() for _ in inputs] async def run_crew(crew, input_data): @@ -739,7 +766,9 @@ class Crew(FlowTrackable, BaseModel): tasks=self.tasks, planning_agent_llm=self.planning_llm )._handle_crew_planning() - for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): + for task, step_plan in zip( + self.tasks, result.list_of_plans_per_task, strict=False + ): task.description += step_plan.plan def _store_execution_log( @@ -776,7 +805,7 @@ class Crew(FlowTrackable, BaseModel): return self._execute_tasks(self.tasks) def _run_hierarchical_process(self) -> CrewOutput: - """Creates and assigns a manager agent to make sure the crew completes the tasks.""" + """Creates and assigns a manager agent to complete the tasks.""" self._create_manager_agent() return self._execute_tasks(self.tasks) @@ -807,23 +836,24 @@ class Crew(FlowTrackable, BaseModel): def _execute_tasks( self, - tasks: List[Task], - start_index: Optional[int] = 0, + tasks: list[Task], + start_index: int | None = 0, was_replayed: bool = False, ) -> CrewOutput: """Executes tasks sequentially and returns the final output. Args: tasks (List[Task]): List of tasks to execute - manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None. + manager (Optional[BaseAgent], optional): Manager agent to use for + delegation. Defaults to None. Returns: CrewOutput: Final output of the crew """ - task_outputs: List[TaskOutput] = [] - futures: List[Tuple[Task, Future[TaskOutput], int]] = [] - last_sync_output: Optional[TaskOutput] = None + task_outputs: list[TaskOutput] = [] + futures: list[tuple[Task, Future[TaskOutput], int]] = [] + last_sync_output: TaskOutput | None = None for task_index, task in enumerate(tasks): if start_index is not None and task_index < start_index: @@ -838,7 +868,9 @@ class Crew(FlowTrackable, BaseModel): agent_to_use = self._get_agent_to_use(task) if agent_to_use is None: raise ValueError( - f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided." + f"No agent available for task: {task.description}. " + f"Ensure that either the task has an assigned agent " + f"or a manager agent is provided." ) # Determine which tools to use - task tools take precedence over agent tools @@ -847,7 +879,7 @@ class Crew(FlowTrackable, BaseModel): tools_for_task = self._prepare_tools( agent_to_use, task, - cast(Union[List[Tool], List[BaseTool]], tools_for_task), + cast(list[Tool] | list[BaseTool], tools_for_task), ) self._log_task_start(task, agent_to_use.role) @@ -867,7 +899,7 @@ class Crew(FlowTrackable, BaseModel): future = task.execute_async( agent=agent_to_use, context=context, - tools=cast(List[BaseTool], tools_for_task), + tools=cast(list[BaseTool], tools_for_task), ) futures.append((task, future, task_index)) else: @@ -879,7 +911,7 @@ class Crew(FlowTrackable, BaseModel): task_output = task.execute_sync( agent=agent_to_use, context=context, - tools=cast(List[BaseTool], tools_for_task), + tools=cast(list[BaseTool], tools_for_task), ) task_outputs.append(task_output) self._process_task_result(task, task_output) @@ -893,11 +925,11 @@ class Crew(FlowTrackable, BaseModel): def _handle_conditional_task( self, task: ConditionalTask, - task_outputs: List[TaskOutput], - futures: List[Tuple[Task, Future[TaskOutput], int]], + task_outputs: list[TaskOutput], + futures: list[tuple[Task, Future[TaskOutput], int]], task_index: int, was_replayed: bool, - ) -> Optional[TaskOutput]: + ) -> TaskOutput | None: if futures: task_outputs = self._process_async_tasks(futures, was_replayed) futures.clear() @@ -917,8 +949,8 @@ class Crew(FlowTrackable, BaseModel): return None def _prepare_tools( - self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool] + ) -> list[BaseTool]: # Add delegation tools if agent allows delegation if hasattr(agent, "allow_delegation") and getattr( agent, "allow_delegation", False @@ -947,22 +979,22 @@ class Crew(FlowTrackable, BaseModel): ): tools = self._add_multimodal_tools(agent, tools) - # Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async - return cast(List[BaseTool], tools) + # Return a List[BaseTool] compatible with Task.execute_sync and execute_async + return cast(list[BaseTool], tools) - def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: + def _get_agent_to_use(self, task: Task) -> BaseAgent | None: if self.process == Process.hierarchical: return self.manager_agent return task.agent def _merge_tools( self, - existing_tools: Union[List[Tool], List[BaseTool]], - new_tools: Union[List[Tool], List[BaseTool]], - ) -> List[BaseTool]: - """Merge new tools into existing tools list, avoiding duplicates by tool name.""" + existing_tools: list[Tool] | list[BaseTool], + new_tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: + """Merge new tools into existing tools list, avoiding duplicates.""" if not new_tools: - return cast(List[BaseTool], existing_tools) + return cast(list[BaseTool], existing_tools) # Create mapping of tool names to new tools new_tool_map = {tool.name: tool for tool in new_tools} @@ -973,41 +1005,41 @@ class Crew(FlowTrackable, BaseModel): # Add all new tools tools.extend(new_tools) - return cast(List[BaseTool], tools) + return cast(list[BaseTool], tools) def _inject_delegation_tools( self, - tools: Union[List[Tool], List[BaseTool]], + tools: list[Tool] | list[BaseTool], task_agent: BaseAgent, - agents: List[BaseAgent], - ) -> List[BaseTool]: + agents: list[BaseAgent], + ) -> list[BaseTool]: if hasattr(task_agent, "get_delegation_tools"): delegation_tools = task_agent.get_delegation_tools(agents) # Cast delegation_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast(list[BaseTool], delegation_tools)) + return cast(list[BaseTool], tools) def _add_multimodal_tools( - self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, tools: list[Tool] | list[BaseTool] + ) -> list[BaseTool]: if hasattr(agent, "get_multimodal_tools"): multimodal_tools = agent.get_multimodal_tools() # Cast multimodal_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools)) + return cast(list[BaseTool], tools) def _add_code_execution_tools( - self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, tools: list[Tool] | list[BaseTool] + ) -> list[BaseTool]: if hasattr(agent, "get_code_execution_tools"): code_tools = agent.get_code_execution_tools() # Cast code_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], code_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast(list[BaseTool], code_tools)) + return cast(list[BaseTool], tools) def _add_delegation_tools( - self, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, task: Task, tools: list[Tool] | list[BaseTool] + ) -> list[BaseTool]: agents_for_delegation = [agent for agent in self.agents if agent != task.agent] if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if not tools: @@ -1015,7 +1047,7 @@ class Crew(FlowTrackable, BaseModel): tools = self._inject_delegation_tools( tools, task.agent, agents_for_delegation ) - return cast(List[BaseTool], tools) + return cast(list[BaseTool], tools) def _log_task_start(self, task: Task, role: str = "None"): if self.output_log_file: @@ -1024,8 +1056,8 @@ class Crew(FlowTrackable, BaseModel): ) def _update_manager_tools( - self, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, task: Task, tools: list[Tool] | list[BaseTool] + ) -> list[BaseTool]: if self.manager_agent: if task.agent: tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) @@ -1033,18 +1065,17 @@ class Crew(FlowTrackable, BaseModel): tools = self._inject_delegation_tools( tools, self.manager_agent, self.agents ) - return cast(List[BaseTool], tools) + return cast(list[BaseTool], tools) - def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str: + def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str: if not task.context: return "" - context = ( + return ( aggregate_raw_outputs_from_task_outputs(task_outputs) if task.context is NOT_SPECIFIED else aggregate_raw_outputs_from_tasks(task.context) ) - return context def _process_task_result(self, task: Task, output: TaskOutput) -> None: role = task.agent.role if task.agent is not None else "None" @@ -1057,7 +1088,7 @@ class Crew(FlowTrackable, BaseModel): output=output.raw, ) - def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput: + def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput: if not task_outputs: raise ValueError("No task outputs available to create crew output.") @@ -1088,10 +1119,10 @@ class Crew(FlowTrackable, BaseModel): def _process_async_tasks( self, - futures: List[Tuple[Task, Future[TaskOutput], int]], + futures: list[tuple[Task, Future[TaskOutput], int]], was_replayed: bool = False, - ) -> List[TaskOutput]: - task_outputs: List[TaskOutput] = [] + ) -> list[TaskOutput]: + task_outputs: list[TaskOutput] = [] for future_task, future, task_index in futures: task_output = future.result() task_outputs.append(task_output) @@ -1101,9 +1132,7 @@ class Crew(FlowTrackable, BaseModel): ) return task_outputs - def _find_task_index( - self, task_id: str, stored_outputs: List[Any] - ) -> Optional[int]: + def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None: return next( ( index @@ -1113,9 +1142,8 @@ class Crew(FlowTrackable, BaseModel): None, ) - def replay( - self, task_id: str, inputs: Optional[Dict[str, Any]] = None - ) -> CrewOutput: + def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput: + """Replay the crew execution from a specific task.""" stored_outputs = self._task_output_handler.load() if not stored_outputs: raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") @@ -1151,19 +1179,19 @@ class Crew(FlowTrackable, BaseModel): self.tasks[i].output = task_output self._logging_color = "bold_blue" - result = self._execute_tasks(self.tasks, start_index, True) - return result + return self._execute_tasks(self.tasks, start_index, True) def query_knowledge( - self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 - ) -> Union[List[Dict[str, Any]], None]: + self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + ) -> list[SearchResult] | None: + """Query the crew's knowledge base for relevant information.""" if self.knowledge: return self.knowledge.query( query, results_limit=results_limit, score_threshold=score_threshold ) return None - def fetch_inputs(self) -> Set[str]: + def fetch_inputs(self) -> set[str]: """ Gathers placeholders (e.g., {something}) referenced in tasks or agents. Scans each task's 'description' + 'expected_output', and each agent's @@ -1172,11 +1200,11 @@ class Crew(FlowTrackable, BaseModel): Returns a set of all discovered placeholder names. """ placeholder_pattern = re.compile(r"\{(.+?)\}") - required_inputs: Set[str] = set() + required_inputs: set[str] = set() # Scan tasks for inputs for task in self.tasks: - # description and expected_output might contain e.g. {topic}, {user_name}, etc. + # description and expected_output might contain e.g. {topic}, {user_name} text = f"{task.description or ''} {task.expected_output or ''}" required_inputs.update(placeholder_pattern.findall(text)) @@ -1230,7 +1258,7 @@ class Crew(FlowTrackable, BaseModel): cloned_tasks.append(cloned_task) task_mapping[task.key] = cloned_task - for cloned_task, original_task in zip(cloned_tasks, self.tasks): + for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False): if isinstance(original_task.context, list): cloned_context = [ task_mapping[context_task.key] @@ -1256,7 +1284,7 @@ class Crew(FlowTrackable, BaseModel): copied_data.pop("agents", None) copied_data.pop("tasks", None) - copied_crew = Crew( + return Crew( **copied_data, agents=cloned_agents, tasks=cloned_tasks, @@ -1266,15 +1294,13 @@ class Crew(FlowTrackable, BaseModel): manager_llm=manager_llm, ) - return copied_crew - def _set_tasks_callbacks(self) -> None: """Sets callback for every task suing task_callback""" for task in self.tasks: if not task.callback: task.callback = self.task_callback - def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None: + def _interpolate_inputs(self, inputs: dict[str, Any]) -> None: """Interpolates the inputs in the tasks and agents.""" [ task.interpolate_inputs_and_add_conversation_history( @@ -1307,10 +1333,13 @@ class Crew(FlowTrackable, BaseModel): def test( self, n_iterations: int, - eval_llm: Union[str, InstanceOf[BaseLLM]], - inputs: Optional[Dict[str, Any]] = None, + eval_llm: str | InstanceOf[BaseLLM], + inputs: dict[str, Any] | None = None, ) -> None: - """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" + """Test and evaluate the Crew with the given inputs for n iterations. + + Uses concurrent.futures for concurrent execution. + """ try: # Create LLM instance and ensure it's of type LLM for CrewEvaluator llm_instance = create_llm(eval_llm) @@ -1350,7 +1379,11 @@ class Crew(FlowTrackable, BaseModel): raise def __repr__(self): - return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" + return ( + f"Crew(id={self.id}, process={self.process}, " + f"number_of_agents={len(self.agents)}, " + f"number_of_tasks={len(self.tasks)})" + ) def reset_memories(self, command_type: str) -> None: """Reset specific or all memories for the crew. @@ -1364,7 +1397,7 @@ class Crew(FlowTrackable, BaseModel): ValueError: If an invalid command type is provided. RuntimeError: If memory reset operation fails. """ - VALID_TYPES = frozenset( + valid_types = frozenset( [ "long", "short", @@ -1377,9 +1410,10 @@ class Crew(FlowTrackable, BaseModel): ] ) - if command_type not in VALID_TYPES: + if command_type not in valid_types: raise ValueError( - f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}" + f"Invalid command type. Must be one of: " + f"{', '.join(sorted(valid_types))}" ) try: @@ -1389,7 +1423,7 @@ class Crew(FlowTrackable, BaseModel): self._reset_specific_memory(command_type) except Exception as e: - error_msg = f"Failed to reset {command_type} memory: {str(e)}" + error_msg = f"Failed to reset {command_type} memory: {e!s}" self._logger.log("error", error_msg) raise RuntimeError(error_msg) from e @@ -1397,7 +1431,7 @@ class Crew(FlowTrackable, BaseModel): """Reset all available memory systems.""" memory_systems = self._get_memory_systems() - for memory_type, config in memory_systems.items(): + for config in memory_systems.values(): if (system := config.get("system")) is not None: name = config.get("name") try: @@ -1405,11 +1439,13 @@ class Crew(FlowTrackable, BaseModel): reset_fn(system) self._logger.log( "info", - f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", + f"[Crew ({self.name if self.name else self.id})] " + f"{name} memory has been reset", ) except Exception as e: raise RuntimeError( - f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + f"[Crew ({self.name if self.name else self.id})] " + f"Failed to reset {name} memory: {e!s}" ) from e def _reset_specific_memory(self, memory_type: str) -> None: @@ -1434,18 +1470,21 @@ class Crew(FlowTrackable, BaseModel): reset_fn(system) self._logger.log( "info", - f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", + f"[Crew ({self.name if self.name else self.id})] " + f"{name} memory has been reset", ) except Exception as e: raise RuntimeError( - f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + f"[Crew ({self.name if self.name else self.id})] " + f"Failed to reset {name} memory: {e!s}" ) from e def _get_memory_systems(self): """Get all available memory systems with their configuration. Returns: - Dict containing all memory systems with their reset functions and display names. + Dict containing all memory systems with their reset functions and + display names. """ def default_reset(memory): @@ -1506,7 +1545,7 @@ class Crew(FlowTrackable, BaseModel): }, } - def reset_knowledge(self, knowledges: List[Knowledge]) -> None: + def reset_knowledge(self, knowledges: list[Knowledge]) -> None: """Reset crew and agent knowledge storage.""" for ks in knowledges: ks.reset() diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 2340dec90..3330ba6ce 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -1,10 +1,11 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.rag.types import SearchResult os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed @@ -13,23 +14,23 @@ class Knowledge(BaseModel): """ Knowledge is a collection of sources and setup for the vector store to save and query relevant context. Args: - sources: List[BaseKnowledgeSource] = Field(default_factory=list) - storage: Optional[KnowledgeStorage] = Field(default=None) - embedder: Optional[Dict[str, Any]] = None + sources: list[BaseKnowledgeSource] = Field(default_factory=list) + storage: KnowledgeStorage | None = Field(default=None) + embedder: dict[str, Any] | None = None """ - sources: List[BaseKnowledgeSource] = Field(default_factory=list) + sources: list[BaseKnowledgeSource] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) - storage: Optional[KnowledgeStorage] = Field(default=None) - embedder: Optional[Dict[str, Any]] = None - collection_name: Optional[str] = None + storage: KnowledgeStorage | None = Field(default=None) + embedder: dict[str, Any] | None = None + collection_name: str | None = None def __init__( self, collection_name: str, - sources: List[BaseKnowledgeSource], - embedder: Optional[Dict[str, Any]] = None, - storage: Optional[KnowledgeStorage] = None, + sources: list[BaseKnowledgeSource], + embedder: dict[str, Any] | None = None, + storage: KnowledgeStorage | None = None, **data, ): super().__init__(**data) @@ -40,11 +41,10 @@ class Knowledge(BaseModel): embedder=embedder, collection_name=collection_name ) self.sources = sources - self.storage.initialize_knowledge_storage() def query( - self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 - ) -> List[Dict[str, Any]]: + self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + ) -> list[SearchResult]: """ Query across all knowledge sources to find the most relevant information. Returns the top_k most relevant chunks. @@ -55,12 +55,11 @@ class Knowledge(BaseModel): if self.storage is None: raise ValueError("Storage is not initialized.") - results = self.storage.search( + return self.storage.search( query, limit=results_limit, score_threshold=score_threshold, ) - return results def add_sources(self): try: diff --git a/src/crewai/knowledge/storage/base_knowledge_storage.py b/src/crewai/knowledge/storage/base_knowledge_storage.py index d4887e85b..376ed6612 100644 --- a/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any + +from crewai.rag.types import SearchResult class BaseKnowledgeStorage(ABC): @@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC): @abstractmethod def search( self, - query: List[str], + query: list[str], limit: int = 3, - filter: Optional[dict] = None, + metadata_filter: dict[str, Any] | None = None, score_threshold: float = 0.35, - ) -> List[Dict[str, Any]]: + ) -> list[SearchResult]: """Search for documents in the knowledge base.""" - pass @abstractmethod - def save( - self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]] - ) -> None: + def save(self, documents: list[str]) -> None: """Save documents to the knowledge base.""" - pass @abstractmethod def reset(self) -> None: """Reset the knowledge base.""" - pass diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 3629dc7ce..4aeb58e15 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -1,24 +1,16 @@ -import hashlib import logging -import os -import shutil -from typing import Any, Dict, List, Optional, Union - -import chromadb -import chromadb.errors -from chromadb.api import ClientAPI -from chromadb.api.types import OneOrMany -from chromadb.config import Settings import warnings +from typing import Any, cast from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage -from crewai.rag.embeddings.configurator import EmbeddingConfigurator -from crewai.utilities.chromadb import sanitize_collection_name -from crewai.utilities.constants import KNOWLEDGE_DIRECTORY +from crewai.rag.chromadb.config import ChromaDBConfig +from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper +from crewai.rag.config.utils import get_rag_client +from crewai.rag.core.base_client import BaseClient +from crewai.rag.embeddings.factory import get_embedding_function +from crewai.rag.factory import create_client +from crewai.rag.types import BaseRecord, SearchResult from crewai.utilities.logger import Logger -from crewai.utilities.paths import db_storage_path -from crewai.utilities.chromadb import create_persistent_client -from crewai.utilities.logger_utils import suppress_logging class KnowledgeStorage(BaseKnowledgeStorage): @@ -27,167 +19,101 @@ class KnowledgeStorage(BaseKnowledgeStorage): search efficiency. """ - collection: Optional[chromadb.Collection] = None - collection_name: Optional[str] = "knowledge" - app: Optional[ClientAPI] = None - def __init__( self, - embedder: Optional[Dict[str, Any]] = None, - collection_name: Optional[str] = None, - ): + embedder: dict[str, Any] | None = None, + collection_name: str | None = None, + ) -> None: self.collection_name = collection_name - self._set_embedder_config(embedder) + self._client: BaseClient | None = None - def search( - self, - query: List[str], - limit: int = 3, - filter: Optional[dict] = None, - score_threshold: float = 0.35, - ) -> List[Dict[str, Any]]: - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - if self.collection: - fetched = self.collection.query( - query_texts=query, - n_results=limit, - where=filter, - ) - results = [] - for i in range(len(fetched["ids"][0])): # type: ignore - result = { - "id": fetched["ids"][0][i], # type: ignore - "metadata": fetched["metadatas"][0][i], # type: ignore - "context": fetched["documents"][0][i], # type: ignore - "score": fetched["distances"][0][i], # type: ignore - } - if result["score"] >= score_threshold: - results.append(result) - return results - else: - raise Exception("Collection not initialized") - - def initialize_knowledge_storage(self): - # Suppress deprecation warnings from chromadb, which are not relevant to us - # TODO: Remove this once we upgrade chromadb to at least 1.0.8. warnings.filterwarnings( "ignore", message=r".*'model_fields'.*is deprecated.*", module=r"^chromadb(\.|$)", ) - self.app = create_persistent_client( - path=os.path.join(db_storage_path(), "knowledge"), - settings=Settings(allow_reset=True), - ) + if embedder: + embedding_function = get_embedding_function(embedder) + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ) + ) + self._client = create_client(config) + def _get_client(self) -> BaseClient: + """Get the appropriate client - instance-specific or global.""" + return self._client if self._client else get_rag_client() + + def search( + self, + query: list[str], + limit: int = 3, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.35, + ) -> list[SearchResult]: try: + if not query: + raise ValueError("Query cannot be empty") + + client = self._get_client() collection_name = ( f"knowledge_{self.collection_name}" if self.collection_name else "knowledge" ) - if self.app: - self.collection = self.app.get_or_create_collection( - name=sanitize_collection_name(collection_name), - embedding_function=self.embedder, - ) - else: - raise Exception("Vector Database Client not initialized") - except Exception: - raise Exception("Failed to create or get collection") + query_text = " ".join(query) if len(query) > 1 else query[0] - def reset(self): - base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) - if not self.app: - self.app = create_persistent_client( - path=base_path, settings=Settings(allow_reset=True) + return client.search( + collection_name=collection_name, + query=query_text, + limit=limit, + metadata_filter=metadata_filter, + score_threshold=score_threshold, ) - - self.app.reset() - shutil.rmtree(base_path) - self.app = None - self.collection = None - - def save( - self, - documents: List[str], - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - ): - if not self.collection: - raise Exception("Collection not initialized") - - try: - # Create a dictionary to store unique documents - unique_docs = {} - - # Generate IDs and create a mapping of id -> (document, metadata) - for idx, doc in enumerate(documents): - doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest() - doc_metadata = None - if metadata is not None: - if isinstance(metadata, list): - doc_metadata = metadata[idx] - else: - doc_metadata = metadata - unique_docs[doc_id] = (doc, doc_metadata) - - # Prepare filtered lists for ChromaDB - filtered_docs = [] - filtered_metadata = [] - filtered_ids = [] - - # Build the filtered lists - for doc_id, (doc, meta) in unique_docs.items(): - filtered_docs.append(doc) - filtered_metadata.append(meta) - filtered_ids.append(doc_id) - - # If we have no metadata at all, set it to None - final_metadata: Optional[OneOrMany[chromadb.Metadata]] = ( - None if all(m is None for m in filtered_metadata) else filtered_metadata - ) - - self.collection.upsert( - documents=filtered_docs, - metadatas=final_metadata, - ids=filtered_ids, - ) - except chromadb.errors.InvalidDimensionException as e: - Logger(verbose=True).log( - "error", - "Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`", - "red", - ) - raise ValueError( - "Embedding dimension mismatch. Make sure you're using the same embedding model " - "across all operations with this collection." - "Try resetting the collection using `crewai reset-memories -a`" - ) from e except Exception as e: + logging.error(f"Error during knowledge search: {e!s}") + return [] + + def reset(self) -> None: + try: + client = self._get_client() + collection_name = ( + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" + ) + client.delete_collection(collection_name=collection_name) + except Exception as e: + logging.error(f"Error during knowledge reset: {e!s}") + + def save(self, documents: list[str]) -> None: + try: + client = self._get_client() + collection_name = ( + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" + ) + client.get_or_create_collection(collection_name=collection_name) + + rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] + + client.add_documents( + collection_name=collection_name, documents=rag_documents + ) + except Exception as e: + if "dimension mismatch" in str(e).lower(): + Logger(verbose=True).log( + "error", + "Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`", + "red", + ) + raise ValueError( + "Embedding dimension mismatch. Make sure you're using the same embedding model " + "across all operations with this collection." + "Try resetting the collection using `crewai reset-memories -a`" + ) from e Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") raise - - def _create_default_embedding_function(self): - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" - ) - - def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: - """Set the embedding configuration for the knowledge storage. - - Args: - embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. - If None or empty, defaults to the default embedding function. - """ - self.embedder = ( - EmbeddingConfigurator().configure_embedder(embedder) - if embedder - else self._create_default_embedding_function() - ) diff --git a/src/crewai/knowledge/utils/knowledge_utils.py b/src/crewai/knowledge/utils/knowledge_utils.py index bdd8b9a4e..98f2af197 100644 --- a/src/crewai/knowledge/utils/knowledge_utils.py +++ b/src/crewai/knowledge/utils/knowledge_utils.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, List +from crewai.rag.types import SearchResult -def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str: +def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str: """Extract knowledge from the task prompt.""" valid_snippets = [ - result["context"] + result["content"] for result in knowledge_snippets - if result and result.get("context") + if result and result.get("content") ] snippet = "\n".join(valid_snippets) return f"Additional Information: {snippet}" if valid_snippets else "" diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index 3a0f86b70..ba7906ae1 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -1,4 +1,6 @@ -from typing import Optional, TYPE_CHECKING +from __future__ import annotations + +from typing import TYPE_CHECKING from crewai.memory import ( EntityMemory, @@ -19,9 +21,9 @@ class ContextualMemory: ltm: LongTermMemory, em: EntityMemory, exm: ExternalMemory, - agent: Optional["Agent"] = None, - task: Optional["Task"] = None, - ): + agent: Agent | None = None, + task: Task | None = None, + ) -> None: self.stm = stm self.ltm = ltm self.em = em @@ -42,7 +44,7 @@ class ContextualMemory: self.exm.agent = self.agent self.exm.task = self.task - def build_context_for_task(self, task, context) -> str: + def build_context_for_task(self, task: Task, context: str) -> str: """ Automatically builds a minimal, highly relevant set of contextual information for a given task. @@ -52,14 +54,15 @@ class ContextualMemory: if query == "": return "" - context = [] - context.append(self._fetch_ltm_context(task.description)) - context.append(self._fetch_stm_context(query)) - context.append(self._fetch_entity_context(query)) - context.append(self._fetch_external_context(query)) - return "\n".join(filter(None, context)) + context_parts = [ + self._fetch_ltm_context(task.description), + self._fetch_stm_context(query), + self._fetch_entity_context(query), + self._fetch_external_context(query), + ] + return "\n".join(filter(None, context_parts)) - def _fetch_stm_context(self, query) -> str: + def _fetch_stm_context(self, query: str) -> str: """ Fetches recent relevant insights from STM related to the task's description and expected_output, formatted as bullet points. @@ -70,11 +73,11 @@ class ContextualMemory: stm_results = self.stm.search(query) formatted_results = "\n".join( - [f"- {result['context']}" for result in stm_results] + [f"- {result['content']}" for result in stm_results] ) return f"Recent Insights:\n{formatted_results}" if stm_results else "" - def _fetch_ltm_context(self, task) -> Optional[str]: + def _fetch_ltm_context(self, task: str) -> str | None: """ Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, formatted as bullet points. @@ -90,14 +93,14 @@ class ContextualMemory: formatted_results = [ suggestion for result in ltm_results - for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" + for suggestion in result["metadata"]["suggestions"] ] formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") return f"Historical Data:\n{formatted_results}" if ltm_results else "" - def _fetch_entity_context(self, query) -> str: + def _fetch_entity_context(self, query: str) -> str: """ Fetches relevant entity information from Entity Memory related to the task's description and expected_output, formatted as bullet points. @@ -107,7 +110,7 @@ class ContextualMemory: em_results = self.em.search(query) formatted_results = "\n".join( - [f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" + [f"- {result['content']}" for result in em_results] ) return f"Entities:\n{formatted_results}" if em_results else "" @@ -128,6 +131,6 @@ class ContextualMemory: return "" formatted_memories = "\n".join( - f"- {result['context']}" for result in external_memories + f"- {result['content']}" for result in external_memories ) return f"External memories:\n{formatted_memories}" diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 020e0058e..128aa6ed8 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -1,12 +1,13 @@ import os import re from collections import defaultdict -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any -from mem0 import Memory, MemoryClient # type: ignore[import-untyped] +from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found] from crewai.memory.storage.interface import Storage -from crewai.utilities.chromadb import sanitize_collection_name +from crewai.rag.chromadb.utils import _sanitize_collection_name MAX_AGENT_ID_LENGTH_MEM0 = 255 @@ -15,6 +16,7 @@ class Mem0Storage(Storage): """ Extends Storage to handle embedding and searching across entities using Mem0. """ + def __init__(self, type, crew=None, config=None): super().__init__() @@ -30,7 +32,8 @@ class Mem0Storage(Storage): supported_types = {"short_term", "long_term", "entities", "external"} if type not in supported_types: raise ValueError( - f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}" + f"Invalid type '{type}' for Mem0Storage. " + f"Must be one of: {', '.join(supported_types)}" ) def _extract_config_values(self): @@ -68,7 +71,8 @@ class Mem0Storage(Storage): - Includes user_id and agent_id if both are present. - Includes user_id if only user_id is present. - Includes agent_id if only agent_id is present. - - Includes run_id if memory_type is 'short_term' and mem0_run_id is present. + - Includes run_id if memory_type is 'short_term' and + mem0_run_id is present. """ filter = defaultdict(list) @@ -91,10 +95,14 @@ class Mem0Storage(Storage): def save(self, value: Any, metadata: dict[str, Any]) -> None: def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: return next( - (m.get("content", "") for m in reversed(list(messages)) if m.get("role") == role), - "" + ( + m.get("content", "") + for m in reversed(list(messages)) + if m.get("role") == role + ), + "", ) - + conversations = [] messages = metadata.pop("messages", None) if messages: @@ -103,7 +111,7 @@ class Mem0Storage(Storage): if user_msg := self._get_user_message(last_user): conversations.append({"role": "user", "content": user_msg}) - + if assistant_msg := self._get_assistant_message(last_assistant): conversations.append({"role": "assistant", "content": assistant_msg}) else: @@ -115,13 +123,13 @@ class Mem0Storage(Storage): "short_term": "short_term", "long_term": "long_term", "entities": "entity", - "external": "external" + "external": "external", } # Shared base params params: dict[str, Any] = { "metadata": {"type": base_metadata[self.memory_type], **metadata}, - "infer": self.infer + "infer": self.infer, } # MemoryClient-specific overrides @@ -142,13 +150,15 @@ class Mem0Storage(Storage): self.memory.add(conversations, **params) - def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> list[Any]: + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> list[Any]: params = { "query": query, "limit": limit, "version": "v2", - "output_format": "v1.1" - } + "output_format": "v1.1", + } if user_id := self.config.get("user_id", ""): params["user_id"] = user_id @@ -169,10 +179,10 @@ class Mem0Storage(Storage): # automatically when the crew is created. params["filters"] = self._create_filter_for_search() - params['threshold'] = score_threshold + params["threshold"] = score_threshold if isinstance(self.memory, Memory): - del params["metadata"], params["version"], params['output_format'] + del params["metadata"], params["version"], params["output_format"] if params.get("run_id"): del params["run_id"] @@ -180,7 +190,7 @@ class Mem0Storage(Storage): # This makes it compatible for Contextual Memory to retrieve for result in results["results"]: - result["context"] = result["memory"] + result["content"] = result["memory"] return [r for r in results["results"]] @@ -201,7 +211,9 @@ class Mem0Storage(Storage): agents = self.crew.agents agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) - return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0) + return _sanitize_collection_name( + name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0 + ) def _get_assistant_message(self, text: str) -> str: marker = "Final Answer:" diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 504da2fce..b52ec384e 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -1,17 +1,16 @@ import logging -import os -import shutil -import uuid +import warnings +from typing import Any -from typing import Any, Dict, List, Optional -from chromadb.api import ClientAPI +from crewai.rag.chromadb.config import ChromaDBConfig +from crewai.rag.config.utils import get_rag_client +from crewai.rag.core.base_client import BaseClient +from crewai.rag.embeddings.factory import get_embedding_function +from crewai.rag.factory import create_client from crewai.rag.storage.base_rag_storage import BaseRAGStorage -from crewai.rag.embeddings.configurator import EmbeddingConfigurator -from crewai.utilities.chromadb import create_persistent_client +from crewai.rag.types import BaseRecord from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.paths import db_storage_path -from crewai.utilities.logger_utils import suppress_logging -import warnings class RAGStorage(BaseRAGStorage): @@ -20,8 +19,6 @@ class RAGStorage(BaseRAGStorage): search efficiency. """ - app: ClientAPI | None = None - def __init__( self, type, allow_reset=True, embedder_config=None, crew=None, path=None ): @@ -33,37 +30,25 @@ class RAGStorage(BaseRAGStorage): self.storage_file_name = self._build_storage_file_name(type, agents) self.type = type + self._client: BaseClient | None = None self.allow_reset = allow_reset self.path = path - self._initialize_app() - def _set_embedder_config(self): - configurator = EmbeddingConfigurator() - self.embedder_config = configurator.configure_embedder(self.embedder_config) - - def _initialize_app(self): - from chromadb.config import Settings - - # Suppress deprecation warnings from chromadb, which are not relevant to us - # TODO: Remove this once we upgrade chromadb to at least 1.0.8. warnings.filterwarnings( "ignore", message=r".*'model_fields'.*is deprecated.*", module=r"^chromadb(\.|$)", ) - self._set_embedder_config() + if self.embedder_config: + embedding_function = get_embedding_function(self.embedder_config) + config = ChromaDBConfig(embedding_function=embedding_function) + self._client = create_client(config) - self.app = create_persistent_client( - path=self.path if self.path else self.storage_file_name, - settings=Settings(allow_reset=self.allow_reset), - ) - - self.collection = self.app.get_or_create_collection( - name=self.type, embedding_function=self.embedder_config - ) - logging.info(f"Collection found or created: {self.collection}") + def _get_client(self) -> BaseClient: + """Get the appropriate client - instance-specific or global.""" + return self._client if self._client else get_rag_client() def _sanitize_role(self, role: str) -> str: """ @@ -85,77 +70,65 @@ class RAGStorage(BaseRAGStorage): return f"{base_path}/{file_name}" - def save(self, value: Any, metadata: Dict[str, Any]) -> None: - if not hasattr(self, "app") or not hasattr(self, "collection"): - self._initialize_app() + def save(self, value: Any, metadata: dict[str, Any]) -> None: try: - self._generate_embedding(value, metadata) + client = self._get_client() + collection_name = ( + f"memory_{self.type}_{self.agents}" + if self.agents + else f"memory_{self.type}" + ) + client.get_or_create_collection(collection_name=collection_name) + + document: BaseRecord = {"content": value} + if metadata: + document["metadata"] = metadata + + client.add_documents(collection_name=collection_name, documents=[document]) except Exception as e: - logging.error(f"Error during {self.type} save: {str(e)}") + logging.error(f"Error during {self.type} save: {e!s}") def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, + filter: dict[str, Any] | None = None, score_threshold: float = 0.35, - ) -> List[Any]: - if not hasattr(self, "app"): - self._initialize_app() - + ) -> list[Any]: try: - with suppress_logging( - "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR - ): - response = self.collection.query(query_texts=query, n_results=limit) - - results = [] - for i in range(len(response["ids"][0])): - result = { - "id": response["ids"][0][i], - "metadata": response["metadatas"][0][i], - "context": response["documents"][0][i], - "score": response["distances"][0][i], - } - if result["score"] >= score_threshold: - results.append(result) - - return results + client = self._get_client() + collection_name = ( + f"memory_{self.type}_{self.agents}" + if self.agents + else f"memory_{self.type}" + ) + return client.search( + collection_name=collection_name, + query=query, + limit=limit, + metadata_filter=filter, + score_threshold=score_threshold, + ) except Exception as e: - logging.error(f"Error during {self.type} search: {str(e)}") + logging.error(f"Error during {self.type} search: {e!s}") return [] - def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore - if not hasattr(self, "app") or not hasattr(self, "collection"): - self._initialize_app() - - self.collection.add( - documents=[text], - metadatas=[metadata or {}], - ids=[str(uuid.uuid4())], - ) - def reset(self) -> None: try: - if self.app: - self.app.reset() - shutil.rmtree(f"{db_storage_path()}/{self.type}") - self.app = None - self.collection = None + client = self._get_client() + collection_name = ( + f"memory_{self.type}_{self.agents}" + if self.agents + else f"memory_{self.type}" + ) + client.delete_collection(collection_name=collection_name) except Exception as e: - if "attempt to write a readonly database" in str(e): - # Ignore this specific error + if "attempt to write a readonly database" in str( + e + ) or "does not exist" in str(e): + # Ignore readonly database and collection not found errors (already reset) pass else: raise Exception( f"An error occurred while resetting the {self.type} memory: {e}" - ) - - def _create_default_embedding_function(self): - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" - ) + ) from e diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py index ca3ae62b8..3a9d140d4 100644 --- a/src/crewai/rag/chromadb/client.py +++ b/src/crewai/rag/chromadb/client.py @@ -4,8 +4,9 @@ import logging from typing import Any from chromadb.api.types import ( - Embeddable, EmbeddingFunction as ChromaEmbeddingFunction, +) +from chromadb.api.types import ( QueryResult, ) from typing_extensions import Unpack @@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import ( _process_query_results, _sanitize_collection_name, ) -from crewai.utilities.logger_utils import suppress_logging from crewai.rag.core.base_client import ( BaseClient, - BaseCollectionParams, BaseCollectionAddParams, + BaseCollectionParams, ) from crewai.rag.types import SearchResult +from crewai.utilities.logger_utils import suppress_logging class ChromaDBClient(BaseClient): @@ -46,7 +47,7 @@ class ChromaDBClient(BaseClient): def __init__( self, client: ChromaDBClientType, - embedding_function: ChromaEmbeddingFunction[Embeddable], + embedding_function: ChromaEmbeddingFunction, ) -> None: """Initialize ChromaDBClient with client and embedding function. @@ -306,10 +307,12 @@ class ChromaDBClient(BaseClient): ) prepared = _prepare_documents_for_chromadb(documents) + # ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty + metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None collection.upsert( ids=prepared.ids, documents=prepared.texts, - metadatas=prepared.metadatas, + metadatas=metadatas, ) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: @@ -347,10 +350,12 @@ class ChromaDBClient(BaseClient): embedding_function=self.embedding_function, ) prepared = _prepare_documents_for_chromadb(documents) + # ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty + metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None await collection.upsert( ids=prepared.ids, documents=prepared.texts, - metadatas=prepared.metadatas, + metadatas=metadatas, ) def search( diff --git a/src/crewai/rag/chromadb/config.py b/src/crewai/rag/chromadb/config.py index 33a3ed9ae..033f8ff32 100644 --- a/src/crewai/rag/chromadb/config.py +++ b/src/crewai/rag/chromadb/config.py @@ -3,18 +3,18 @@ import warnings from dataclasses import field from typing import Literal, cast -from pydantic.dataclasses import dataclass as pyd_dataclass + from chromadb.config import Settings from chromadb.utils.embedding_functions import DefaultEmbeddingFunction +from pydantic.dataclasses import dataclass as pyd_dataclass -from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper -from crewai.rag.config.base import BaseRagConfig from crewai.rag.chromadb.constants import ( - DEFAULT_TENANT, DEFAULT_DATABASE, DEFAULT_STORAGE_PATH, + DEFAULT_TENANT, ) - +from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper +from crewai.rag.config.base import BaseRagConfig warnings.filterwarnings( "ignore", diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py index 60bf69131..44def6495 100644 --- a/src/crewai/rag/chromadb/factory.py +++ b/src/crewai/rag/chromadb/factory.py @@ -2,11 +2,12 @@ import os from hashlib import md5 + import portalocker from chromadb import PersistentClient -from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.client import ChromaDBClient +from crewai.rag.chromadb.config import ChromaDBConfig def create_client(config: ChromaDBConfig) -> ChromaDBClient: @@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: """ persist_dir = config.settings.persist_directory + os.makedirs(persist_dir, exist_ok=True) lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest() lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock") diff --git a/src/crewai/rag/chromadb/types.py b/src/crewai/rag/chromadb/types.py index 11c480ea3..23db5b77a 100644 --- a/src/crewai/rag/chromadb/types.py +++ b/src/crewai/rag/chromadb/types.py @@ -3,27 +3,28 @@ from collections.abc import Mapping from typing import Any, NamedTuple -from pydantic import GetCoreSchemaHandler -from pydantic_core import CoreSchema, core_schema -from chromadb.api import ClientAPI, AsyncClientAPI +from chromadb.api import AsyncClientAPI, ClientAPI from chromadb.api.configuration import CollectionConfigurationInterface from chromadb.api.types import ( CollectionMetadata, DataLoader, - Embeddable, - EmbeddingFunction as ChromaEmbeddingFunction, Include, Loadable, Where, WhereDocument, ) +from chromadb.api.types import ( + EmbeddingFunction as ChromaEmbeddingFunction, +) +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams ChromaDBClientType = ClientAPI | AsyncClientAPI -class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]): +class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction): """Base class for ChromaDB EmbeddingFunction to work with Pydantic validation.""" @classmethod @@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple): Attributes: ids: List of document IDs texts: List of document texts - metadatas: List of document metadata mappings + metadatas: List of document metadata mappings (empty dict for no metadata) """ ids: list[str] @@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False): configuration: CollectionConfigurationInterface metadata: CollectionMetadata - embedding_function: ChromaEmbeddingFunction[Embeddable] + embedding_function: ChromaEmbeddingFunction data_loader: DataLoader[Loadable] get_or_create: bool diff --git a/src/crewai/rag/chromadb/utils.py b/src/crewai/rag/chromadb/utils.py index 23f66f4c0..93865b203 100644 --- a/src/crewai/rag/chromadb/utils.py +++ b/src/crewai/rag/chromadb/utils.py @@ -5,13 +5,14 @@ from collections.abc import Mapping from typing import Literal, TypeGuard, cast from chromadb.api import AsyncClientAPI, ClientAPI +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Include, IncludeEnum, QueryResult, ) -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.models.Collection import Collection + from crewai.rag.chromadb.constants import ( DEFAULT_COLLECTION, INVALID_CHARS_PATTERN, @@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb( metadata = doc.get("metadata") if metadata: if isinstance(metadata, list): - metadatas.append(metadata[0] if metadata else {}) + metadatas.append(metadata[0] if metadata and metadata[0] else {}) else: metadatas.append(metadata) else: @@ -154,7 +155,7 @@ def _convert_chromadb_results_to_search_results( """ search_results: list[SearchResult] = [] - include_strings = [item.value for item in include] + include_strings = [item.value for item in include] if include else [] ids = results["ids"][0] if results.get("ids") else [] @@ -188,7 +189,9 @@ def _convert_chromadb_results_to_search_results( result: SearchResult = { "id": doc_id, "content": documents[i] if documents and i < len(documents) else "", - "metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {}, + "metadata": dict(metadatas[i]) + if metadatas and i < len(metadatas) and metadatas[i] is not None + else {}, "score": score, } search_results.append(result) @@ -271,7 +274,7 @@ def _sanitize_collection_name( sanitized = sanitized[:-1] + "z" if len(sanitized) < MIN_COLLECTION_LENGTH: - sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized)) + sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized)) if len(sanitized) > max_collection_length: sanitized = sanitized[:max_collection_length] if not sanitized[-1].isalnum(): diff --git a/src/crewai/rag/core/base_client.py b/src/crewai/rag/core/base_client.py index d7fb48a50..f526d2faa 100644 --- a/src/crewai/rag/core/base_client.py +++ b/src/crewai/rag/core/base_client.py @@ -1,15 +1,15 @@ """Protocol for vector database client implementations.""" from abc import abstractmethod -from typing import Any, Protocol, runtime_checkable, Annotated -from typing_extensions import Unpack, Required, TypedDict +from typing import Annotated, Any, Protocol, runtime_checkable + from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema - +from typing_extensions import Required, TypedDict, Unpack from crewai.rag.types import ( - EmbeddingFunction, BaseRecord, + EmbeddingFunction, SearchResult, ) @@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False): query: Required[str] limit: int - metadata_filter: dict[str, Any] + metadata_filter: dict[str, Any] | None score_threshold: float diff --git a/src/crewai/rag/embeddings/factory.py b/src/crewai/rag/embeddings/factory.py index ff3a78c17..0b76ef36a 100644 --- a/src/crewai/rag/embeddings/factory.py +++ b/src/crewai/rag/embeddings/factory.py @@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import ( CohereEmbeddingFunction, ) from chromadb.utils.embedding_functions.google_embedding_function import ( - GooglePalmEmbeddingFunction, GoogleGenerativeAiEmbeddingFunction, + GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction, ) from chromadb.utils.embedding_functions.huggingface_embedding_function import ( @@ -60,7 +60,7 @@ def get_embedding_function( EmbeddingFunction instance ready for use with ChromaDB Supported providers: - - openai: OpenAI embeddings (default) + - openai: OpenAI embeddings - cohere: Cohere embeddings - ollama: Ollama local embeddings - huggingface: HuggingFace embeddings @@ -77,7 +77,7 @@ def get_embedding_function( - onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB) Examples: - # Use default OpenAI with retry logic + # Use default OpenAI embedding >>> embedder = get_embedding_function() # Use Cohere with dict diff --git a/src/crewai/rag/storage/base_rag_storage.py b/src/crewai/rag/storage/base_rag_storage.py index 4ab9acb99..36b4020b7 100644 --- a/src/crewai/rag/storage/base_rag_storage.py +++ b/src/crewai/rag/storage/base_rag_storage.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any class BaseRAGStorage(ABC): @@ -13,7 +13,7 @@ class BaseRAGStorage(ABC): self, type: str, allow_reset: bool = True, - embedder_config: Optional[Dict[str, Any]] = None, + embedder_config: dict[str, Any] | None = None, crew: Any = None, ): self.type = type @@ -32,45 +32,21 @@ class BaseRAGStorage(ABC): @abstractmethod def _sanitize_role(self, role: str) -> str: """Sanitizes agent roles to ensure valid directory names.""" - pass @abstractmethod - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: """Save a value with metadata to the storage.""" - pass @abstractmethod def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, + filter: dict[str, Any] | None = None, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> list[Any]: """Search for entries in the storage.""" - pass @abstractmethod def reset(self) -> None: """Reset the storage.""" - pass - - @abstractmethod - def _generate_embedding( - self, text: str, metadata: Optional[Dict[str, Any]] = None - ) -> Any: - """Generate an embedding for the given text and metadata.""" - pass - - @abstractmethod - def _initialize_app(self): - """Initialize the vector db.""" - pass - - def setup_config(self, config: Dict[str, Any]): - """Setup the config of the storage.""" - pass - - def initialize_client(self): - """Initialize the client of the storage. This should setup the app and the db collection""" - pass diff --git a/src/crewai/utilities/chromadb.py b/src/crewai/utilities/chromadb.py deleted file mode 100644 index 60da9988d..000000000 --- a/src/crewai/utilities/chromadb.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import re -import portalocker -from chromadb import PersistentClient -from hashlib import md5 -from typing import Optional -from crewai.utilities.paths import db_storage_path - -MIN_COLLECTION_LENGTH = 3 -MAX_COLLECTION_LENGTH = 63 -DEFAULT_COLLECTION = "default_collection" - -# Compiled regex patterns for better performance -INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") -IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$") - - -def is_ipv4_pattern(name: str) -> bool: - """ - Check if a string matches an IPv4 address pattern. - - Args: - name: The string to check - - Returns: - True if the string matches an IPv4 pattern, False otherwise - """ - return bool(IPV4_PATTERN.match(name)) - - -def sanitize_collection_name( - name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH -) -> str: - """ - Sanitize a collection name to meet ChromaDB requirements: - 1. 3-63 characters long - 2. Starts and ends with alphanumeric character - 3. Contains only alphanumeric characters, underscores, or hyphens - 4. No consecutive periods - 5. Not a valid IPv4 address - - Args: - name: The original collection name to sanitize - - Returns: - A sanitized collection name that meets ChromaDB requirements - """ - if not name: - return DEFAULT_COLLECTION - - if is_ipv4_pattern(name): - name = f"ip_{name}" - - sanitized = INVALID_CHARS_PATTERN.sub("_", name) - - if not sanitized[0].isalnum(): - sanitized = "a" + sanitized - - if not sanitized[-1].isalnum(): - sanitized = sanitized[:-1] + "z" - - if len(sanitized) < MIN_COLLECTION_LENGTH: - sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized)) - if len(sanitized) > max_collection_length: - sanitized = sanitized[:max_collection_length] - if not sanitized[-1].isalnum(): - sanitized = sanitized[:-1] + "z" - - return sanitized - - -def create_persistent_client(path: str, **kwargs): - """ - Creates a persistent client for ChromaDB with a lock file to prevent - concurrent creations. Works for both multi-threads and multi-processes - environments. - """ - lock_id = md5(path.encode(), usedforsecurity=False).hexdigest() - lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock") - with portalocker.Lock(lockfile): - client = PersistentClient(path=path, **kwargs) - - return client diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index a52888069..4457f69cf 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -9,19 +9,19 @@ import pytest from crewai import Agent, Crew, Task from crewai.agents.cache import CacheHandler from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM +from crewai.process import Process from crewai.tools import tool from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_usage import ToolUsage from crewai.utilities import RPMController from crewai.utilities.errors import AgentRepositoryError -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent -from crewai.process import Process def test_agent_llm_creation_with_env_vars(): @@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_agent_powered_by_new_o_model_family_that_uses_tool(): @tool - def comapny_customer_data() -> float: + def comapny_customer_data() -> str: """Useful for getting customer related data.""" return "The company has 42 customers" @@ -559,9 +559,9 @@ def test_agent_repeated_tool_usage(capsys): expected_message = ( "I tried reusing the same input, I must stop using this action input." ) - assert ( - expected_message in output - ), f"Expected message not found in output. Output was: {output}" + assert expected_message in output, ( + f"Expected message not found in output. Output was: {output}" + ) @pytest.mark.vcr(filter_headers=["authorization"]) @@ -602,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys): has_max_iterations = "maximum iterations reached" in output_lower has_final_answer = "final answer" in output_lower or "42" in captured.out - assert ( - has_repeated_usage_message or (has_max_iterations and has_final_answer) - ), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..." + assert has_repeated_usage_message or (has_max_iterations and has_final_answer), ( + f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..." + ) @pytest.mark.vcr(filter_headers=["authorization"]) @@ -880,7 +880,7 @@ def test_agent_step_callback(): with patch.object(StepCallback, "callback") as callback: @tool - def learn_about_AI() -> str: + def learn_about_ai() -> str: """Useful for when you need to learn about AI to write an paragraph about it.""" return "AI is a very broad field." @@ -888,7 +888,7 @@ def test_agent_step_callback(): role="test role", goal="test goal", backstory="test backstory", - tools=[learn_about_AI], + tools=[learn_about_ai], step_callback=StepCallback().callback, ) @@ -910,7 +910,7 @@ def test_agent_function_calling_llm(): llm = "gpt-4o" @tool - def learn_about_AI() -> str: + def learn_about_ai() -> str: """Useful for when you need to learn about AI to write an paragraph about it.""" return "AI is a very broad field." @@ -918,7 +918,7 @@ def test_agent_function_calling_llm(): role="test role", goal="test goal", backstory="test backstory", - tools=[learn_about_AI], + tools=[learn_about_ai], llm="gpt-4o", max_iter=2, function_calling_llm=llm, @@ -1356,7 +1356,7 @@ def test_agent_training_handler(crew_training_handler): verbose=True, ) crew_training_handler().load.return_value = { - f"{str(agent.id)}": {"0": {"human_feedback": "good"}} + f"{agent.id!s}": {"0": {"human_feedback": "good"}} } result = agent._training_handler(task_prompt=task_prompt) @@ -1473,7 +1473,7 @@ def test_agent_with_custom_stop_words(): ) assert isinstance(agent.llm, LLM) - assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"]) + assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"]) assert all(word in agent.llm.stop for word in stop_words) assert "\nObservation:" in agent.llm.stop @@ -1530,7 +1530,7 @@ def test_llm_call_with_error(): llm = LLM(model="non-existent-model") messages = [{"role": "user", "content": "This should fail"}] - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 llm.call(messages) @@ -1830,11 +1830,11 @@ def test_agent_execute_task_with_ollama(): def test_agent_with_knowledge_sources(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) - with patch("crewai.knowledge") as MockKnowledge: - mock_knowledge_instance = MockKnowledge.return_value + with patch("crewai.knowledge") as mock_knowledge: + mock_knowledge_instance = mock_knowledge.return_value mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.search.return_value = [{"content": content}] - MockKnowledge.add_sources.return_value = [string_source] + mock_knowledge.add_sources.return_value = [string_source] agent = Agent( role="Information Agent", @@ -1863,12 +1863,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5) - with patch( - "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" - ) as MockKnowledge: - mock_knowledge_instance = MockKnowledge.return_value - mock_knowledge_instance.sources = [string_source] - mock_knowledge_instance.query.return_value = [{"content": content}] + with ( + patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as mock_knowledge_storage, + patch( + "crewai.knowledge.source.base_knowledge_source.KnowledgeStorage" + ) as mock_base_knowledge_storage, + patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb, + ): + mock_storage_instance = mock_knowledge_storage.return_value + mock_storage_instance.sources = [string_source] + mock_storage_instance.query.return_value = [{"content": content}] + mock_storage_instance.save.return_value = None + + mock_chromadb_instance = mock_chromadb.return_value + mock_chromadb_instance.add_documents.return_value = None + + mock_base_knowledge_storage.return_value = mock_storage_instance + with patch.object(Knowledge, "query") as mock_knowledge_query: agent = Agent( role="Information Agent", @@ -1898,15 +1911,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) knowledge_config = KnowledgeConfig() - with patch( - "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" - ) as MockKnowledge: - mock_knowledge_instance = MockKnowledge.return_value - mock_knowledge_instance.sources = [string_source] - mock_knowledge_instance.query.return_value = [{"content": content}] + + with ( + patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as mock_knowledge_storage, + patch( + "crewai.knowledge.source.base_knowledge_source.KnowledgeStorage" + ) as mock_base_knowledge_storage, + patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb, + ): + mock_storage_instance = mock_knowledge_storage.return_value + mock_storage_instance.sources = [string_source] + mock_storage_instance.query.return_value = [{"content": content}] + mock_storage_instance.save.return_value = None + + mock_chromadb_instance = mock_chromadb.return_value + mock_chromadb_instance.add_documents.return_value = None + + mock_base_knowledge_storage.return_value = mock_storage_instance + with patch.object(Knowledge, "query") as mock_knowledge_query: - string_source = StringKnowledgeSource(content=content) - knowledge_config = KnowledgeConfig() agent = Agent( role="Information Agent", goal="Provide information based on knowledge sources", @@ -1935,10 +1960,16 @@ def test_agent_with_knowledge_sources_extensive_role(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) - with patch("crewai.knowledge") as MockKnowledge: - mock_knowledge_instance = MockKnowledge.return_value + with ( + patch("crewai.knowledge") as mock_knowledge, + patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save" + ) as mock_save, + ): + mock_knowledge_instance = mock_knowledge.return_value mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.query.return_value = [{"content": content}] + mock_save.return_value = None agent = Agent( role="Information Agent with extensive role description that is longer than 80 characters", @@ -1968,8 +1999,8 @@ def test_agent_with_knowledge_sources_works_with_copy(): with patch( "crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource", autospec=True, - ) as MockKnowledgeSource: - mock_knowledge_source_instance = MockKnowledgeSource.return_value + ) as mock_knowledge_source: + mock_knowledge_source_instance = mock_knowledge_source.return_value mock_knowledge_source_instance.__class__ = BaseKnowledgeSource mock_knowledge_source_instance.sources = [string_source] @@ -1983,9 +2014,9 @@ def test_agent_with_knowledge_sources_works_with_copy(): with patch( "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" - ) as MockKnowledgeStorage: - mock_knowledge_storage = MockKnowledgeStorage.return_value - agent.knowledge_storage = mock_knowledge_storage + ) as mock_knowledge_storage: + mock_knowledge_storage_instance = mock_knowledge_storage.return_value + agent.knowledge_storage = mock_knowledge_storage_instance agent_copy = agent.copy() @@ -2004,11 +2035,30 @@ def test_agent_with_knowledge_sources_generate_search_query(): content = "Brandon's favorite color is red and he likes Mexican food." string_source = StringKnowledgeSource(content=content) - with patch("crewai.knowledge") as MockKnowledge: - mock_knowledge_instance = MockKnowledge.return_value + with ( + patch("crewai.knowledge") as mock_knowledge, + patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as mock_knowledge_storage, + patch( + "crewai.knowledge.source.base_knowledge_source.KnowledgeStorage" + ) as mock_base_knowledge_storage, + patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb, + ): + mock_knowledge_instance = mock_knowledge.return_value mock_knowledge_instance.sources = [string_source] mock_knowledge_instance.query.return_value = [{"content": content}] + mock_storage_instance = mock_knowledge_storage.return_value + mock_storage_instance.sources = [string_source] + mock_storage_instance.query.return_value = [{"content": content}] + mock_storage_instance.save.return_value = None + + mock_chromadb_instance = mock_chromadb.return_value + mock_chromadb_instance.add_documents.return_value = None + + mock_base_knowledge_storage.return_value = mock_storage_instance + agent = Agent( role="Information Agent with extensive role description that is longer than 80 characters", goal="Provide information based on knowledge sources", @@ -2270,7 +2320,26 @@ def test_get_knowledge_search_query(): i18n = I18N() task_prompt = task.prompt() - with patch.object(agent, "_get_knowledge_search_query") as mock_get_query: + with ( + patch( + "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" + ) as mock_knowledge_storage, + patch( + "crewai.knowledge.source.base_knowledge_source.KnowledgeStorage" + ) as mock_base_knowledge_storage, + patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb, + patch.object(agent, "_get_knowledge_search_query") as mock_get_query, + ): + mock_storage_instance = mock_knowledge_storage.return_value + mock_storage_instance.sources = [string_source] + mock_storage_instance.query.return_value = [{"content": content}] + mock_storage_instance.save.return_value = None + + mock_chromadb_instance = mock_chromadb.return_value + mock_chromadb_instance.add_documents.return_value = None + + mock_base_knowledge_storage.return_value = mock_storage_instance + mock_get_query.return_value = "Capital of France" crew = Crew(agents=[agent], tasks=[task]) @@ -2312,9 +2381,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token): # Mock embedchain initialization to prevent race conditions in parallel CI execution with patch("embedchain.client.Client.setup"): from crewai_tools import ( - SerperDevTool, - FileReadTool, EnterpriseActionTool, + FileReadTool, + SerperDevTool, ) mock_get_response = MagicMock() @@ -2347,7 +2416,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token): tool_action = EnterpriseActionTool( name="test_name", description="test_description", - enterprise_action_token="test_token", + enterprise_action_token="test_token", # noqa: S106 action_name="test_action_name", action_schema={"test": "test"}, ) diff --git a/tests/knowledge/test_knowledge.py b/tests/knowledge/test_knowledge.py index 9cfc2bf53..67c2d68b0 100644 --- a/tests/knowledge/test_knowledge.py +++ b/tests/knowledge/test_knowledge.py @@ -1,7 +1,6 @@ """Test Knowledge creation and querying functionality.""" from pathlib import Path -from typing import List, Union from unittest.mock import patch import pytest @@ -23,7 +22,7 @@ def mock_vector_db(): instance = mock.return_value instance.query.return_value = [ { - "context": "Brandon's favorite color is blue and he likes Mexican food.", + "content": "Brandon's favorite color is blue and he likes Mexican food.", "score": 0.9, } ] @@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db): content=content, metadata={"preference": "personal"} ) mock_vector_db.sources = [string_source] - mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] + mock_vector_db.query.return_value = [{"content": content, "score": 0.9}] # Perform a query query = "What is Brandon's favorite color?" results = mock_vector_db.query(query) # Assert that the results contain the expected information - assert any("blue" in result["context"].lower() for result in results) + assert any("blue" in result["content"].lower() for result in results) # Verify the mock was called mock_vector_db.query.assert_called_once() @@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db): content=content, metadata={"preference": "personal"} ) mock_vector_db.sources = [string_source] - mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] + mock_vector_db.query.return_value = [{"content": content, "score": 0.9}] # Perform a query query = "What is Brandon's favorite movie?" results = mock_vector_db.query(query) # Assert that the results contain the expected information - assert any("inception" in result["context"].lower() for result in results) + assert any("inception" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db): # Mock the vector db query response mock_vector_db.query.return_value = [ - {"context": "Brandon has a dog named Max.", "score": 0.9} + {"content": "Brandon has a dog named Max.", "score": 0.9} ] mock_vector_db.sources = string_sources @@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db): results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("max" in result["context"].lower() for result in results) + assert any("max" in result["content"].lower() for result in results) # Verify the mock was called mock_vector_db.query.assert_called_once() @@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db): ] mock_vector_db.sources = string_sources - mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}] + mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}] # Perform a query query = "What is Brandon's favorite book?" @@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db): # Assert that the correct information is retrieved assert any( - "the hitchhiker's guide to the galaxy" in result["context"].lower() + "the hitchhiker's guide to the galaxy" in result["content"].lower() for result in results ) mock_vector_db.query.assert_called_once() @@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir): file_paths=[file_path], metadata={"preference": "personal"} ) mock_vector_db.sources = [file_source] - mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] + mock_vector_db.query.return_value = [{"content": content, "score": 0.9}] # Perform a query query = "What sport does Brandon like?" results = mock_vector_db.query(query) # Assert that the results contain the expected information - assert any("basketball" in result["context"].lower() for result in results) + assert any("basketball" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir): file_paths=[file_path], metadata={"preference": "personal"} ) mock_vector_db.sources = [file_source] - mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] + mock_vector_db.query.return_value = [{"content": content, "score": 0.9}] # Perform a query query = "What is Brandon's favorite movie?" results = mock_vector_db.query(query) # Assert that the results contain the expected information - assert any("inception" in result["context"].lower() for result in results) + assert any("inception" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir): ] mock_vector_db.sources = file_sources mock_vector_db.query.return_value = [ - {"context": "Brandon lives in New York.", "score": 0.9} + {"content": "Brandon lives in New York.", "score": 0.9} ] # Perform a query query = "What city does he reside in?" results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("new york" in result["context"].lower() for result in results) + assert any("new york" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir): mock_vector_db.sources = file_sources mock_vector_db.query.return_value = [ { - "context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.", + "content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.", "score": 0.9, } ] @@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir): # Assert that the correct information is retrieved assert any( - "the hitchhiker's guide to the galaxy" in result["context"].lower() + "the hitchhiker's guide to the galaxy" in result["content"].lower() for result in results ) mock_vector_db.query.assert_called_once() @@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir): # Combine string and file sources mock_vector_db.sources = string_sources + file_sources - mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}] + mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}] # Perform a query query = "What is Brandon's favorite book?" results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("the alchemist" in result["context"].lower() for result in results) + assert any("the alchemist" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db): ) mock_vector_db.sources = [pdf_source] mock_vector_db.query.return_value = [ - {"context": "crewai create crew latest-ai-development", "score": 0.9} + {"content": "crewai create crew latest-ai-development", "score": 0.9} ] # Perform a query @@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db): # Assert that the correct information is retrieved assert any( - "crewai create crew latest-ai-development" in result["context"].lower() + "crewai create crew latest-ai-development" in result["content"].lower() for result in results ) mock_vector_db.query.assert_called_once() @@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir): ) mock_vector_db.sources = [csv_source] mock_vector_db.query.return_value = [ - {"context": "Brandon is 30 years old.", "score": 0.9} + {"content": "Brandon is 30 years old.", "score": 0.9} ] # Perform a query @@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir): results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("30" in result["context"] for result in results) + assert any("30" in result["content"] for result in results) mock_vector_db.query.assert_called_once() @@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir): ) mock_vector_db.sources = [json_source] mock_vector_db.query.return_value = [ - {"context": "Alice lives in Los Angeles.", "score": 0.9} + {"content": "Alice lives in Los Angeles.", "score": 0.9} ] # Perform a query @@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir): results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("los angeles" in result["context"].lower() for result in results) + assert any("los angeles" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir): """Test ExcelKnowledgeSource with a simple Excel file.""" # Create an Excel file with sample data - import pandas as pd + import pandas as pd # type: ignore[import-untyped] excel_data = { "Name": ["Brandon", "Alice", "Bob"], @@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir): ) mock_vector_db.sources = [excel_source] mock_vector_db.query.return_value = [ - {"context": "Brandon is 30 years old.", "score": 0.9} + {"content": "Brandon is 30 years old.", "score": 0.9} ] # Perform a query @@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir): results = mock_vector_db.query(query) # Assert that the correct information is retrieved - assert any("30" in result["context"] for result in results) + assert any("30" in result["content"] for result in results) mock_vector_db.query.assert_called_once() @@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db): mock_vector_db.sources = [docling_source] mock_vector_db.query.return_value = [ { - "context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.", + "content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.", "score": 0.9, } ] # Perform a query query = "What is reward hacking?" results = mock_vector_db.query(query) - assert any("reward hacking" in result["context"].lower() for result in results) + assert any("reward hacking" in result["content"].lower() for result in results) mock_vector_db.query.assert_called_once() @pytest.mark.vcr -def test_multiple_docling_sources(): - urls: List[Union[Path, str]] = [ +def test_multiple_docling_sources() -> None: + urls: list[Path | str] = [ "https://lilianweng.github.io/posts/2024-11-28-reward-hacking/", "https://lilianweng.github.io/posts/2024-07-07-hallucination/", ] diff --git a/tests/knowledge/test_knowledge_searchresult.py b/tests/knowledge/test_knowledge_searchresult.py new file mode 100644 index 000000000..cea7c0367 --- /dev/null +++ b/tests/knowledge/test_knowledge_searchresult.py @@ -0,0 +1,191 @@ +"""Tests for Knowledge SearchResult type conversion and integration.""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped] +from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped] + StringKnowledgeSource, +) +from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped] + extract_knowledge_context, +) + + +def test_knowledge_query_returns_searchresult() -> None: + """Test that Knowledge.query returns SearchResult format.""" + with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class: + mock_storage = MagicMock() + mock_storage_class.return_value = mock_storage + mock_storage.search.return_value = [ + { + "content": "AI is fascinating", + "score": 0.9, + "metadata": {"source": "doc1"}, + }, + { + "content": "Machine learning rocks", + "score": 0.8, + "metadata": {"source": "doc2"}, + }, + ] + + sources = [StringKnowledgeSource(content="Test knowledge content")] + knowledge = Knowledge(collection_name="test_collection", sources=sources) + + results = knowledge.query( + ["AI technology"], results_limit=5, score_threshold=0.3 + ) + + mock_storage.search.assert_called_once_with( + ["AI technology"], limit=5, score_threshold=0.3 + ) + + assert isinstance(results, list) + assert len(results) == 2 + + for result in results: + assert isinstance(result, dict) + assert "content" in result + assert "score" in result + assert "metadata" in result + + assert results[0]["content"] == "AI is fascinating" + assert results[0]["score"] == 0.9 + assert results[1]["content"] == "Machine learning rocks" + assert results[1]["score"] == 0.8 + + +def test_knowledge_query_with_empty_results() -> None: + """Test Knowledge.query with empty search results.""" + with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class: + mock_storage = MagicMock() + mock_storage_class.return_value = mock_storage + mock_storage.search.return_value = [] + + sources = [StringKnowledgeSource(content="Test content")] + knowledge = Knowledge(collection_name="empty_test", sources=sources) + + results = knowledge.query(["nonexistent query"]) + + assert isinstance(results, list) + assert len(results) == 0 + + +def test_extract_knowledge_context_with_searchresult() -> None: + """Test extract_knowledge_context works with SearchResult format.""" + search_results = [ + {"content": "Python is great for AI", "score": 0.95, "metadata": {}}, + {"content": "Machine learning algorithms", "score": 0.88, "metadata": {}}, + {"content": "Deep learning frameworks", "score": 0.82, "metadata": {}}, + ] + + context = extract_knowledge_context(search_results) + + assert "Additional Information:" in context + assert "Python is great for AI" in context + assert "Machine learning algorithms" in context + assert "Deep learning frameworks" in context + + expected_content = ( + "Python is great for AI\nMachine learning algorithms\nDeep learning frameworks" + ) + assert expected_content in context + + +def test_extract_knowledge_context_with_empty_content() -> None: + """Test extract_knowledge_context handles empty or invalid content.""" + search_results = [ + {"content": "", "score": 0.5, "metadata": {}}, + {"content": None, "score": 0.4, "metadata": {}}, + {"score": 0.3, "metadata": {}}, + ] + + context = extract_knowledge_context(search_results) + + assert context == "" + + +def test_extract_knowledge_context_filters_invalid_results() -> None: + """Test that extract_knowledge_context filters out invalid results.""" + search_results: list[dict[str, Any] | None] = [ + {"content": "Valid content 1", "score": 0.9, "metadata": {}}, + {"content": "", "score": 0.8, "metadata": {}}, + {"content": "Valid content 2", "score": 0.7, "metadata": {}}, + None, + {"content": None, "score": 0.6, "metadata": {}}, + ] + + context = extract_knowledge_context(search_results) + + assert "Additional Information:" in context + assert "Valid content 1" in context + assert "Valid content 2" in context + assert context.count("\n") == 1 + + +@patch("crewai.rag.config.utils.get_rag_client") +@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") +def test_knowledge_storage_exception_handling( + mock_storage_class: MagicMock, mock_get_client: MagicMock +) -> None: + """Test Knowledge handles storage exceptions gracefully.""" + mock_storage = MagicMock() + mock_storage_class.return_value = mock_storage + mock_storage.search.side_effect = Exception("Storage error") + + sources = [StringKnowledgeSource(content="Test content")] + knowledge = Knowledge(collection_name="error_test", sources=sources) + + with pytest.raises(ValueError, match="Storage is not initialized"): + knowledge.storage = None + knowledge.query(["test query"]) + + +def test_knowledge_add_sources_integration() -> None: + """Test Knowledge.add_sources integrates properly with storage.""" + with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class: + mock_storage = MagicMock() + mock_storage_class.return_value = mock_storage + + sources = [ + StringKnowledgeSource(content="Content 1"), + StringKnowledgeSource(content="Content 2"), + ] + knowledge = Knowledge(collection_name="add_sources_test", sources=sources) + + knowledge.add_sources() + + for source in sources: + assert source.storage == mock_storage + + +def test_knowledge_reset_integration() -> None: + """Test Knowledge.reset integrates with storage.""" + with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class: + mock_storage = MagicMock() + mock_storage_class.return_value = mock_storage + + sources = [StringKnowledgeSource(content="Test content")] + knowledge = Knowledge(collection_name="reset_test", sources=sources) + + knowledge.reset() + + mock_storage.reset.assert_called_once() + + +@patch("crewai.rag.config.utils.get_rag_client") +@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage") +def test_knowledge_reset_without_storage( + mock_storage_class: MagicMock, mock_get_client: MagicMock +) -> None: + """Test Knowledge.reset raises error when storage is None.""" + sources = [StringKnowledgeSource(content="Test content")] + knowledge = Knowledge(collection_name="no_storage_test", sources=sources) + + knowledge.storage = None + + with pytest.raises(ValueError, match="Storage is not initialized"): + knowledge.reset() diff --git a/tests/knowledge/test_knowledge_storage_integration.py b/tests/knowledge/test_knowledge_storage_integration.py new file mode 100644 index 000000000..0f9581864 --- /dev/null +++ b/tests/knowledge/test_knowledge_storage_integration.py @@ -0,0 +1,196 @@ +"""Integration tests for KnowledgeStorage RAG client migration.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped] + KnowledgeStorage, +) + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +@patch("crewai.knowledge.storage.knowledge_storage.create_client") +@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function") +def test_knowledge_storage_uses_rag_client( + mock_get_embedding: MagicMock, + mock_create_client: MagicMock, + mock_get_client: MagicMock, +) -> None: + """Test that KnowledgeStorage properly integrates with RAG client.""" + mock_client = MagicMock() + mock_create_client.return_value = mock_client + mock_get_client.return_value = mock_client + mock_client.search.return_value = [ + {"content": "test content", "score": 0.9, "metadata": {"source": "test"}} + ] + + embedder_config = {"provider": "openai", "model": "text-embedding-3-small"} + storage = KnowledgeStorage( + embedder=embedder_config, collection_name="test_knowledge" + ) + + mock_create_client.assert_called_once() + + results = storage.search(["test query"], limit=5, score_threshold=0.3) + + mock_get_client.assert_not_called() + mock_client.search.assert_called_once_with( + collection_name="knowledge_test_knowledge", + query="test query", + limit=5, + metadata_filter=None, + score_threshold=0.3, + ) + + assert isinstance(results, list) + assert len(results) == 1 + assert isinstance(results[0], dict) + assert "content" in results[0] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_collection_name_prefixing(mock_get_client: MagicMock) -> None: + """Test that collection names are properly prefixed.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.return_value = [] + + storage = KnowledgeStorage(collection_name="custom_knowledge") + storage.search(["test"], limit=1) + + mock_client.search.assert_called_once() + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["collection_name"] == "knowledge_custom_knowledge" + + mock_client.reset_mock() + storage_default = KnowledgeStorage() + storage_default.search(["test"], limit=1) + + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["collection_name"] == "knowledge" + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_save_documents_integration(mock_get_client: MagicMock) -> None: + """Test document saving through RAG client.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + storage = KnowledgeStorage(collection_name="test_docs") + documents = ["Document 1 content", "Document 2 content"] + + storage.save(documents) + + mock_client.get_or_create_collection.assert_called_once_with( + collection_name="knowledge_test_docs" + ) + mock_client.add_documents.assert_called_once() + + call_kwargs = mock_client.add_documents.call_args.kwargs + added_docs = call_kwargs["documents"] + assert len(added_docs) == 2 + assert added_docs[0]["content"] == "Document 1 content" + assert added_docs[1]["content"] == "Document 2 content" + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_reset_integration(mock_get_client: MagicMock) -> None: + """Test collection reset through RAG client.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + storage = KnowledgeStorage(collection_name="test_reset") + storage.reset() + + mock_client.delete_collection.assert_called_once_with( + collection_name="knowledge_test_reset" + ) + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_search_error_handling(mock_get_client: MagicMock) -> None: + """Test error handling during search operations.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.side_effect = Exception("RAG client error") + + storage = KnowledgeStorage(collection_name="error_test") + + results = storage.search(["test query"]) + assert results == [] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function") +def test_embedding_configuration_flow( + mock_get_embedding: MagicMock, mock_get_client: MagicMock +) -> None: + """Test that embedding configuration flows properly to RAG client.""" + mock_embedding_func = MagicMock() + mock_get_embedding.return_value = mock_embedding_func + mock_get_client.return_value = MagicMock() + + embedder_config = { + "provider": "sentence-transformer", + "model_name": "all-MiniLM-L6-v2", + } + + KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test") + + mock_get_embedding.assert_called_once_with(embedder_config) + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_query_list_conversion(mock_get_client: MagicMock) -> None: + """Test that query list is properly converted to string.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.return_value = [] + + storage = KnowledgeStorage() + + storage.search(["single query"]) + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["query"] == "single query" + + mock_client.reset_mock() + storage.search(["query one", "query two"]) + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["query"] == "query one query two" + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_metadata_filter_handling(mock_get_client: MagicMock) -> None: + """Test metadata filter parameter handling.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.return_value = [] + + storage = KnowledgeStorage() + + metadata_filter = {"category": "technical", "priority": "high"} + storage.search(["test"], metadata_filter=metadata_filter) + + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["metadata_filter"] == metadata_filter + + mock_client.reset_mock() + storage.search(["test"], metadata_filter=None) + + call_kwargs = mock_client.search.call_args.kwargs + assert call_kwargs["metadata_filter"] is None + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None: + """Test specific handling of dimension mismatch errors.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.get_or_create_collection.return_value = None + mock_client.add_documents.side_effect = Exception("dimension mismatch detected") + + storage = KnowledgeStorage(collection_name="dimension_test") + + with pytest.raises(ValueError, match="Embedding dimension mismatch"): + storage.save(["test document"]) diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index b87077094..18dc28fa8 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -1,19 +1,20 @@ -from unittest.mock import patch, ANY from collections import defaultdict +from unittest.mock import ANY, patch + import pytest from crewai.agent import Agent from crewai.crew import Crew +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveStartedEvent, +) from crewai.memory.short_term.short_term_memory import ShortTermMemory from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.task import Task -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.memory_events import ( - MemorySaveStartedEvent, - MemorySaveCompletedEvent, - MemoryQueryStartedEvent, - MemoryQueryCompletedEvent, -) @pytest.fixture @@ -38,22 +39,23 @@ def short_term_memory(): def test_short_term_memory_search_events(short_term_memory): events = defaultdict(list) - with crewai_event_bus.scoped_handlers(): + with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]): + with crewai_event_bus.scoped_handlers(): - @crewai_event_bus.on(MemoryQueryStartedEvent) - def on_search_started(source, event): - events["MemoryQueryStartedEvent"].append(event) + @crewai_event_bus.on(MemoryQueryStartedEvent) + def on_search_started(source, event): + events["MemoryQueryStartedEvent"].append(event) - @crewai_event_bus.on(MemoryQueryCompletedEvent) - def on_search_completed(source, event): - events["MemoryQueryCompletedEvent"].append(event) + @crewai_event_bus.on(MemoryQueryCompletedEvent) + def on_search_completed(source, event): + events["MemoryQueryCompletedEvent"].append(event) - # Call the save method - short_term_memory.search( - query="test value", - limit=3, - score_threshold=0.35, - ) + # Call the save method + short_term_memory.search( + query="test value", + limit=3, + score_threshold=0.35, + ) assert len(events["MemoryQueryStartedEvent"]) == 1 assert len(events["MemoryQueryCompletedEvent"]) == 1 @@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory): expected_result = [ { - "context": memory.data, + "content": memory.data, "metadata": {"agent": "test_agent"}, "score": 0.95, } ] with patch.object(ShortTermMemory, "search", return_value=expected_result): find = short_term_memory.search("test value", score_threshold=0.01)[0] - assert find["context"] == memory.data, "Data value mismatch." + assert find["content"] == memory.data, "Data value mismatch." assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch." diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py index 88742a711..8e0cc66a1 100644 --- a/tests/rag/chromadb/test_client.py +++ b/tests/rag/chromadb/test_client.py @@ -285,6 +285,43 @@ class TestChromaDBClient: metadatas=[{"source": "test1"}, {"source": "test2"}], ) + def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None: + """Test add_documents with documents that have no metadata.""" + mock_collection = Mock() + mock_chromadb_client.get_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + {"content": "Document without metadata"}, + {"content": "Another document", "metadata": None}, + {"content": "Document with metadata", "metadata": {"key": "value"}}, + ] + + client.add_documents(collection_name="test_collection", documents=documents) + + # Verify upsert was called with empty dicts for missing metadata + mock_collection.upsert.assert_called_once() + call_args = mock_collection.upsert.call_args + assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}] + + def test_add_documents_all_without_metadata( + self, client, mock_chromadb_client + ) -> None: + """Test add_documents when all documents have no metadata.""" + mock_collection = Mock() + mock_chromadb_client.get_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + {"content": "Document 1"}, + {"content": "Document 2"}, + {"content": "Document 3"}, + ] + + client.add_documents(collection_name="test_collection", documents=documents) + + mock_collection.upsert.assert_called_once() + call_args = mock_collection.upsert.call_args + assert call_args[1]["metadatas"] is None + def test_add_documents_empty_list_raises_error( self, client, mock_chromadb_client ) -> None: @@ -358,6 +395,31 @@ class TestChromaDBClient: metadatas=[{"source": "test1"}, {"source": "test2"}], ) + @pytest.mark.asyncio + async def test_aadd_documents_without_metadata( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test aadd_documents with documents that have no metadata.""" + mock_collection = AsyncMock() + mock_async_chromadb_client.get_collection = AsyncMock( + return_value=mock_collection + ) + + documents: list[BaseRecord] = [ + {"content": "Document without metadata"}, + {"content": "Another document", "metadata": None}, + {"content": "Document with metadata", "metadata": {"key": "value"}}, + ] + + await async_client.aadd_documents( + collection_name="test_collection", documents=documents + ) + + # Verify upsert was called with empty dicts for missing metadata + mock_collection.upsert.assert_called_once() + call_args = mock_collection.upsert.call_args + assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}] + @pytest.mark.asyncio async def test_aadd_documents_empty_list_raises_error( self, async_client, mock_async_chromadb_client diff --git a/tests/rag/chromadb/test_utils.py b/tests/rag/chromadb/test_utils.py new file mode 100644 index 000000000..ac7a8f5a9 --- /dev/null +++ b/tests/rag/chromadb/test_utils.py @@ -0,0 +1,95 @@ +"""Tests for ChromaDB utility functions.""" + +from crewai.rag.chromadb.utils import ( + MAX_COLLECTION_LENGTH, + MIN_COLLECTION_LENGTH, + _is_ipv4_pattern, + _sanitize_collection_name, +) + + +class TestChromaDBUtils: + """Test suite for ChromaDB utility functions.""" + + def test_sanitize_collection_name_long_name(self) -> None: + """Test sanitizing a very long collection name.""" + long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name" + sanitized = _sanitize_collection_name(long_name) + assert len(sanitized) <= MAX_COLLECTION_LENGTH + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + assert all(c.isalnum() or c in ["_", "-"] for c in sanitized) + + def test_sanitize_collection_name_special_chars(self) -> None: + """Test sanitizing a name with special characters.""" + special_chars = "Agent@123!#$%^&*()" + sanitized = _sanitize_collection_name(special_chars) + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + assert all(c.isalnum() or c in ["_", "-"] for c in sanitized) + + def test_sanitize_collection_name_short_name(self) -> None: + """Test sanitizing a very short name.""" + short_name = "A" + sanitized = _sanitize_collection_name(short_name) + assert len(sanitized) >= MIN_COLLECTION_LENGTH + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + + def test_sanitize_collection_name_bad_ends(self) -> None: + """Test sanitizing a name with non-alphanumeric start/end.""" + bad_ends = "_Agent_" + sanitized = _sanitize_collection_name(bad_ends) + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + + def test_sanitize_collection_name_none(self) -> None: + """Test sanitizing a None value.""" + sanitized = _sanitize_collection_name(None) + assert sanitized == "default_collection" + + def test_sanitize_collection_name_ipv4_pattern(self) -> None: + """Test sanitizing an IPv4 address.""" + ipv4 = "192.168.1.1" + sanitized = _sanitize_collection_name(ipv4) + assert sanitized.startswith("ip_") + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + assert all(c.isalnum() or c in ["_", "-"] for c in sanitized) + + def test_is_ipv4_pattern(self) -> None: + """Test IPv4 pattern detection.""" + assert _is_ipv4_pattern("192.168.1.1") is True + assert _is_ipv4_pattern("not.an.ip.address") is False + + def test_sanitize_collection_name_properties(self) -> None: + """Test that sanitized collection names always meet ChromaDB requirements.""" + test_cases: list[str] = [ + "A" * 100, # Very long name + "_start_with_underscore", + "end_with_underscore_", + "contains@special#characters", + "192.168.1.1", # IPv4 address + "a" * 2, # Too short + ] + for test_case in test_cases: + sanitized = _sanitize_collection_name(test_case) + assert len(sanitized) >= MIN_COLLECTION_LENGTH + assert len(sanitized) <= MAX_COLLECTION_LENGTH + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() + + def test_sanitize_collection_name_empty_string(self) -> None: + """Test sanitizing an empty string.""" + sanitized = _sanitize_collection_name("") + assert sanitized == "default_collection" + + def test_sanitize_collection_name_whitespace_only(self) -> None: + """Test sanitizing a string with only whitespace.""" + sanitized = _sanitize_collection_name(" ") + assert ( + sanitized == "a__z" + ) # Spaces become underscores, padded to meet requirements + assert len(sanitized) >= MIN_COLLECTION_LENGTH + assert sanitized[0].isalnum() + assert sanitized[-1].isalnum() diff --git a/tests/rag/embeddings/test_factory_enhanced.py b/tests/rag/embeddings/test_factory_enhanced.py new file mode 100644 index 000000000..489064826 --- /dev/null +++ b/tests/rag/embeddings/test_factory_enhanced.py @@ -0,0 +1,250 @@ +"""Enhanced tests for embedding function factory.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped] + get_embedding_function, +) +from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped] + + +def test_get_embedding_function_default() -> None: + """Test default embedding function when no config provided.""" + with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: + mock_instance = MagicMock() + mock_openai.return_value = mock_instance + + with patch( + "crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key" + ): + result = get_embedding_function() + + mock_openai.assert_called_once_with( + api_key="test-api-key", model_name="text-embedding-3-small" + ) + assert result == mock_instance + + +def test_get_embedding_function_with_embedding_options() -> None: + """Test embedding function creation with EmbeddingOptions object.""" + with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: + mock_instance = MagicMock() + mock_openai.return_value = mock_instance + + options = EmbeddingOptions( + provider="openai", api_key="test-key", model="text-embedding-3-large" + ) + + result = get_embedding_function(options) + + call_kwargs = mock_openai.call_args.kwargs + assert "api_key" in call_kwargs + assert call_kwargs["api_key"].get_secret_value() == "test-key" + # OpenAI uses model_name parameter, not model + assert result == mock_instance + + +def test_get_embedding_function_sentence_transformer() -> None: + """Test sentence transformer embedding function.""" + with patch( + "crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction" + ) as mock_st: + mock_instance = MagicMock() + mock_st.return_value = mock_instance + + config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"} + + result = get_embedding_function(config) + + mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2") + assert result == mock_instance + + +def test_get_embedding_function_ollama() -> None: + """Test Ollama embedding function.""" + with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama: + mock_instance = MagicMock() + mock_ollama.return_value = mock_instance + + config = { + "provider": "ollama", + "model_name": "nomic-embed-text", + "url": "http://localhost:11434", + } + + result = get_embedding_function(config) + + mock_ollama.assert_called_once_with( + model_name="nomic-embed-text", url="http://localhost:11434" + ) + assert result == mock_instance + + +def test_get_embedding_function_cohere() -> None: + """Test Cohere embedding function.""" + with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere: + mock_instance = MagicMock() + mock_cohere.return_value = mock_instance + + config = { + "provider": "cohere", + "api_key": "cohere-key", + "model_name": "embed-english-v3.0", + } + + result = get_embedding_function(config) + + mock_cohere.assert_called_once_with( + api_key="cohere-key", model_name="embed-english-v3.0" + ) + assert result == mock_instance + + +def test_get_embedding_function_huggingface() -> None: + """Test HuggingFace embedding function.""" + with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf: + mock_instance = MagicMock() + mock_hf.return_value = mock_instance + + config = { + "provider": "huggingface", + "api_key": "hf-token", + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + } + + result = get_embedding_function(config) + + mock_hf.assert_called_once_with( + api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2" + ) + assert result == mock_instance + + +def test_get_embedding_function_onnx() -> None: + """Test ONNX embedding function.""" + with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx: + mock_instance = MagicMock() + mock_onnx.return_value = mock_instance + + config = {"provider": "onnx"} + + result = get_embedding_function(config) + + mock_onnx.assert_called_once() + assert result == mock_instance + + +def test_get_embedding_function_google_palm() -> None: + """Test Google PaLM embedding function.""" + with patch( + "crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction" + ) as mock_palm: + mock_instance = MagicMock() + mock_palm.return_value = mock_instance + + config = {"provider": "google-palm", "api_key": "palm-key"} + + result = get_embedding_function(config) + + mock_palm.assert_called_once_with(api_key="palm-key") + assert result == mock_instance + + +def test_get_embedding_function_amazon_bedrock() -> None: + """Test Amazon Bedrock embedding function.""" + with patch( + "crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction" + ) as mock_bedrock: + mock_instance = MagicMock() + mock_bedrock.return_value = mock_instance + + config = { + "provider": "amazon-bedrock", + "region_name": "us-west-2", + "model_name": "amazon.titan-embed-text-v1", + } + + result = get_embedding_function(config) + + mock_bedrock.assert_called_once_with( + region_name="us-west-2", model_name="amazon.titan-embed-text-v1" + ) + assert result == mock_instance + + +def test_get_embedding_function_jina() -> None: + """Test Jina embedding function.""" + with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina: + mock_instance = MagicMock() + mock_jina.return_value = mock_instance + + config = { + "provider": "jina", + "api_key": "jina-key", + "model_name": "jina-embeddings-v2-base-en", + } + + result = get_embedding_function(config) + + mock_jina.assert_called_once_with( + api_key="jina-key", model_name="jina-embeddings-v2-base-en" + ) + assert result == mock_instance + + +def test_get_embedding_function_unsupported_provider() -> None: + """Test handling of unsupported provider.""" + config = {"provider": "unsupported-provider"} + + with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"): + get_embedding_function(config) + + +def test_get_embedding_function_config_modification() -> None: + """Test that original config dict is not modified.""" + original_config = { + "provider": "openai", + "api_key": "test-key", + "model": "text-embedding-3-small", + } + config_copy = original_config.copy() + + with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"): + get_embedding_function(config_copy) + + assert config_copy == original_config + + +def test_get_embedding_function_exclude_none_values() -> None: + """Test that None values are excluded from embedding function calls.""" + with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: + mock_instance = MagicMock() + mock_openai.return_value = mock_instance + + options = EmbeddingOptions(provider="openai", api_key="test-key", model=None) + + result = get_embedding_function(options) + + call_kwargs = mock_openai.call_args.kwargs + assert "api_key" in call_kwargs + assert call_kwargs["api_key"].get_secret_value() == "test-key" + assert "model" not in call_kwargs + assert result == mock_instance + + +def test_get_embedding_function_instructor() -> None: + """Test Instructor embedding function.""" + with patch( + "crewai.rag.embeddings.factory.InstructorEmbeddingFunction" + ) as mock_instructor: + mock_instance = MagicMock() + mock_instructor.return_value = mock_instance + + config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"} + + result = get_embedding_function(config) + + mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large") + assert result == mock_instance diff --git a/tests/rag/test_error_handling.py b/tests/rag/test_error_handling.py new file mode 100644 index 000000000..ef2c8f7d5 --- /dev/null +++ b/tests/rag/test_error_handling.py @@ -0,0 +1,218 @@ +"""Tests for RAG client error handling scenarios.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped] + KnowledgeStorage, +) +from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles RAG client connection failures.""" + mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB") + + storage = KnowledgeStorage(collection_name="connection_test") + + results = storage.search(["test query"]) + assert results == [] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles search timeouts gracefully.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.side_effect = TimeoutError("Search operation timed out") + + storage = KnowledgeStorage(collection_name="timeout_test") + + results = storage.search(["test query"]) + assert results == [] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles missing collections.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.side_effect = ValueError( + "Collection 'knowledge_missing' does not exist" + ) + + storage = KnowledgeStorage(collection_name="missing_collection") + + results = storage.search(["test query"]) + assert results == [] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles invalid embedding configurations.""" + mock_get_client.return_value = MagicMock() + + with patch( + "crewai.knowledge.storage.knowledge_storage.get_embedding_function" + ) as mock_get_embedding: + mock_get_embedding.side_effect = ValueError( + "Unsupported provider: invalid_provider" + ) + + with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"): + KnowledgeStorage( + embedder={"provider": "invalid_provider"}, + collection_name="invalid_embedding_test", + ) + + +@patch("crewai.memory.storage.rag_storage.get_rag_client") +def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None: + """Test RAGStorage handles RAG client failures in memory operations.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.side_effect = RuntimeError("ChromaDB server error") + + storage = RAGStorage("short_term", crew=None) + + results = storage.search("test query") + assert results == [] + + +@patch("crewai.memory.storage.rag_storage.get_rag_client") +def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None: + """Test RAGStorage handles save operation failures.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.add_documents.side_effect = Exception("Failed to add documents") + + storage = RAGStorage("long_term", crew=None) + + storage.save("test memory", {"key": "value"}) + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage reset handles readonly database errors.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.delete_collection.side_effect = Exception( + "attempt to write a readonly database" + ) + + storage = KnowledgeStorage(collection_name="readonly_test") + + storage.reset() + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_reset_collection_does_not_exist( + mock_get_client: MagicMock, +) -> None: + """Test KnowledgeStorage reset handles non-existent collections.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.delete_collection.side_effect = Exception("Collection does not exist") + + storage = KnowledgeStorage(collection_name="nonexistent_test") + + storage.reset() + + +@patch("crewai.memory.storage.rag_storage.get_rag_client") +def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None: + """Test RAGStorage reset propagates unexpected errors.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.delete_collection.side_effect = Exception("Unexpected database error") + + storage = RAGStorage("entities", crew=None) + + with pytest.raises( + Exception, match="An error occurred while resetting the entities memory" + ): + storage.reset() + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles malformed search results.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.search.return_value = [ + {"content": "valid result", "metadata": {"source": "test"}}, + {"invalid": "missing content field", "metadata": {"source": "test"}}, + None, + {"content": None, "metadata": {"source": "test"}}, + ] + + storage = KnowledgeStorage(collection_name="malformed_test") + + results = storage.search(["test query"]) + + assert isinstance(results, list) + assert len(results) == 4 + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None: + """Test KnowledgeStorage handles network interruptions during operations.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_client.search.side_effect = [ + ConnectionError("Network interruption"), + [{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}], + ] + + storage = KnowledgeStorage(collection_name="network_test") + + first_attempt = storage.search(["test query"]) + assert first_attempt == [] + + mock_client.search.side_effect = None + mock_client.search.return_value = [ + {"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}} + ] + + second_attempt = storage.search(["test query"]) + assert len(second_attempt) == 1 + assert second_attempt[0]["content"] == "recovered result" + + +@patch("crewai.memory.storage.rag_storage.get_rag_client") +def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None: + """Test RAGStorage handles collection creation failures.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.get_or_create_collection.side_effect = Exception( + "Failed to create collection" + ) + + storage = RAGStorage("user_memory", crew=None) + + storage.save("test data", {"metadata": "test"}) + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_knowledge_storage_embedding_dimension_mismatch_detailed( + mock_get_client: MagicMock, +) -> None: + """Test detailed handling of embedding dimension mismatch errors.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_client.get_or_create_collection.return_value = None + mock_client.add_documents.side_effect = Exception( + "Embedding dimension mismatch: expected 384, got 1536" + ) + + storage = KnowledgeStorage(collection_name="dimension_detailed_test") + + with pytest.raises(ValueError) as exc_info: + storage.save(["test document"]) + + assert "Embedding dimension mismatch" in str(exc_info.value) + assert "Make sure you're using the same embedding model" in str(exc_info.value) + assert "crewai reset-memories -a" in str(exc_info.value) diff --git a/tests/storage/test_mem0_storage.py b/tests/storage/test_mem0_storage.py index dae93e39e..11cfddb3a 100644 --- a/tests/storage/test_mem0_storage.py +++ b/tests/storage/test_mem0_storage.py @@ -1,8 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from mem0.client.main import MemoryClient -from mem0.memory.main import Memory +from mem0 import Memory, MemoryClient from crewai.memory.storage.mem0_storage import Mem0Storage @@ -13,6 +12,67 @@ class MockCrew: self.agents = [MagicMock(role="Test Agent")] +# Test data constants +SYSTEM_CONTENT = ( + "You are Friendly chatbot assistant. You are a kind and " + "knowledgeable chatbot assistant. You excel at understanding user needs, " + "providing helpful responses, and maintaining engaging conversations. " + "You remember previous interactions to provide a personalized experience.\n" + "Your personal goal is: Engage in useful and interesting conversations " + "with users while remembering context.\n" + "To give my best complete final answer to the task respond using the exact " + "following format:\n\n" + "Thought: I now can give a great answer\n" + "Final Answer: Your final answer must be the great and the most complete " + "as possible, it must be outcome described.\n\n" + "I MUST use these formats, my job depends on it!" +) + +USER_CONTENT = ( + "\nCurrent Task: Respond to user conversation. User message: " + "What do you know about me?\n\n" + "This is the expected criteria for your final answer: Contextually " + "appropriate, helpful, and friendly response.\n" + "you MUST return the actual complete content as the final answer, " + "not a summary.\n\n" + "# Useful context: \nExternal memories:\n" + "- User is from India\n" + "- User is interested in the solar system\n" + "- User name is Vidit Ostwal\n" + "- User is interested in French cuisine\n\n" + "Begin! This is VERY important to you, use the tools available and give " + "your best Final Answer, your job depends on it!\n\n" + "Thought:" +) + +ASSISTANT_CONTENT = ( + "I now can give a great answer \n" + "Final Answer: Hi Vidit! From our previous conversations, I know you're " + "from India and have a great interest in the solar system. It's fascinating " + "to explore the wonders of space, isn't it? Also, I remember you have a " + "passion for French cuisine, which has so many delightful dishes to explore. " + "If there's anything specific you'd like to discuss or learn about—whether " + "it's about the solar system or some great French recipes—feel free to let " + "me know! I'm here to help." +) + +TEST_DESCRIPTION = ( + "Respond to user conversation. User message: What do you know about me?" +) + +# Extracted content (after processing by _get_user_message and _get_assistant_message) +EXTRACTED_USER_CONTENT = "What do you know about me?" +EXTRACTED_ASSISTANT_CONTENT = ( + "Hi Vidit! From our previous conversations, I know you're " + "from India and have a great interest in the solar system. It's fascinating " + "to explore the wonders of space, isn't it? Also, I remember you have a " + "passion for French cuisine, which has so many delightful dishes to explore. " + "If there's anything specific you'd like to discuss or learn about—whether " + "it's about the solar system or some great French recipes—feel free to let " + "me know! I'm here to help." +) + + @pytest.fixture def mock_mem0_memory(): """Fixture to create a mock Memory instance""" @@ -24,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory): """Fixture to create a Mem0Storage instance with mocked dependencies""" # Patch the Memory class to return our mock - with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config: + with patch( + "mem0.Memory.from_config", return_value=mock_mem0_memory + ) as mock_from_config: config = { "vector_store": { "provider": "mock_vector_store", @@ -55,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory): # Parameters like run_id, includes, and excludes doesn't matter in Memory OSS crew = MockCrew() - embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True} + embedder_config = { + "user_id": "test_user", + "local_mem0_config": config, + "run_id": "my_run_id", + "includes": "include1", + "excludes": "exclude1", + "infer": True, + } mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config) return mem0_storage, mock_from_config, config @@ -83,28 +152,31 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): crew = MockCrew() - embedder_config={ - "user_id": "test_user", - "api_key": "ABCDEFGH", - "org_id": "my_org_id", - "project_id": "my_project_id", - "run_id": "my_run_id", - "includes": "include1", - "excludes": "exclude1", - "infer": True - } + embedder_config = { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "org_id": "my_org_id", + "project_id": "my_project_id", + "run_id": "my_run_id", + "includes": "include1", + "excludes": "exclude1", + "infer": True, + } return Mem0Storage(type="short_term", crew=crew, config=embedder_config) @pytest.fixture -def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory): +def mem0_storage_with_memory_client_using_explictly_config( + mock_mem0_memory_client, mock_mem0_memory +): """Fixture to create a Mem0Storage instance with mocked dependencies""" # We need to patch both MemoryClient and Memory to prevent actual initialization - with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \ - patch.object(Memory, "__new__", return_value=mock_mem0_memory): - + with ( + patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), + patch.object(Memory, "__new__", return_value=mock_mem0_memory), + ): crew = MockCrew() new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}} @@ -138,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl mock_mem0_memory_client.update_project = MagicMock() new_categories = [ - {"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"}, + { + "lifestyle_management_concerns": ( + "Tracks daily routines, habits, hobbies and interests " + "including cooking, time management and work-life balance" + ) + }, ] crew = MockCrew() - config={ - "user_id": "test_user", - "api_key": "ABCDEFGH", - "org_id": "my_org_id", - "project_id": "my_project_id", - "custom_categories": new_categories - } + config = { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "org_id": "my_org_id", + "project_id": "my_project_id", + "custom_categories": new_categories, + } with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): _ = Mem0Storage(type="short_term", crew=crew, config=config) @@ -159,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl ) - - def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): """Test save method for different memory types""" mem0_storage, _, _ = mem0_storage_with_mocked_config @@ -168,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): # Test short_term memory type (already set in fixture) test_value = "This is a test memory" - test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} + test_metadata = { + "description": TEST_DESCRIPTION, + "messages": [ + {"role": "system", "content": SYSTEM_CONTENT}, + {"role": "user", "content": USER_CONTENT}, + {"role": "assistant", "content": ASSISTANT_CONTENT}, + ], + "agent": "Friendly chatbot assistant", + } mem0_storage.save(test_value, test_metadata) mem0_storage.memory.add.assert_called_once_with( - [{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], + [ + {"role": "user", "content": EXTRACTED_USER_CONTENT}, + { + "role": "assistant", + "content": EXTRACTED_ASSISTANT_CONTENT, + }, + ], infer=True, - metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, + metadata={ + "type": "short_term", + "description": TEST_DESCRIPTION, + "agent": "Friendly chatbot assistant", + }, run_id="my_run_id", user_id="test_user", - agent_id='Test_Agent' + agent_id="Test_Agent", ) + def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config): mem0_storage, _, _ = mem0_storage_with_mocked_config - mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")] + mem0_storage.crew.agents = [ + MagicMock(role="Test Agent"), + MagicMock(role="Test Agent 2"), + MagicMock(role="Test Agent 3"), + ] mem0_storage.memory.add = MagicMock() test_value = "This is a test memory" - test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} + test_metadata = { + "description": TEST_DESCRIPTION, + "messages": [ + {"role": "system", "content": SYSTEM_CONTENT}, + {"role": "user", "content": USER_CONTENT}, + {"role": "assistant", "content": ASSISTANT_CONTENT}, + ], + "agent": "Friendly chatbot assistant", + } mem0_storage.save(test_value, test_metadata) mem0_storage.memory.add.assert_called_once_with( - [{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], + [ + {"role": "user", "content": EXTRACTED_USER_CONTENT}, + { + "role": "assistant", + "content": EXTRACTED_ASSISTANT_CONTENT, + }, + ], infer=True, - metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, + metadata={ + "type": "short_term", + "description": TEST_DESCRIPTION, + "agent": "Friendly chatbot assistant", + }, run_id="my_run_id", user_id="test_user", - agent_id='Test_Agent_Test_Agent_2_Test_Agent_3' + agent_id="Test_Agent_Test_Agent_2_Test_Agent_3", ) -def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew): +def test_save_method_with_memory_client( + mem0_storage_with_memory_client_using_config_from_crew, +): """Test save method for different memory types""" mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage.memory.add = MagicMock() # Test short_term memory type (already set in fixture) test_value = "This is a test memory" - test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'} + test_metadata = { + "description": TEST_DESCRIPTION, + "messages": [ + {"role": "system", "content": SYSTEM_CONTENT}, + {"role": "user", "content": USER_CONTENT}, + {"role": "assistant", "content": ASSISTANT_CONTENT}, + ], + "agent": "Friendly chatbot assistant", + } mem0_storage.save(test_value, test_metadata) mem0_storage.memory.add.assert_called_once_with( - [{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], + [ + {"role": "user", "content": EXTRACTED_USER_CONTENT}, + { + "role": "assistant", + "content": EXTRACTED_ASSISTANT_CONTENT, + }, + ], infer=True, - metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'}, + metadata={ + "type": "short_term", + "description": TEST_DESCRIPTION, + "agent": "Friendly chatbot assistant", + }, version="v2", run_id="my_run_id", includes="include1", excludes="exclude1", - output_format='v1.1', - user_id='test_user', - agent_id='Test_Agent' + output_format="v1.1", + user_id="test_user", + agent_id="Test_Agent", ) def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): """Test search method for different memory types""" mem0_storage, _, _ = mem0_storage_with_mocked_config - mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} + mock_results = { + "results": [ + {"score": 0.9, "memory": "Result 1"}, + {"score": 0.4, "memory": "Result 2"}, + ] + } mem0_storage.memory.search = MagicMock(return_value=mock_results) results = mem0_storage.search("test query", limit=5, score_threshold=0.5) @@ -238,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): query="test query", limit=5, user_id="test_user", - filters={'AND': [{'run_id': 'my_run_id'}]}, - threshold=0.5 + filters={"AND": [{"run_id": "my_run_id"}]}, + threshold=0.5, ) assert len(results) == 2 - assert results[0]["context"] == "Result 1" + assert results[0]["content"] == "Result 1" -def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew): +def test_search_method_with_memory_client( + mem0_storage_with_memory_client_using_config_from_crew, +): """Test search method for different memory types""" mem0_storage = mem0_storage_with_memory_client_using_config_from_crew - mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} + mock_results = { + "results": [ + {"score": 0.9, "memory": "Result 1"}, + {"score": 0.4, "memory": "Result 2"}, + ] + } mem0_storage.memory.search = MagicMock(return_value=mock_results) results = mem0_storage.search("test query", limit=5, score_threshold=0.5) @@ -259,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_ limit=5, metadata={"type": "short_term"}, user_id="test_user", - version='v2', + version="v2", run_id="my_run_id", - output_format='v1.1', - filters={'AND': [{'run_id': 'my_run_id'}]}, - threshold=0.5 + output_format="v1.1", + filters={"AND": [{"run_id": "my_run_id"}]}, + threshold=0.5, ) assert len(results) == 2 - assert results[0]["context"] == "Result 1" + assert results[0]["content"] == "Result 1" def test_mem0_storage_default_infer_value(mock_mem0_memory_client): @@ -275,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client): with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): crew = MockCrew() - config={ - "user_id": "test_user", - "api_key": "ABCDEFGH" - } + config = {"user_id": "test_user", "api_key": "ABCDEFGH"} mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config) assert mem0_storage.infer is True + def test_save_memory_using_agent_entity(mock_mem0_memory_client): config = { "agent_id": "agent-123", @@ -293,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client): mem0_storage = Mem0Storage(type="external", config=config) mem0_storage.save("test memory", {"key": "value"}) mem0_storage.memory.add.assert_called_once_with( - [{'role': 'assistant' , 'content': 'test memory'}], + [{"role": "assistant", "content": "test memory"}], infer=True, metadata={"type": "external", "key": "value"}, agent_id="agent-123", ) + def test_search_method_with_agent_entity(): config = { "agent_id": "agent-123", } mock_memory = MagicMock(spec=Memory) - mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} + mock_results = { + "results": [ + {"score": 0.9, "memory": "Result 1"}, + {"score": 0.4, "memory": "Result 2"}, + ] + } with patch.object(Memory, "__new__", return_value=mock_memory): mem0_storage = Mem0Storage(type="external", config=config) @@ -314,22 +466,29 @@ def test_search_method_with_agent_entity(): results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, - filters={"AND": [{"agent_id": "agent-123"}]}, - threshold=0.5, - ) + query="test query", + limit=5, + filters={"AND": [{"agent_id": "agent-123"}]}, + threshold=0.5, + ) assert len(results) == 2 - assert results[0]["context"] == "Result 1" + assert results[0]["content"] == "Result 1" def test_search_method_with_agent_id_and_user_id(): mock_memory = MagicMock(spec=Memory) - mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]} + mock_results = { + "results": [ + {"score": 0.9, "memory": "Result 1"}, + {"score": 0.4, "memory": "Result 2"}, + ] + } with patch.object(Memory, "__new__", return_value=mock_memory): - mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"}) + mem0_storage = Mem0Storage( + type="external", config={"agent_id": "agent-123", "user_id": "user-123"} + ) mem0_storage.memory.search = MagicMock(return_value=mock_results) results = mem0_storage.search("test query", limit=5, score_threshold=0.5) @@ -337,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id(): mem0_storage.memory.search.assert_called_once_with( query="test query", limit=5, - user_id='user-123', + user_id="user-123", filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]}, threshold=0.5, ) assert len(results) == 2 - assert results[0]["context"] == "Result 1" + assert results[0]["content"] == "Result 1" diff --git a/tests/utilities/test_chromadb_utils.py b/tests/utilities/test_chromadb_utils.py deleted file mode 100644 index bf939f4c8..000000000 --- a/tests/utilities/test_chromadb_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -import multiprocessing -import tempfile -import unittest - -from chromadb.config import Settings -from unittest.mock import patch, MagicMock - -from crewai.utilities.chromadb import ( - MAX_COLLECTION_LENGTH, - MIN_COLLECTION_LENGTH, - is_ipv4_pattern, - sanitize_collection_name, - create_persistent_client, -) - - -def persistent_client_worker(path, queue): - try: - create_persistent_client(path=path) - queue.put(None) - except Exception as e: - queue.put(e) - - -class TestChromadbUtils(unittest.TestCase): - def test_sanitize_collection_name_long_name(self): - """Test sanitizing a very long collection name.""" - long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name" - sanitized = sanitize_collection_name(long_name) - self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized)) - - def test_sanitize_collection_name_special_chars(self): - """Test sanitizing a name with special characters.""" - special_chars = "Agent@123!#$%^&*()" - sanitized = sanitize_collection_name(special_chars) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized)) - - def test_sanitize_collection_name_short_name(self): - """Test sanitizing a very short name.""" - short_name = "A" - sanitized = sanitize_collection_name(short_name) - self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - - def test_sanitize_collection_name_bad_ends(self): - """Test sanitizing a name with non-alphanumeric start/end.""" - bad_ends = "_Agent_" - sanitized = sanitize_collection_name(bad_ends) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - - def test_sanitize_collection_name_none(self): - """Test sanitizing a None value.""" - sanitized = sanitize_collection_name(None) - self.assertEqual(sanitized, "default_collection") - - def test_sanitize_collection_name_ipv4_pattern(self): - """Test sanitizing an IPv4 address.""" - ipv4 = "192.168.1.1" - sanitized = sanitize_collection_name(ipv4) - self.assertTrue(sanitized.startswith("ip_")) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized)) - - def test_is_ipv4_pattern(self): - """Test IPv4 pattern detection.""" - self.assertTrue(is_ipv4_pattern("192.168.1.1")) - self.assertFalse(is_ipv4_pattern("not.an.ip.address")) - - def test_sanitize_collection_name_properties(self): - """Test that sanitized collection names always meet ChromaDB requirements.""" - test_cases = [ - "A" * 100, # Very long name - "_start_with_underscore", - "end_with_underscore_", - "contains@special#characters", - "192.168.1.1", # IPv4 address - "a" * 2, # Too short - ] - for test_case in test_cases: - sanitized = sanitize_collection_name(test_case) - self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH) - self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH) - self.assertTrue(sanitized[0].isalnum()) - self.assertTrue(sanitized[-1].isalnum()) - - def test_create_persistent_client_passes_args(self): - with patch( - "crewai.utilities.chromadb.PersistentClient" - ) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir: - mock_instance = MagicMock() - mock_persistent_client.return_value = mock_instance - - settings = Settings(allow_reset=True) - client = create_persistent_client(path=tmpdir, settings=settings) - - mock_persistent_client.assert_called_once_with( - path=tmpdir, settings=settings - ) - self.assertIs(client, mock_instance) - - def test_create_persistent_client_process_safe(self): - with tempfile.TemporaryDirectory() as tmpdir: - queue = multiprocessing.Queue() - processes = [ - multiprocessing.Process( - target=persistent_client_worker, args=(tmpdir, queue) - ) - for _ in range(5) - ] - - [p.start() for p in processes] - [p.join() for p in processes] - - errors = [queue.get(timeout=5) for _ in processes] - self.assertTrue(all(err is None for err in errors)) diff --git a/tests/utilities/test_knowledge_planning.py b/tests/utilities/test_knowledge_planning.py index 37b6df69f..9ff29c573 100644 --- a/tests/utilities/test_knowledge_planning.py +++ b/tests/utilities/test_knowledge_planning.py @@ -29,13 +29,15 @@ def mock_knowledge_source(): """ return StringKnowledgeSource(content=content) -@patch('crewai.knowledge.storage.knowledge_storage.chromadb') -def test_knowledge_included_in_planning(mock_chroma): + +@patch("crewai.rag.config.utils.get_rag_client") +def test_knowledge_included_in_planning(mock_get_client): """Test that verifies knowledge sources are properly included in planning.""" - # Mock ChromaDB collection - mock_collection = mock_chroma.return_value.get_or_create_collection.return_value - mock_collection.add.return_value = None - + # Mock RAG client + mock_client = mock_get_client.return_value + mock_client.get_or_create_collection.return_value = None + mock_client.add_documents.return_value = None + # Create an agent with knowledge agent = Agent( role="AI Researcher", @@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma): StringKnowledgeSource( content="AI systems require careful training and validation." ) - ] + ], ) # Create a task for the agent task = Task( description="Explain the basics of AI systems", expected_output="A clear explanation of AI fundamentals", - agent=agent + agent=agent, ) # Create a crew planner @@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma): task_summary = planner._create_tasks_summary() # Verify that knowledge is included in planning when present - assert "AI systems require careful training" in task_summary, \ + assert "AI systems require careful training" in task_summary, ( "Knowledge content should be present in task summary when knowledge exists" - assert '"agent_knowledge"' in task_summary, \ + ) + assert '"agent_knowledge"' in task_summary, ( "agent_knowledge field should be present in task summary when knowledge exists" + ) # Verify that knowledge is properly formatted - assert isinstance(task.agent.knowledge_sources, list), \ + assert isinstance(task.agent.knowledge_sources, list), ( "Knowledge sources should be stored in a list" - assert len(task.agent.knowledge_sources) > 0, \ + ) + assert len(task.agent.knowledge_sources) > 0, ( "At least one knowledge source should be present" - assert task.agent.knowledge_sources[0].content in task_summary, \ + ) + assert task.agent.knowledge_sources[0].content in task_summary, ( "Knowledge source content should be included in task summary" + ) # Verify that other expected components are still present - assert task.description in task_summary, \ + assert task.description in task_summary, ( "Task description should be present in task summary" - assert task.expected_output in task_summary, \ + ) + assert task.expected_output in task_summary, ( "Expected output should be present in task summary" - assert agent.role in task_summary, \ - "Agent role should be present in task summary" + ) + assert agent.role in task_summary, "Agent role should be present in task summary"