refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

- ignore line length errors globally
- migrate knowledge/memory and crew query_knowledge to `SearchResult`
- remove legacy chromadb utils; fix empty metadata handling
- restore openai as default embedding provider; support instance-specific clients
- update and fix tests for `SearchResult` migration and rag changes
This commit is contained in:
Greyson LaLonde
2025-09-17 14:46:54 -04:00
committed by GitHub
parent 81bd81e5f5
commit f28e78c5ba
30 changed files with 1956 additions and 976 deletions

View File

@@ -3,26 +3,17 @@ import json
import re
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -39,26 +30,14 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -70,16 +49,28 @@ from crewai.events.types.crew_events import (
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -94,28 +85,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
class Crew(FlowTrackable, BaseModel):
"""
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
Represents a group of agents, defining how they should collaborate and the
tasks they should perform.
Attributes:
tasks: List of tasks assigned to the crew.
agents: List of agents part of this crew.
manager_llm: The language model that will run manager agent.
manager_agent: Custom agent that will be used as manager.
memory: Whether the crew should use memory to store memories of it's execution.
cache: Whether the crew should use a cache to store the results of the tools execution.
function_calling_llm: The language model that will run the tool calling for all the agents.
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
memory: Whether the crew should use memory to store memories of it's
execution.
cache: Whether the crew should use a cache to store the results of the
tools execution.
function_calling_llm: The language model that will run the tool calling
for all the agents.
process: The process flow that the crew will follow (e.g., sequential,
hierarchical).
verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew.
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
max_rpm: Maximum number of requests per minute for the crew execution to
be respected.
prompt_file: Path to the prompt json file to be used for the crew.
id: A unique identifier for the crew instance.
task_callback: Callback to be executed after each task for every agents execution.
step_callback: Callback to be executed after each step for every agents execution.
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
task_callback: Callback to be executed after each task for every agents
execution.
step_callback: Callback to be executed after each step for every agents
execution.
share_crew: Whether you want to share the complete crew information and
execution with crewAI to make the library better, and allow us to
train models.
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
security_config: Security configuration for the crew, including fingerprinting.
chat_llm: The language model used for orchestrating chat interactions
with the crew.
security_config: Security configuration for the crew, including
fingerprinting.
"""
__hash__ = object.__hash__ # type: ignore
@@ -124,13 +127,13 @@ class Crew(FlowTrackable, BaseModel):
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
@@ -138,107 +141,121 @@ class Crew(FlowTrackable, BaseModel):
default_factory=TaskOutputStorageHandler
)
name: Optional[str] = Field(default="crew")
name: str | None = Field(default="crew")
cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
tasks: list[Task] = Field(default_factory=list)
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool = Field(
default=False,
description="Whether the crew should use memory to store memories of it's execution",
description="If crew should use memory to store memories of it's execution",
)
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
default=None,
description="An Instance of the ShortTermMemory to be used by the Crew",
)
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
default=None,
description="An Instance of the LongTermMemory to be used by the Crew",
)
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
entity_memory: InstanceOf[EntityMemory] | None = Field(
default=None,
description="An Instance of the EntityMemory to be used by the Crew",
)
external_memory: Optional[InstanceOf[ExternalMemory]] = Field(
external_memory: InstanceOf[ExternalMemory] | None = Field(
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
embedder: dict | None = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
usage_metrics: Optional[UsageMetrics] = Field(
usage_metrics: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
config: Json | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
step_callback: Optional[Any] = Field(
share_crew: bool | None = Field(default=False)
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step for all agents execution.",
)
task_callback: Optional[Any] = Field(
task_callback: Any | None = Field(
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: List[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
before_kickoff_callbacks: list[
Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field(
default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
)
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
description=(
"List of callbacks to be executed after crew kickoff. "
"It may be used to adjust the output of the crew."
),
)
max_rpm: Optional[int] = Field(
max_rpm: int | None = Field(
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
description=(
"Maximum number of requests per minute for the crew execution "
"to be respected."
),
)
prompt_file: Optional[str] = Field(
prompt_file: str | None = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
output_log_file: Optional[Union[bool, str]] = Field(
output_log_file: bool | str | None = Field(
default=None,
description="Path to the log file to be saved",
)
planning: Optional[bool] = Field(
planning: bool | None = Field(
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
description=(
"Language model that will run the AgentPlanner if planning is True."
),
)
task_execution_output_json_files: Optional[List[str]] = Field(
task_execution_output_json_files: list[str] | None = Field(
default=None,
description="List of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: list[dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
description=(
"Knowledge sources for the crew. Add knowledge sources to the "
"knowledge object."
),
)
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
knowledge: Optional[Knowledge] = Field(
knowledge: Knowledge | None = Field(
default=None,
description="Knowledge for the crew.",
)
@@ -246,18 +263,18 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
token_usage: Optional[UsageMetrics] = Field(
token_usage: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
tracing: Optional[bool] = Field(
tracing: bool | None = Field(
default=False,
description="Whether to enable tracing for the crew.",
)
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
"""Prevent manual setting of the 'id' field by users."""
if v:
raise PydanticCustomError(
@@ -266,9 +283,7 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
@@ -314,7 +329,8 @@ class Crew(FlowTrackable, BaseModel):
def create_crew_memory(self) -> "Crew":
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
# External memory does not support a default value since it was
# designed to be managed entirely externally
self.external_memory.set_crew(self) if self.external_memory else None
)
@@ -355,7 +371,10 @@ class Crew(FlowTrackable, BaseModel):
if not self.manager_llm and not self.manager_agent:
raise PydanticCustomError(
"missing_manager_llm_or_manager_agent",
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.",
(
"Attribute `manager_llm` or `manager_agent` is required "
"when using hierarchical process."
),
{},
)
@@ -398,7 +417,10 @@ class Crew(FlowTrackable, BaseModel):
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Sequential process error: Agent is missing in the task "
f"with the following description: {task.description}"
), # type: ignore # Dynamic string in error message
{},
)
@@ -459,7 +481,10 @@ class Crew(FlowTrackable, BaseModel):
if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError(
"invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Conditional Task: {task.description}, "
f"cannot be executed asynchronously."
),
{},
)
return self
@@ -478,7 +503,9 @@ class Crew(FlowTrackable, BaseModel):
for j in range(i - 1, -1, -1):
if self.tasks[j] == context_task:
raise ValueError(
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
f"Task '{task.description}' is asynchronous and "
f"cannot include other sequential asynchronous "
f"tasks in its context."
)
if not self.tasks[j].async_execution:
break
@@ -496,13 +523,15 @@ class Crew(FlowTrackable, BaseModel):
continue # Skip context tasks not in the main tasks list
if task_indices[id(context_task)] > task_indices[id(task)]:
raise ValueError(
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
f"Task '{task.description}' has a context dependency "
f"on a future task '{context_task.description}', "
f"which is not allowed."
)
return self
@property
def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [
source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -518,9 +547,9 @@ class Crew(FlowTrackable, BaseModel):
return self.security_config.fingerprint
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
if self.config is None:
raise ValueError("Config should not be None.")
if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError(
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
@@ -530,7 +559,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task:
def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration.
Args:
@@ -559,7 +588,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file()
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
) -> None:
"""Trains the crew for a given number of iterations."""
inputs = inputs or {}
@@ -611,7 +640,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: dict[str, Any] | None = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -682,9 +711,9 @@ class Crew(FlowTrackable, BaseModel):
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""
results: List[CrewOutput] = []
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input and aggregates results."""
results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics()
@@ -703,14 +732,12 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
async def kickoff_async(
self, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
@@ -739,7 +766,9 @@ class Crew(FlowTrackable, BaseModel):
tasks=self.tasks, planning_agent_llm=self.planning_llm
)._handle_crew_planning()
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task):
for task, step_plan in zip(
self.tasks, result.list_of_plans_per_task, strict=False
):
task.description += step_plan.plan
def _store_execution_log(
@@ -776,7 +805,7 @@ class Crew(FlowTrackable, BaseModel):
return self._execute_tasks(self.tasks)
def _run_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
"""Creates and assigns a manager agent to complete the tasks."""
self._create_manager_agent()
return self._execute_tasks(self.tasks)
@@ -807,23 +836,24 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks(
self,
tasks: List[Task],
start_index: Optional[int] = 0,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.
Args:
tasks (List[Task]): List of tasks to execute
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
manager (Optional[BaseAgent], optional): Manager agent to use for
delegation. Defaults to None.
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None
task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
@@ -838,7 +868,9 @@ class Crew(FlowTrackable, BaseModel):
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
# Determine which tools to use - task tools take precedence over agent tools
@@ -847,7 +879,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
cast(list[Tool] | list[BaseTool], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -867,7 +899,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
@@ -879,7 +911,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -893,11 +925,11 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]],
task_outputs: list[TaskOutput],
futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
) -> TaskOutput | None:
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
@@ -917,8 +949,8 @@ class Crew(FlowTrackable, BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
@@ -947,22 +979,22 @@ class Crew(FlowTrackable, BaseModel):
):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
# Return a List[BaseTool] compatible with Task.execute_sync and execute_async
return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
if self.process == Process.hierarchical:
return self.manager_agent
return task.agent
def _merge_tools(
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
existing_tools: list[Tool] | list[BaseTool],
new_tools: list[Tool] | list[BaseTool],
) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates."""
if not new_tools:
return cast(List[BaseTool], existing_tools)
return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +1005,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools
tools.extend(new_tools)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _inject_delegation_tools(
self,
tools: Union[List[Tool], List[BaseTool]],
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
agents: list[BaseAgent],
) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(list[BaseTool], tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(list[BaseTool], tools)
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -1015,7 +1047,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
@@ -1024,8 +1056,8 @@ class Crew(FlowTrackable, BaseModel):
)
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,18 +1065,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
context = (
return (
aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_tasks(task.context)
)
return context
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None"
@@ -1057,7 +1088,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs:
raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1119,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks(
self,
futures: List[Tuple[Task, Future[TaskOutput], int]],
futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False,
) -> List[TaskOutput]:
task_outputs: List[TaskOutput] = []
) -> list[TaskOutput]:
task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures:
task_output = future.result()
task_outputs.append(task_output)
@@ -1101,9 +1132,7 @@ class Crew(FlowTrackable, BaseModel):
)
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]:
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
return next(
(
index
@@ -1113,9 +1142,8 @@ class Crew(FlowTrackable, BaseModel):
None,
)
def replay(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Replay the crew execution from a specific task."""
stored_outputs = self._task_output_handler.load()
if not stored_outputs:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
@@ -1151,19 +1179,19 @@ class Crew(FlowTrackable, BaseModel):
self.tasks[i].output = task_output
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, start_index, True)
return result
return self._execute_tasks(self.tasks, start_index, True)
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information."""
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,11 +1200,11 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks for inputs
for task in self.tasks:
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
# description and expected_output might contain e.g. {topic}, {user_name}
text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text))
@@ -1230,7 +1258,7 @@ class Crew(FlowTrackable, BaseModel):
cloned_tasks.append(cloned_task)
task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
if isinstance(original_task.context, list):
cloned_context = [
task_mapping[context_task.key]
@@ -1256,7 +1284,7 @@ class Crew(FlowTrackable, BaseModel):
copied_data.pop("agents", None)
copied_data.pop("tasks", None)
copied_crew = Crew(
return Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
@@ -1266,15 +1294,13 @@ class Crew(FlowTrackable, BaseModel):
manager_llm=manager_llm,
)
return copied_crew
def _set_tasks_callbacks(self) -> None:
"""Sets callback for every task suing task_callback"""
for task in self.tasks:
if not task.callback:
task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents."""
[
task.interpolate_inputs_and_add_conversation_history(
@@ -1307,10 +1333,13 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
eval_llm: str | InstanceOf[BaseLLM],
inputs: dict[str, Any] | None = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
"""Test and evaluate the Crew with the given inputs for n iterations.
Uses concurrent.futures for concurrent execution.
"""
try:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
@@ -1350,7 +1379,11 @@ class Crew(FlowTrackable, BaseModel):
raise
def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
return (
f"Crew(id={self.id}, process={self.process}, "
f"number_of_agents={len(self.agents)}, "
f"number_of_tasks={len(self.tasks)})"
)
def reset_memories(self, command_type: str) -> None:
"""Reset specific or all memories for the crew.
@@ -1364,7 +1397,7 @@ class Crew(FlowTrackable, BaseModel):
ValueError: If an invalid command type is provided.
RuntimeError: If memory reset operation fails.
"""
VALID_TYPES = frozenset(
valid_types = frozenset(
[
"long",
"short",
@@ -1377,9 +1410,10 @@ class Crew(FlowTrackable, BaseModel):
]
)
if command_type not in VALID_TYPES:
if command_type not in valid_types:
raise ValueError(
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
f"Invalid command type. Must be one of: "
f"{', '.join(sorted(valid_types))}"
)
try:
@@ -1389,7 +1423,7 @@ class Crew(FlowTrackable, BaseModel):
self._reset_specific_memory(command_type)
except Exception as e:
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
error_msg = f"Failed to reset {command_type} memory: {e!s}"
self._logger.log("error", error_msg)
raise RuntimeError(error_msg) from e
@@ -1397,7 +1431,7 @@ class Crew(FlowTrackable, BaseModel):
"""Reset all available memory systems."""
memory_systems = self._get_memory_systems()
for memory_type, config in memory_systems.items():
for config in memory_systems.values():
if (system := config.get("system")) is not None:
name = config.get("name")
try:
@@ -1405,11 +1439,13 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _reset_specific_memory(self, memory_type: str) -> None:
@@ -1434,18 +1470,21 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _get_memory_systems(self):
"""Get all available memory systems with their configuration.
Returns:
Dict containing all memory systems with their reset functions and display names.
Dict containing all memory systems with their reset functions and
display names.
"""
def default_reset(memory):
@@ -1506,7 +1545,7 @@ class Crew(FlowTrackable, BaseModel):
},
}
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -13,23 +14,23 @@ class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
super().__init__(**data)
@@ -40,11 +41,10 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -55,12 +55,11 @@ class Knowledge(BaseModel):
if self.storage is None:
raise ValueError("Storage is not initialized.")
results = self.storage.search(
return self.storage.search(
query,
limit=results_limit,
score_threshold=score_threshold,
)
return results
def add_sources(self):
try:

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC):
@@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod
def search(
self,
query: List[str],
query: list[str],
limit: int = 3,
filter: Optional[dict] = None,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -1,24 +1,16 @@
import hashlib
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import warnings
from typing import Any, cast
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
class KnowledgeStorage(BaseKnowledgeStorage):
@@ -27,167 +19,101 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency.
"""
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._set_embedder_config(embedder)
self._client: BaseClient | None = None
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection:
fetched = self.collection.query(
query_texts=query,
n_results=limit,
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self.app = create_persistent_client(
path=os.path.join(db_storage_path(), "knowledge"),
settings=Settings(allow_reset=True),
)
if embedder:
embedding_function = get_embedding_function(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def search(
self,
query: list[str],
limit: int = 3,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> list[SearchResult]:
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
if self.app:
self.collection = self.app.get_or_create_collection(
name=sanitize_collection_name(collection_name),
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")
except Exception:
raise Exception("Failed to create or get collection")
query_text = " ".join(query) if len(query) > 1 else query[0]
def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = create_persistent_client(
path=base_path, settings=Settings(allow_reset=True)
return client.search(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
self.app.reset()
shutil.rmtree(base_path)
self.app = None
self.collection = None
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
if not self.collection:
raise Exception("Collection not initialized")
try:
# Create a dictionary to store unique documents
unique_docs = {}
# Generate IDs and create a mapping of id -> (document, metadata)
for idx, doc in enumerate(documents):
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
doc_metadata = None
if metadata is not None:
if isinstance(metadata, list):
doc_metadata = metadata[idx]
else:
doc_metadata = metadata
unique_docs[doc_id] = (doc, doc_metadata)
# Prepare filtered lists for ChromaDB
filtered_docs = []
filtered_metadata = []
filtered_ids = []
# Build the filtered lists
for doc_id, (doc, meta) in unique_docs.items():
filtered_docs.append(doc)
filtered_metadata.append(meta)
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
self.collection.upsert(
documents=filtered_docs,
metadatas=final_metadata,
ids=filtered_ids,
)
except chromadb.errors.InvalidDimensionException as e:
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
except Exception as e:
logging.error(f"Error during knowledge search: {e!s}")
return []
def reset(self) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
logging.error(f"Error during knowledge reset: {e!s}")
def save(self, documents: list[str]) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.get_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
if embedder
else self._create_default_embedding_function()
)

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, List
from crewai.rag.types import SearchResult
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
"""Extract knowledge from the task prompt."""
valid_snippets = [
result["context"]
result["content"]
for result in knowledge_snippets
if result and result.get("context")
if result and result.get("content")
]
snippet = "\n".join(valid_snippets)
return f"Additional Information: {snippet}" if valid_snippets else ""

View File

@@ -1,4 +1,6 @@
from typing import Optional, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.memory import (
EntityMemory,
@@ -19,9 +21,9 @@ class ContextualMemory:
ltm: LongTermMemory,
em: EntityMemory,
exm: ExternalMemory,
agent: Optional["Agent"] = None,
task: Optional["Task"] = None,
):
agent: Agent | None = None,
task: Task | None = None,
) -> None:
self.stm = stm
self.ltm = ltm
self.em = em
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent
self.exm.task = self.task
def build_context_for_task(self, task, context) -> str:
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
@@ -52,14 +54,15 @@ class ContextualMemory:
if query == "":
return ""
context = []
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
context.append(self._fetch_external_context(query))
return "\n".join(filter(None, context))
context_parts = [
self._fetch_ltm_context(task.description),
self._fetch_stm_context(query),
self._fetch_entity_context(query),
self._fetch_external_context(query),
]
return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str:
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
@@ -70,11 +73,11 @@ class ContextualMemory:
stm_results = self.stm.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
def _fetch_ltm_context(self, task: str) -> str | None:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -90,14 +93,14 @@ class ContextualMemory:
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
def _fetch_entity_context(self, query: str) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
@@ -107,7 +110,7 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
@@ -128,6 +131,6 @@ class ContextualMemory:
return ""
formatted_memories = "\n".join(
f"- {result['context']}" for result in external_memories
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -1,12 +1,13 @@
import os
import re
from collections import defaultdict
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any
from mem0 import Memory, MemoryClient # type: ignore[import-untyped]
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.rag.chromadb.utils import _sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -15,6 +16,7 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
super().__init__()
@@ -30,7 +32,8 @@ class Mem0Storage(Storage):
supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
f"Invalid type '{type}' for Mem0Storage. "
f"Must be one of: {', '.join(supported_types)}"
)
def _extract_config_values(self):
@@ -68,7 +71,8 @@ class Mem0Storage(Storage):
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
- Includes run_id if memory_type is 'short_term' and
mem0_run_id is present.
"""
filter = defaultdict(list)
@@ -91,10 +95,14 @@ class Mem0Storage(Storage):
def save(self, value: Any, metadata: dict[str, Any]) -> None:
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
return next(
(m.get("content", "") for m in reversed(list(messages)) if m.get("role") == role),
""
(
m.get("content", "")
for m in reversed(list(messages))
if m.get("role") == role
),
"",
)
conversations = []
messages = metadata.pop("messages", None)
if messages:
@@ -103,7 +111,7 @@ class Mem0Storage(Storage):
if user_msg := self._get_user_message(last_user):
conversations.append({"role": "user", "content": user_msg})
if assistant_msg := self._get_assistant_message(last_assistant):
conversations.append({"role": "assistant", "content": assistant_msg})
else:
@@ -115,13 +123,13 @@ class Mem0Storage(Storage):
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer
"infer": self.infer,
}
# MemoryClient-specific overrides
@@ -142,13 +150,15 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> list[Any]:
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> list[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
"output_format": "v1.1",
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -169,10 +179,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold
params["threshold"] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format']
del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"):
del params["run_id"]
@@ -180,7 +190,7 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve
for result in results["results"]:
result["context"] = result["memory"]
result["content"] = result["memory"]
return [r for r in results["results"]]
@@ -201,7 +211,9 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
return _sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)
def _get_assistant_message(self, text: str) -> str:
marker = "Final Answer:"

View File

@@ -1,17 +1,16 @@
import logging
import os
import shutil
import uuid
import warnings
from typing import Any
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client
from crewai.rag.types import BaseRecord
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
class RAGStorage(BaseRAGStorage):
@@ -20,8 +19,6 @@ class RAGStorage(BaseRAGStorage):
search efficiency.
"""
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
@@ -33,37 +30,25 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._client: BaseClient | None = None
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self._set_embedder_config()
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function)
self._client = create_client(config)
self.app = create_persistent_client(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
self.collection = self.app.get_or_create_collection(
name=self.type, embedding_function=self.embedder_config
)
logging.info(f"Collection found or created: {self.collection}")
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def _sanitize_role(self, role: str) -> str:
"""
@@ -85,77 +70,65 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
def save(self, value: Any, metadata: dict[str, Any]) -> None:
try:
self._generate_embedding(value, metadata)
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.get_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
logging.error(f"Error during {self.type} save: {e!s}")
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
) -> list[Any]:
try:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
return results
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return client.search(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
logging.error(f"Error during {self.type} search: {e!s}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
def reset(self) -> None:
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
self.app = None
self.collection = None
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
if "attempt to write a readonly database" in str(
e
) or "does not exist" in str(e):
# Ignore readonly database and collection not found errors (already reset)
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
) from e

View File

@@ -4,8 +4,9 @@ import logging
from typing import Any
from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
)
from chromadb.api.types import (
QueryResult,
)
from typing_extensions import Unpack
@@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import (
_process_query_results,
_sanitize_collection_name,
)
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.utilities.logger_utils import suppress_logging
class ChromaDBClient(BaseClient):
@@ -46,7 +47,7 @@ class ChromaDBClient(BaseClient):
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
embedding_function: ChromaEmbeddingFunction,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
@@ -306,10 +307,12 @@ class ChromaDBClient(BaseClient):
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -347,10 +350,12 @@ class ChromaDBClient(BaseClient):
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
def search(

View File

@@ -3,18 +3,18 @@
import warnings
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
DEFAULT_TENANT,
)
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
warnings.filterwarnings(
"ignore",

View File

@@ -2,11 +2,12 @@
import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
@@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""
persist_dir = config.settings.persist_directory
os.makedirs(persist_dir, exist_ok=True)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")

View File

@@ -3,27 +3,28 @@
from collections.abc import Mapping
from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include,
Loadable,
Where,
WhereDocument,
)
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
@@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple):
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings
metadatas: List of document metadata mappings (empty dict for no metadata)
"""
ids: list[str]
@@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable]
embedding_function: ChromaEmbeddingFunction
data_loader: DataLoader[Loadable]
get_or_create: bool

View File

@@ -5,13 +5,14 @@ from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
@@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb(
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
metadatas.append(metadata[0] if metadata and metadata[0] else {})
else:
metadatas.append(metadata)
else:
@@ -154,7 +155,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
include_strings = [item.value for item in include] if include else []
ids = results["ids"][0] if results.get("ids") else []
@@ -188,7 +189,9 @@ def _convert_chromadb_results_to_search_results(
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"metadata": dict(metadatas[i])
if metadatas and i < len(metadatas) and metadatas[i] is not None
else {},
"score": score,
}
search_results.append(result)
@@ -271,7 +274,7 @@ def _sanitize_collection_name(
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():

View File

@@ -1,15 +1,15 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from typing import Annotated, Any, Protocol, runtime_checkable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Required, TypedDict, Unpack
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
EmbeddingFunction,
SearchResult,
)
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
metadata_filter: dict[str, Any] | None
score_threshold: float

View File

@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
@@ -60,7 +60,7 @@ def get_embedding_function(
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings (default)
- openai: OpenAI embeddings
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
@@ -77,7 +77,7 @@ def get_embedding_function(
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples:
# Use default OpenAI with retry logic
# Use default OpenAI embedding
>>> embedder = get_embedding_function()
# Use Cohere with dict

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseRAGStorage(ABC):
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
self.type = type
@@ -32,45 +32,21 @@ class BaseRAGStorage(ABC):
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass

View File

@@ -1,83 +0,0 @@
import os
import re
import portalocker
from chromadb import PersistentClient
from hashlib import md5
from typing import Optional
from crewai.utilities.paths import db_storage_path
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized
def create_persistent_client(path: str, **kwargs):
"""
Creates a persistent client for ChromaDB with a lock file to prevent
concurrent creations. Works for both multi-threads and multi-processes
environments.
"""
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(path=path, **kwargs)
return client