Merge branch 'main' into gl/refactor/agent-adapter-typing-and-docs

This commit is contained in:
Greyson LaLonde
2025-09-19 17:15:47 -04:00
committed by GitHub
58 changed files with 6622 additions and 4886 deletions

102
.github/workflows/codeql.yml vendored Normal file
View File

@@ -0,0 +1,102 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL Advanced"
on:
push:
branches: [ "main" ]
paths-ignore:
- "src/crewai/cli/templates/**"
pull_request:
branches: [ "main" ]
paths-ignore:
- "src/crewai/cli/templates/**"
jobs:
analyze:
name: Analyze (${{ matrix.language }})
# Runner size impacts CodeQL analysis time. To learn more, please see:
# - https://gh.io/recommended-hardware-resources-for-running-codeql
# - https://gh.io/supported-runners-and-hardware-resources
# - https://gh.io/using-larger-runners (GitHub.com only)
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
permissions:
# required for all workflows
security-events: write
# required to fetch internal or private CodeQL packs
packages: read
# only required for workflows in private repositories
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: actions
build-mode: none
- language: python
build-mode: none
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
# Use `c-cpp` to analyze code written in C, C++ or both
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Add any setup steps before running the `github/codeql-action/init` action.
# This includes steps like installing compilers or runtimes (`actions/setup-node`
# or others). This is typically only required for manual builds.
# - name: Setup runtime (example)
# uses: actions/setup-example@v1
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# If the analyze step fails for one of the languages you are analyzing with
# "We were unable to automatically build your code", modify the matrix above
# to set the build mode to "manual" for that language. Then modify this step
# to build your code.
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
- if: matrix.build-mode == 'manual'
shell: bash
run: |
echo 'If you are using a "manual" build mode for one or more of the' \
'languages you are analyzing, replace this with the commands to build' \
'your code, for example:'
echo ' make bootstrap'
echo ' make release'
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -9,7 +9,7 @@ mode: "wide"
## Description
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through EmbedChain.
The `RagTool` is designed to answer questions by leveraging the power of Retrieval-Augmented Generation (RAG) through CrewAI's native RAG system.
It provides a dynamic knowledge base that can be queried to retrieve relevant information from various data sources.
This tool is particularly useful for applications that require access to a vast array of information and need to provide contextually relevant answers.
@@ -76,8 +76,8 @@ The `RagTool` can be used with a wide variety of data sources, including:
The `RagTool` accepts the following parameters:
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, an EmbedchainAdapter will be used.
- **config**: Optional. Configuration for the underlying EmbedChain App.
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
## Adding Content
@@ -130,44 +130,23 @@ from crewai_tools import RagTool
# Create a RAG tool with custom configuration
config = {
"app": {
"name": "custom_app",
},
"llm": {
"provider": "openai",
"vectordb": {
"provider": "qdrant",
"config": {
"model": "gpt-4",
"collection_name": "my-collection"
}
},
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-ada-002"
"model": "text-embedding-3-small"
}
},
"vectordb": {
"provider": "elasticsearch",
"config": {
"collection_name": "my-collection",
"cloud_id": "deployment-name:xxxx",
"api_key": "your-key",
"verify_certs": False
}
},
"chunker": {
"chunk_size": 400,
"chunk_overlap": 100,
"length_function": "len",
"min_chunk_size": 0
}
}
rag_tool = RagTool(config=config, summarize=True)
```
The internal RAG tool utilizes the Embedchain adapter, allowing you to pass any configuration options that are supported by Embedchain.
You can refer to the [Embedchain documentation](https://docs.embedchain.ai/components/introduction) for details.
Make sure to review the configuration options available in the .yaml file.
## Conclusion
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.

View File

@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = ["crewai-tools~=0.71.0"]
tools = ["crewai-tools~=0.73.0"]
embeddings = [
"tiktoken~=0.8.0"
]
@@ -131,13 +131,14 @@ select = [
"I001", # sort imports
"I002", # remove unused imports
]
ignore = ["E501"] # ignore line too long
ignore = ["E501"] # ignore line too long globally
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
"tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests
[tool.mypy]
exclude = ["src/crewai/cli/templates", "tests"]
exclude = ["src/crewai/cli/templates", "tests/"]
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "0.186.1"
__version__ = "0.193.0"
_telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.186.1,<1.0.0"
"crewai[tools]>=0.193.0,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.186.1,<1.0.0",
"crewai[tools]>=0.193.0,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.186.1"
"crewai[tools]>=0.193.0"
]
[tool.crewai]

View File

@@ -3,26 +3,17 @@ import json
import re
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -39,26 +30,15 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
should_auto_collect_first_time_traces,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -70,16 +50,28 @@ from crewai.events.types.crew_events import (
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -94,28 +86,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
class Crew(FlowTrackable, BaseModel):
"""
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
Represents a group of agents, defining how they should collaborate and the
tasks they should perform.
Attributes:
tasks: List of tasks assigned to the crew.
agents: List of agents part of this crew.
tasks: list of tasks assigned to the crew.
agents: list of agents part of this crew.
manager_llm: The language model that will run manager agent.
manager_agent: Custom agent that will be used as manager.
memory: Whether the crew should use memory to store memories of it's execution.
cache: Whether the crew should use a cache to store the results of the tools execution.
function_calling_llm: The language model that will run the tool calling for all the agents.
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
memory: Whether the crew should use memory to store memories of it's
execution.
cache: Whether the crew should use a cache to store the results of the
tools execution.
function_calling_llm: The language model that will run the tool calling
for all the agents.
process: The process flow that the crew will follow (e.g., sequential,
hierarchical).
verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew.
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
max_rpm: Maximum number of requests per minute for the crew execution to
be respected.
prompt_file: Path to the prompt json file to be used for the crew.
id: A unique identifier for the crew instance.
task_callback: Callback to be executed after each task for every agents execution.
step_callback: Callback to be executed after each step for every agents execution.
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
task_callback: Callback to be executed after each task for every agents
execution.
step_callback: Callback to be executed after each step for every agents
execution.
share_crew: Whether you want to share the complete crew information and
execution with crewAI to make the library better, and allow us to
train models.
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
security_config: Security configuration for the crew, including fingerprinting.
chat_llm: The language model used for orchestrating chat interactions
with the crew.
security_config: Security configuration for the crew, including
fingerprinting.
"""
__hash__ = object.__hash__ # type: ignore
@@ -124,13 +128,13 @@ class Crew(FlowTrackable, BaseModel):
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
@@ -138,107 +142,121 @@ class Crew(FlowTrackable, BaseModel):
default_factory=TaskOutputStorageHandler
)
name: Optional[str] = Field(default="crew")
name: str | None = Field(default="crew")
cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
tasks: list[Task] = Field(default_factory=list)
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool = Field(
default=False,
description="Whether the crew should use memory to store memories of it's execution",
description="If crew should use memory to store memories of it's execution",
)
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
default=None,
description="An Instance of the ShortTermMemory to be used by the Crew",
)
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
default=None,
description="An Instance of the LongTermMemory to be used by the Crew",
)
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
entity_memory: InstanceOf[EntityMemory] | None = Field(
default=None,
description="An Instance of the EntityMemory to be used by the Crew",
)
external_memory: Optional[InstanceOf[ExternalMemory]] = Field(
external_memory: InstanceOf[ExternalMemory] | None = Field(
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
embedder: dict | None = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
usage_metrics: Optional[UsageMetrics] = Field(
usage_metrics: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
config: Json | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
step_callback: Optional[Any] = Field(
share_crew: bool | None = Field(default=False)
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step for all agents execution.",
)
task_callback: Optional[Any] = Field(
task_callback: Any | None = Field(
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: List[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
before_kickoff_callbacks: list[
Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field(
default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
)
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
description=(
"List of callbacks to be executed after crew kickoff. "
"It may be used to adjust the output of the crew."
),
)
max_rpm: Optional[int] = Field(
max_rpm: int | None = Field(
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
description=(
"Maximum number of requests per minute for the crew execution "
"to be respected."
),
)
prompt_file: Optional[str] = Field(
prompt_file: str | None = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
output_log_file: Optional[Union[bool, str]] = Field(
output_log_file: bool | str | None = Field(
default=None,
description="Path to the log file to be saved",
)
planning: Optional[bool] = Field(
planning: bool | None = Field(
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
description=(
"Language model that will run the AgentPlanner if planning is True."
),
)
task_execution_output_json_files: Optional[List[str]] = Field(
task_execution_output_json_files: list[str] | None = Field(
default=None,
description="List of file paths for task execution JSON files.",
description="list of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: list[dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
description="list of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
description=(
"Knowledge sources for the crew. Add knowledge sources to the "
"knowledge object."
),
)
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
knowledge: Optional[Knowledge] = Field(
knowledge: Knowledge | None = Field(
default=None,
description="Knowledge for the crew.",
)
@@ -246,18 +264,18 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
token_usage: Optional[UsageMetrics] = Field(
token_usage: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
tracing: Optional[bool] = Field(
tracing: bool | None = Field(
default=False,
description="Whether to enable tracing for the crew.",
)
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
"""Prevent manual setting of the 'id' field by users."""
if v:
raise PydanticCustomError(
@@ -266,9 +284,7 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
@@ -281,12 +297,16 @@ class Crew(FlowTrackable, BaseModel):
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
"""Set private attributes."""
"""set private attributes."""
self._cache_handler = CacheHandler()
event_listener = EventListener()
if is_tracing_enabled() or self.tracing:
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
event_listener.verbose = self.verbose
@@ -314,7 +334,8 @@ class Crew(FlowTrackable, BaseModel):
def create_crew_memory(self) -> "Crew":
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
# External memory does not support a default value since it was
# designed to be managed entirely externally
self.external_memory.set_crew(self) if self.external_memory else None
)
@@ -355,7 +376,10 @@ class Crew(FlowTrackable, BaseModel):
if not self.manager_llm and not self.manager_agent:
raise PydanticCustomError(
"missing_manager_llm_or_manager_agent",
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.",
(
"Attribute `manager_llm` or `manager_agent` is required "
"when using hierarchical process."
),
{},
)
@@ -398,7 +422,10 @@ class Crew(FlowTrackable, BaseModel):
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Sequential process error: Agent is missing in the task "
f"with the following description: {task.description}"
), # type: ignore # Dynamic string in error message
{},
)
@@ -459,7 +486,10 @@ class Crew(FlowTrackable, BaseModel):
if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError(
"invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Conditional Task: {task.description}, "
f"cannot be executed asynchronously."
),
{},
)
return self
@@ -478,7 +508,9 @@ class Crew(FlowTrackable, BaseModel):
for j in range(i - 1, -1, -1):
if self.tasks[j] == context_task:
raise ValueError(
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
f"Task '{task.description}' is asynchronous and "
f"cannot include other sequential asynchronous "
f"tasks in its context."
)
if not self.tasks[j].async_execution:
break
@@ -496,13 +528,15 @@ class Crew(FlowTrackable, BaseModel):
continue # Skip context tasks not in the main tasks list
if task_indices[id(context_task)] > task_indices[id(task)]:
raise ValueError(
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
f"Task '{task.description}' has a context dependency "
f"on a future task '{context_task.description}', "
f"which is not allowed."
)
return self
@property
def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [
source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -518,9 +552,9 @@ class Crew(FlowTrackable, BaseModel):
return self.security_config.fingerprint
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
if self.config is None:
raise ValueError("Config should not be None.")
if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError(
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
@@ -530,7 +564,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task:
def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration.
Args:
@@ -559,7 +593,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file()
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
) -> None:
"""Trains the crew for a given number of iterations."""
inputs = inputs or {}
@@ -611,7 +645,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: dict[str, Any] | None = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -682,9 +716,9 @@ class Crew(FlowTrackable, BaseModel):
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""
results: List[CrewOutput] = []
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input and aggregates results."""
results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics()
@@ -703,14 +737,12 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
async def kickoff_async(
self, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
@@ -739,7 +771,9 @@ class Crew(FlowTrackable, BaseModel):
tasks=self.tasks, planning_agent_llm=self.planning_llm
)._handle_crew_planning()
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task):
for task, step_plan in zip(
self.tasks, result.list_of_plans_per_task, strict=False
):
task.description += step_plan.plan
def _store_execution_log(
@@ -776,7 +810,7 @@ class Crew(FlowTrackable, BaseModel):
return self._execute_tasks(self.tasks)
def _run_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
"""Creates and assigns a manager agent to complete the tasks."""
self._create_manager_agent()
return self._execute_tasks(self.tasks)
@@ -807,23 +841,24 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks(
self,
tasks: List[Task],
start_index: Optional[int] = 0,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.
Args:
tasks (List[Task]): List of tasks to execute
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
manager (Optional[BaseAgent], optional): Manager agent to use for
delegation. Defaults to None.
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None
task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
@@ -838,7 +873,9 @@ class Crew(FlowTrackable, BaseModel):
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
# Determine which tools to use - task tools take precedence over agent tools
@@ -847,7 +884,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
cast(list[Tool] | list[BaseTool], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -867,7 +904,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
@@ -879,7 +916,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -893,11 +930,11 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]],
task_outputs: list[TaskOutput],
futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
) -> TaskOutput | None:
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
@@ -917,8 +954,8 @@ class Crew(FlowTrackable, BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
@@ -947,22 +984,22 @@ class Crew(FlowTrackable, BaseModel):
):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
# Return a List[BaseTool] compatible with Task.execute_sync and execute_async
return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
if self.process == Process.hierarchical:
return self.manager_agent
return task.agent
def _merge_tools(
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
existing_tools: list[Tool] | list[BaseTool],
new_tools: list[Tool] | list[BaseTool],
) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates."""
if not new_tools:
return cast(List[BaseTool], existing_tools)
return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +1010,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools
tools.extend(new_tools)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _inject_delegation_tools(
self,
tools: Union[List[Tool], List[BaseTool]],
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
agents: list[BaseAgent],
) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(list[BaseTool], tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(list[BaseTool], tools)
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -1015,7 +1052,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
@@ -1024,8 +1061,8 @@ class Crew(FlowTrackable, BaseModel):
)
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,18 +1070,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
context = (
return (
aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_tasks(task.context)
)
return context
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None"
@@ -1057,7 +1093,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs:
raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1124,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks(
self,
futures: List[Tuple[Task, Future[TaskOutput], int]],
futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False,
) -> List[TaskOutput]:
task_outputs: List[TaskOutput] = []
) -> list[TaskOutput]:
task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures:
task_output = future.result()
task_outputs.append(task_output)
@@ -1101,9 +1137,7 @@ class Crew(FlowTrackable, BaseModel):
)
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]:
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
return next(
(
index
@@ -1113,9 +1147,8 @@ class Crew(FlowTrackable, BaseModel):
None,
)
def replay(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Replay the crew execution from a specific task."""
stored_outputs = self._task_output_handler.load()
if not stored_outputs:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
@@ -1151,19 +1184,19 @@ class Crew(FlowTrackable, BaseModel):
self.tasks[i].output = task_output
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, start_index, True)
return result
return self._execute_tasks(self.tasks, start_index, True)
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information."""
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,11 +1205,11 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks for inputs
for task in self.tasks:
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
# description and expected_output might contain e.g. {topic}, {user_name}
text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text))
@@ -1230,7 +1263,7 @@ class Crew(FlowTrackable, BaseModel):
cloned_tasks.append(cloned_task)
task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
if isinstance(original_task.context, list):
cloned_context = [
task_mapping[context_task.key]
@@ -1256,7 +1289,7 @@ class Crew(FlowTrackable, BaseModel):
copied_data.pop("agents", None)
copied_data.pop("tasks", None)
copied_crew = Crew(
return Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
@@ -1266,15 +1299,13 @@ class Crew(FlowTrackable, BaseModel):
manager_llm=manager_llm,
)
return copied_crew
def _set_tasks_callbacks(self) -> None:
"""Sets callback for every task suing task_callback"""
for task in self.tasks:
if not task.callback:
task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents."""
[
task.interpolate_inputs_and_add_conversation_history(
@@ -1307,10 +1338,13 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
eval_llm: str | InstanceOf[BaseLLM],
inputs: dict[str, Any] | None = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
"""Test and evaluate the Crew with the given inputs for n iterations.
Uses concurrent.futures for concurrent execution.
"""
try:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
@@ -1350,7 +1384,11 @@ class Crew(FlowTrackable, BaseModel):
raise
def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
return (
f"Crew(id={self.id}, process={self.process}, "
f"number_of_agents={len(self.agents)}, "
f"number_of_tasks={len(self.tasks)})"
)
def reset_memories(self, command_type: str) -> None:
"""Reset specific or all memories for the crew.
@@ -1364,7 +1402,7 @@ class Crew(FlowTrackable, BaseModel):
ValueError: If an invalid command type is provided.
RuntimeError: If memory reset operation fails.
"""
VALID_TYPES = frozenset(
valid_types = frozenset(
[
"long",
"short",
@@ -1377,9 +1415,10 @@ class Crew(FlowTrackable, BaseModel):
]
)
if command_type not in VALID_TYPES:
if command_type not in valid_types:
raise ValueError(
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
f"Invalid command type. Must be one of: "
f"{', '.join(sorted(valid_types))}"
)
try:
@@ -1389,7 +1428,7 @@ class Crew(FlowTrackable, BaseModel):
self._reset_specific_memory(command_type)
except Exception as e:
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
error_msg = f"Failed to reset {command_type} memory: {e!s}"
self._logger.log("error", error_msg)
raise RuntimeError(error_msg) from e
@@ -1397,7 +1436,7 @@ class Crew(FlowTrackable, BaseModel):
"""Reset all available memory systems."""
memory_systems = self._get_memory_systems()
for memory_type, config in memory_systems.items():
for config in memory_systems.values():
if (system := config.get("system")) is not None:
name = config.get("name")
try:
@@ -1405,11 +1444,13 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _reset_specific_memory(self, memory_type: str) -> None:
@@ -1434,18 +1475,21 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _get_memory_systems(self):
"""Get all available memory systems with their configuration.
Returns:
Dict containing all memory systems with their reset functions and display names.
Dict containing all memory systems with their reset functions and
display names.
"""
def default_reset(memory):
@@ -1506,7 +1550,7 @@ class Crew(FlowTrackable, BaseModel):
},
}
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()

View File

@@ -9,48 +9,158 @@ This module provides the event infrastructure that allows users to:
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.logging_events import (
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeRetrievalStartedEvent,
KnowledgeRetrievalCompletedEvent,
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
ReasoningEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffStartedEvent,
CrewKickoffCompletedEvent,
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskEvaluationEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
)
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
from crewai.events.types.tool_usage_events import (
ToolExecutionErrorEvent,
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
ToolValidateInputErrorEvent,
)
__all__ = [
"AgentEvaluationCompletedEvent",
"AgentEvaluationFailedEvent",
"AgentEvaluationStartedEvent",
"AgentExecutionCompletedEvent",
"AgentExecutionErrorEvent",
"AgentExecutionStartedEvent",
"AgentLogsExecutionEvent",
"AgentLogsStartedEvent",
"AgentReasoningCompletedEvent",
"AgentReasoningFailedEvent",
"AgentReasoningStartedEvent",
"BaseEventListener",
"crewai_event_bus",
"CrewKickoffCompletedEvent",
"CrewKickoffFailedEvent",
"CrewKickoffStartedEvent",
"CrewTestCompletedEvent",
"CrewTestFailedEvent",
"CrewTestResultEvent",
"CrewTestStartedEvent",
"CrewTrainCompletedEvent",
"CrewTrainFailedEvent",
"CrewTrainStartedEvent",
"FlowCreatedEvent",
"FlowEvent",
"FlowFinishedEvent",
"FlowPlotEvent",
"FlowStartedEvent",
"KnowledgeQueryCompletedEvent",
"KnowledgeQueryFailedEvent",
"KnowledgeQueryStartedEvent",
"KnowledgeRetrievalCompletedEvent",
"KnowledgeRetrievalStartedEvent",
"KnowledgeSearchQueryFailedEvent",
"LLMCallCompletedEvent",
"LLMCallFailedEvent",
"LLMCallStartedEvent",
"LLMGuardrailCompletedEvent",
"LLMGuardrailStartedEvent",
"LLMStreamChunkEvent",
"LiteAgentExecutionCompletedEvent",
"LiteAgentExecutionErrorEvent",
"LiteAgentExecutionStartedEvent",
"MemoryQueryCompletedEvent",
"MemorySaveCompletedEvent",
"MemorySaveStartedEvent",
"MemoryQueryFailedEvent",
"MemoryQueryStartedEvent",
"MemoryRetrievalCompletedEvent",
"MemoryRetrievalStartedEvent",
"MemorySaveCompletedEvent",
"MemorySaveFailedEvent",
"MemoryQueryFailedEvent",
"KnowledgeRetrievalStartedEvent",
"KnowledgeRetrievalCompletedEvent",
"CrewKickoffStartedEvent",
"CrewKickoffCompletedEvent",
"AgentExecutionCompletedEvent",
"LLMStreamChunkEvent",
]
"MemorySaveStartedEvent",
"MethodExecutionFailedEvent",
"MethodExecutionFinishedEvent",
"MethodExecutionStartedEvent",
"ReasoningEvent",
"TaskCompletedEvent",
"TaskEvaluationEvent",
"TaskFailedEvent",
"TaskStartedEvent",
"ToolExecutionErrorEvent",
"ToolSelectionErrorEvent",
"ToolUsageErrorEvent",
"ToolUsageEvent",
"ToolUsageFinishedEvent",
"ToolUsageStartedEvent",
"ToolValidateInputErrorEvent",
"crewai_event_bus",
]

View File

@@ -0,0 +1,229 @@
import logging
import uuid
import webbrowser
from pathlib import Path
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
from crewai.events.listeners.tracing.utils import (
mark_first_execution_completed,
prompt_user_for_trace_viewing,
should_auto_collect_first_time_traces,
)
logger = logging.getLogger(__name__)
def _update_or_create_env_file():
"""Update or create .env file with CREWAI_TRACING_ENABLED=true."""
env_path = Path(".env")
env_content = ""
variable_name = "CREWAI_TRACING_ENABLED"
variable_value = "true"
# Read existing content if file exists
if env_path.exists():
with open(env_path, "r") as f:
env_content = f.read()
# Check if CREWAI_TRACING_ENABLED is already set
lines = env_content.splitlines()
variable_exists = False
updated_lines = []
for line in lines:
if line.strip().startswith(f"{variable_name}="):
# Update existing variable
updated_lines.append(f"{variable_name}={variable_value}")
variable_exists = True
else:
updated_lines.append(line)
# Add variable if it doesn't exist
if not variable_exists:
if updated_lines and not updated_lines[-1].strip():
# If last line is empty, replace it
updated_lines[-1] = f"{variable_name}={variable_value}"
else:
# Add new line and then the variable
updated_lines.append(f"{variable_name}={variable_value}")
# Write updated content
with open(env_path, "w") as f:
f.write("\n".join(updated_lines))
if updated_lines: # Add final newline if there's content
f.write("\n")
class FirstTimeTraceHandler:
"""Handles the first-time user trace collection and display flow."""
def __init__(self):
self.is_first_time: bool = False
self.collected_events: bool = False
self.trace_batch_id: str | None = None
self.ephemeral_url: str | None = None
self.batch_manager: TraceBatchManager | None = None
def initialize_for_first_time_user(self) -> bool:
"""Check if this is first time and initialize collection."""
self.is_first_time = should_auto_collect_first_time_traces()
return self.is_first_time
def set_batch_manager(self, batch_manager: TraceBatchManager):
"""Set reference to batch manager for sending events."""
self.batch_manager = batch_manager
def mark_events_collected(self):
"""Mark that events have been collected during execution."""
self.collected_events = True
def handle_execution_completion(self):
"""Handle the completion flow as shown in your diagram."""
if not self.is_first_time or not self.collected_events:
return
try:
user_wants_traces = prompt_user_for_trace_viewing(timeout_seconds=20)
if user_wants_traces:
self._initialize_backend_and_send_events()
# Enable tracing for future runs by updating .env file
try:
_update_or_create_env_file()
except Exception: # noqa: S110
pass
if self.ephemeral_url:
self._display_ephemeral_trace_link()
mark_first_execution_completed()
except Exception as e:
self._gracefully_fail(f"Error in trace handling: {e}")
mark_first_execution_completed()
def _initialize_backend_and_send_events(self):
"""Initialize backend batch and send collected events."""
if not self.batch_manager:
return
try:
if not self.batch_manager.backend_initialized:
original_metadata = (
self.batch_manager.current_batch.execution_metadata
if self.batch_manager.current_batch
else {}
)
user_context = {
"privacy_level": "standard",
"user_id": "first_time_user",
"session_id": str(uuid.uuid4()),
"trace_id": self.batch_manager.trace_batch_id,
}
execution_metadata = {
"execution_type": original_metadata.get("execution_type", "crew"),
"crew_name": original_metadata.get(
"crew_name", "First Time Execution"
),
"flow_name": original_metadata.get("flow_name"),
"agent_count": original_metadata.get("agent_count", 1),
"task_count": original_metadata.get("task_count", 1),
"crewai_version": original_metadata.get("crewai_version"),
}
self.batch_manager._initialize_backend_batch(
user_context=user_context,
execution_metadata=execution_metadata,
use_ephemeral=True,
)
self.batch_manager.backend_initialized = True
if self.batch_manager.event_buffer:
self.batch_manager._send_events_to_backend()
self.batch_manager.finalize_batch()
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
if not self.ephemeral_url:
self._show_local_trace_message()
except Exception as e:
self._gracefully_fail(f"Backend initialization failed: {e}")
def _display_ephemeral_trace_link(self):
"""Display the ephemeral trace link to the user and automatically open browser."""
console = Console()
try:
webbrowser.open(self.ephemeral_url)
except Exception: # noqa: S110
pass
panel_content = f"""
🎉 Your First CrewAI Execution Trace is Ready!
View your execution details here:
{self.ephemeral_url}
This trace shows:
• Agent decisions and interactions
• Task execution timeline
• Tool usage and results
• LLM calls and responses
✅ Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
You can also add tracing=True to your Crew(tracing=True) / Flow(tracing=True) for more control.
📝 Note: This link will expire in 24 hours.
""".strip()
panel = Panel(
panel_content,
title="🔍 Execution Trace Generated",
border_style="bright_green",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
def _gracefully_fail(self, error_message: str):
"""Handle errors gracefully without disrupting user experience."""
console = Console()
console.print(f"[yellow]Note: {error_message}[/yellow]")
logger.debug(f"First-time trace error: {error_message}")
def _show_local_trace_message(self):
"""Show message when traces were collected locally but couldn't be uploaded."""
console = Console()
panel_content = f"""
📊 Your execution traces were collected locally!
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
{len(self.batch_manager.event_buffer)} trace events
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
• Batch ID: {self.batch_manager.trace_batch_id}
Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
The traces include agent decisions, task execution, and tool usage.
""".strip()
panel = Panel(
panel_content,
title="🔍 Local Traces Collected",
border_style="yellow",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()

View File

@@ -1,18 +1,18 @@
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime, timezone
from logging import getLogger
from typing import Any
from crewai.utilities.constants import CREWAI_BASE_URL
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.cli.plus_api import PlusAPI
from rich.console import Console
from rich.panel import Panel
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.cli.version import get_crewai_version
from crewai.events.listeners.tracing.types import TraceEvent
from logging import getLogger
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
from crewai.utilities.constants import CREWAI_BASE_URL
logger = getLogger(__name__)
@@ -23,11 +23,11 @@ class TraceBatch:
version: str = field(default_factory=get_crewai_version)
batch_id: str = field(default_factory=lambda: str(uuid.uuid4()))
user_context: Dict[str, str] = field(default_factory=dict)
execution_metadata: Dict[str, Any] = field(default_factory=dict)
events: List[TraceEvent] = field(default_factory=list)
user_context: dict[str, str] = field(default_factory=dict)
execution_metadata: dict[str, Any] = field(default_factory=dict)
events: list[TraceEvent] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"version": self.version,
"batch_id": self.batch_id,
@@ -40,26 +40,28 @@ class TraceBatch:
class TraceBatchManager:
"""Single responsibility: Manage batches and event buffering"""
is_current_batch_ephemeral: bool = False
trace_batch_id: Optional[str] = None
current_batch: Optional[TraceBatch] = None
event_buffer: List[TraceEvent] = []
execution_start_times: Dict[str, datetime] = {}
batch_owner_type: Optional[str] = None
batch_owner_id: Optional[str] = None
def __init__(self):
self.is_current_batch_ephemeral: bool = False
self.trace_batch_id: str | None = None
self.current_batch: TraceBatch | None = None
self.event_buffer: list[TraceEvent] = []
self.execution_start_times: dict[str, datetime] = {}
self.batch_owner_type: str | None = None
self.batch_owner_id: str | None = None
self.backend_initialized: bool = False
self.ephemeral_trace_url: str | None = None
try:
self.plus_api = PlusAPI(
api_key=get_auth_token(),
)
except AuthError:
self.plus_api = PlusAPI(api_key="")
self.ephemeral_trace_url = None
def initialize_batch(
self,
user_context: Dict[str, str],
execution_metadata: Dict[str, Any],
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
) -> TraceBatch:
"""Initialize a new trace batch"""
@@ -70,14 +72,21 @@ class TraceBatchManager:
self.is_current_batch_ephemeral = use_ephemeral
self.record_start_time("execution")
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
)
self.backend_initialized = True
return self.current_batch
def _initialize_backend_batch(
self,
user_context: Dict[str, str],
execution_metadata: Dict[str, Any],
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
):
"""Send batch initialization to backend"""
@@ -129,13 +138,6 @@ class TraceBatchManager:
if not use_ephemeral
else response_data["ephemeral_trace_id"]
)
console = Console()
panel = Panel(
f"✅ Trace batch initialized with session ID: {self.trace_batch_id}",
title="Trace Batch Initialization",
border_style="green",
)
console.print(panel)
else:
logger.warning(
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
@@ -143,7 +145,7 @@ class TraceBatchManager:
except Exception as e:
logger.warning(
f"Error initializing trace batch: {str(e)}. Continuing without tracing."
f"Error initializing trace batch: {e}. Continuing without tracing."
)
def add_event(self, trace_event: TraceEvent):
@@ -154,7 +156,6 @@ class TraceBatchManager:
"""Send buffered events to backend with graceful failure handling"""
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
return 500
try:
payload = {
"events": [event.to_dict() for event in self.event_buffer],
@@ -178,19 +179,19 @@ class TraceBatchManager:
if response.status_code in [200, 201]:
self.event_buffer.clear()
return 200
else:
logger.warning(
f"Failed to send events: {response.status_code}. Events will be lost."
)
return 500
except Exception as e:
logger.warning(
f"Error sending events to backend: {str(e)}. Events will be lost."
f"Failed to send events: {response.status_code}. Events will be lost."
)
return 500
def finalize_batch(self) -> Optional[TraceBatch]:
except Exception as e:
logger.warning(
f"Error sending events to backend: {e}. Events will be lost."
)
return 500
def finalize_batch(self) -> TraceBatch | None:
"""Finalize batch and return it for sending"""
if not self.current_batch:
return None
@@ -246,12 +247,27 @@ class TraceBatchManager:
if not self.is_current_batch_ephemeral and access_code is None
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
)
if self.is_current_batch_ephemeral:
self.ephemeral_trace_url = return_link
# Create a properly formatted message with URL on its own line
message_parts = [
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}",
"",
f"🔗 View here: {return_link}",
]
if access_code:
message_parts.append(f"🔑 Access Code: {access_code}")
panel = Panel(
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
"\n".join(message_parts),
title="Trace Batch Finalization",
border_style="green",
)
console.print(panel)
if not should_auto_collect_first_time_traces():
console.print(panel)
else:
logger.error(
@@ -259,8 +275,8 @@ class TraceBatchManager:
)
except Exception as e:
logger.error(f"❌ Error finalizing trace batch: {str(e)}")
# TODO: send error to app
logger.error(f"❌ Error finalizing trace batch: {e}")
# TODO: send error to app marking as failed
def _cleanup_batch_data(self):
"""Clean up batch data after successful finalization to free memory"""
@@ -277,7 +293,7 @@ class TraceBatchManager:
self.batch_sequence = 0
except Exception as e:
logger.error(f"Warning: Error during cleanup: {str(e)}")
logger.error(f"Warning: Error during cleanup: {e}")
def has_events(self) -> bool:
"""Check if there are events in the buffer"""
@@ -306,7 +322,7 @@ class TraceBatchManager:
return duration_ms
return 0
def get_trace_id(self) -> Optional[str]:
def get_trace_id(self) -> str | None:
"""Get current trace ID"""
if self.current_batch:
return self.current_batch.user_context.get("trace_id")

View File

@@ -1,28 +1,59 @@
import os
import uuid
from typing import Any, ClassVar
from typing import Dict, Any, Optional
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
AgentExecutionErrorEvent,
from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
from crewai.events.listeners.tracing.utils import safe_serialize_to_dict
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -33,49 +64,16 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
FlowFinishedEvent,
MethodExecutionStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
FlowPlotEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.utilities.serialization import to_serializable
from .trace_batch_manager import TraceBatchManager
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
class TraceCollectionListener(BaseEventListener):
"""
Trace collection listener that orchestrates trace collection
"""
complex_events = [
complex_events: ClassVar[list[str]] = [
"task_started",
"task_completed",
"llm_call_started",
@@ -88,14 +86,14 @@ class TraceCollectionListener(BaseEventListener):
_initialized = False
_listeners_setup = False
def __new__(cls, batch_manager=None):
def __new__(cls, batch_manager: TraceBatchManager | None = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
batch_manager: Optional[TraceBatchManager] = None,
batch_manager: TraceBatchManager | None = None,
):
if self._initialized:
return
@@ -103,16 +101,19 @@ class TraceCollectionListener(BaseEventListener):
super().__init__()
self.batch_manager = batch_manager or TraceBatchManager()
self._initialized = True
self.first_time_handler = FirstTimeTraceHandler()
if self.first_time_handler.initialize_for_first_time_user():
self.first_time_handler.set_batch_manager(self.batch_manager)
def _check_authenticated(self) -> bool:
"""Check if tracing should be enabled"""
try:
res = bool(get_auth_token())
return res
return bool(get_auth_token())
except AuthError:
return False
def _get_user_context(self) -> Dict[str, str]:
def _get_user_context(self) -> dict[str, str]:
"""Extract user context for tracing"""
return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
@@ -161,8 +162,14 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event):
self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow":
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
# Normal flow finalization
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
@@ -181,12 +188,20 @@ class TraceCollectionListener(BaseEventListener):
def on_crew_completed(source, event):
self._handle_trace_event("crew_kickoff_completed", source, event)
if self.batch_manager.batch_owner_type == "crew":
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event):
self._handle_trace_event("crew_kickoff_failed", source, event)
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
self.batch_manager.finalize_batch()
@event_bus.on(TaskStartedEvent)
def on_task_started(source, event):
@@ -325,17 +340,19 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata)
def _initialize_batch(
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
):
"""Initialize trace batch if ephemeral"""
if not self._check_authenticated():
self.batch_manager.initialize_batch(
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
if self.first_time_handler.is_first_time:
return self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=True
)
else:
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=False
)
use_ephemeral = not self._check_authenticated()
return self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=use_ephemeral
)
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for context end events"""
@@ -371,11 +388,11 @@ class TraceCollectionListener(BaseEventListener):
def _build_event_data(
self, event_type: str, event: Any, source: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Build event data"""
if event_type not in self.complex_events:
return self._safe_serialize_to_dict(event)
elif event_type == "task_started":
return safe_serialize_to_dict(event)
if event_type == "task_started":
return {
"task_description": event.task.description,
"expected_output": event.task.expected_output,
@@ -384,7 +401,7 @@ class TraceCollectionListener(BaseEventListener):
"agent_role": source.agent.role,
"task_id": str(event.task.id),
}
elif event_type == "task_completed":
if event_type == "task_completed":
return {
"task_description": event.task.description if event.task else None,
"task_name": event.task.name or event.task.description
@@ -397,63 +414,31 @@ class TraceCollectionListener(BaseEventListener):
else None,
"agent_role": event.output.agent if event.output else None,
}
elif event_type == "agent_execution_started":
if event_type == "agent_execution_started":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "agent_execution_completed":
if event_type == "agent_execution_completed":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "llm_call_started":
event_data = self._safe_serialize_to_dict(event)
if event_type == "llm_call_started":
event_data = safe_serialize_to_dict(event)
event_data["task_name"] = (
event.task_name or event.task_description
if hasattr(event, "task_name") and event.task_name
else None
)
return event_data
elif event_type == "llm_call_completed":
return self._safe_serialize_to_dict(event)
else:
return {
"event_type": event_type,
"event": self._safe_serialize_to_dict(event),
"source": source,
}
if event_type == "llm_call_completed":
return safe_serialize_to_dict(event)
# TODO: move to utils
def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None
) -> Dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
try:
serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict):
return serialized
else:
return {"serialized_data": serialized}
except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__}
# TODO: move to utils
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
"""Truncate message content and limit number of messages"""
if not messages or not isinstance(messages, list):
return messages
# Limit number of messages
limited_messages = messages[:max_messages]
# Truncate each message content
for msg in limited_messages:
if isinstance(msg, dict) and "content" in msg:
content = msg["content"]
if len(content) > max_content_length:
msg["content"] = content[:max_content_length] + "..."
return limited_messages
return {
"event_type": event_type,
"event": safe_serialize_to_dict(event),
"source": source,
}

View File

@@ -1,17 +1,25 @@
import getpass
import hashlib
import json
import logging
import os
import platform
import uuid
import hashlib
import subprocess
import getpass
from pathlib import Path
from datetime import datetime
import re
import json
import subprocess
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
import click
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from crewai.utilities.paths import db_storage_path
from crewai.utilities.serialization import to_serializable
logger = logging.getLogger(__name__)
def is_tracing_enabled() -> bool:
@@ -43,13 +51,11 @@ def _get_machine_id() -> str:
try:
mac = ":".join(
["{:02x}".format((uuid.getnode() >> b) & 0xFF) for b in range(0, 12, 2)][
::-1
]
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
)
parts.append(mac)
except Exception:
pass
logger.warning("Error getting machine id for fingerprinting")
sysname = platform.system()
parts.append(sysname)
@@ -57,7 +63,7 @@ def _get_machine_id() -> str:
try:
if sysname == "Darwin":
res = subprocess.run(
["system_profiler", "SPHardwareDataType"],
["/usr/sbin/system_profiler", "SPHardwareDataType"],
capture_output=True,
text=True,
timeout=2,
@@ -72,7 +78,7 @@ def _get_machine_id() -> str:
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
elif sysname == "Windows":
res = subprocess.run(
["wmic", "csproduct", "get", "UUID"],
["C:\\Windows\\System32\\wbem\\wmic.exe", "csproduct", "get", "UUID"],
capture_output=True,
text=True,
timeout=2,
@@ -81,7 +87,7 @@ def _get_machine_id() -> str:
if len(lines) >= 2:
parts.append(lines[1])
except Exception:
pass
logger.exception("Error getting machine ID")
return hashlib.sha256("".join(parts).encode()).hexdigest()
@@ -97,8 +103,8 @@ def _load_user_data() -> dict:
if p.exists():
try:
return json.loads(p.read_text())
except Exception:
pass
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning(f"Failed to load user data: {e}")
return {}
@@ -106,8 +112,8 @@ def _save_user_data(data: dict) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except Exception:
pass
except (OSError, PermissionError) as e:
logger.warning(f"Failed to save user data: {e}")
def get_user_id() -> str:
@@ -151,3 +157,103 @@ def mark_first_execution_done() -> None:
}
)
_save_user_data(data)
def safe_serialize_to_dict(obj, exclude: set[str] | None = None) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
try:
serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict):
return serialized
return {"serialized_data": serialized}
except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__}
def truncate_messages(messages, max_content_length=500, max_messages=5):
"""Truncate message content and limit number of messages"""
if not messages or not isinstance(messages, list):
return messages
limited_messages = messages[:max_messages]
for msg in limited_messages:
if isinstance(msg, dict) and "content" in msg:
content = msg["content"]
if len(content) > max_content_length:
msg["content"] = content[:max_content_length] + "..."
return limited_messages
def should_auto_collect_first_time_traces() -> bool:
"""True if we should auto-collect traces for first-time user."""
if _is_test_environment():
return False
return is_first_execution()
def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
"""
Prompt user if they want to see their traces with timeout.
Returns True if user wants to see traces, False otherwise.
"""
if _is_test_environment():
return False
try:
import threading
console = Console()
content = Text()
content.append("🔍 ", style="cyan bold")
content.append(
"Detailed execution traces are available!\n\n", style="cyan bold"
)
content.append("View insights including:\n", style="white")
content.append(" • Agent decision-making process\n", style="bright_blue")
content.append(" • Task execution flow and timing\n", style="bright_blue")
content.append(" • Tool usage details", style="bright_blue")
panel = Panel(
content,
title="[bold cyan]Execution Traces[/bold cyan]",
border_style="cyan",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
prompt_text = click.style(
f"Would you like to view your execution traces? [y/N] ({timeout_seconds}s timeout): ",
fg="white",
bold=True,
)
click.echo(prompt_text, nl=False)
result = [False]
def get_input():
try:
response = input().strip().lower()
result[0] = response in ["y", "yes"]
except (EOFError, KeyboardInterrupt):
result[0] = False
input_thread = threading.Thread(target=get_input, daemon=True)
input_thread.start()
input_thread.join(timeout=timeout_seconds)
if input_thread.is_alive():
return False
return result[0]
except Exception:
return False
def mark_first_execution_completed() -> None:
"""Mark first execution as completed (called after trace prompt)."""
mark_first_execution_done()

View File

@@ -2,30 +2,22 @@ import asyncio
import copy
import inspect
import logging
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from collections.abc import Callable
from typing import Any, ClassVar, Generic, TypeVar, cast
from uuid import uuid4
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from pydantic import BaseModel, Field, ValidationError
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.events.event_bus import crewai_event_bus
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
should_auto_collect_first_time_traces,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -35,12 +27,10 @@ from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer
logger = logging.getLogger(__name__)
@@ -55,16 +45,14 @@ class FlowState(BaseModel):
)
# Type variables with explicit bounds
T = TypeVar(
"T", bound=Union[Dict[str, Any], BaseModel]
) # Generic flow state type parameter
# type variables with explicit bounds
T = TypeVar("T", bound=dict[str, Any] | BaseModel) # Generic flow state type parameter
StateT = TypeVar(
"StateT", bound=Union[Dict[str, Any], BaseModel]
"StateT", bound=dict[str, Any] | BaseModel
) # State validation type parameter
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation.
Args:
@@ -104,7 +92,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
raise TypeError(f"Invalid expected_type: {expected_type}")
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
def start(condition: str | dict | Callable | None = None) -> Callable:
"""
Marks a method as a flow's starting point.
@@ -171,7 +159,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable:
def listen(condition: str | dict | Callable) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
@@ -231,7 +219,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def router(condition: Union[str, dict, Callable]) -> Callable:
def router(condition: str | dict | Callable) -> Callable:
"""
Creates a routing method that directs flow execution based on conditions.
@@ -297,7 +285,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict:
def or_(*conditions: str | dict | Callable) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
@@ -343,7 +331,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
return {"type": "OR", "methods": methods}
def and_(*conditions: Union[str, dict, Callable]) -> dict:
def and_(*conditions: str | dict | Callable) -> dict:
"""
Combines multiple conditions with AND logic for flow control.
@@ -425,10 +413,10 @@ class FlowMeta(type):
if possible_returns:
router_paths[attr_name] = possible_returns
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
cls._start_methods = start_methods
cls._listeners = listeners
cls._routers = routers
cls._router_paths = router_paths
return cls
@@ -436,29 +424,29 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
_printer = Printer()
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
name: Optional[str] = None
tracing: Optional[bool] = False
_start_methods: ClassVar[list[str]] = []
_listeners: ClassVar[dict[str, tuple[str, list[str]]]] = {}
_routers: ClassVar[set[str]] = set()
_router_paths: ClassVar[dict[str, list[str]]] = {}
initial_state: type[T] | T | None = None
name: str | None = None
tracing: bool | None = False
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
_initial_state_t = item # type: ignore
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
return _FlowGeneric
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
tracing: Optional[bool] = False,
persistence: FlowPersistence | None = None,
tracing: bool | None = False,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
@@ -468,18 +456,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: Dict[str, Callable] = {}
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._completed_methods: Set[str] = set() # Track completed methods for reload
self._persistence: Optional[FlowPersistence] = persistence
self._methods: dict[str, Callable] = {}
self._method_execution_counts: dict[str, int] = {}
self._pending_and_listeners: dict[str, set[str]] = {}
self._method_outputs: list[Any] = [] # list to store all method outputs
self._completed_methods: set[str] = set() # Track completed methods for reload
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
if is_tracing_enabled() or self.tracing:
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
@@ -521,25 +513,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
TypeError: If state is neither BaseModel nor dictionary
"""
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T")
if self.initial_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
instance.id = str(uuid4())
return cast(T, instance)
elif issubclass(state_type, BaseModel):
if issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
instance.id = str(uuid4())
return cast(T, instance)
elif state_type is dict:
if state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
@@ -550,13 +542,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
elif issubclass(self.initial_state, BaseModel):
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
elif self.initial_state is dict:
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
@@ -600,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._state
@property
def method_outputs(self) -> List[Any]:
def method_outputs(self) -> list[Any]:
"""Returns the list of all outputs from executed methods."""
return self._method_outputs
@@ -631,13 +623,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
if isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", ""))
return ""
except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
def _initialize_state(self, inputs: dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Args:
@@ -691,7 +683,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
def _restore_state(self, stored_state: dict[str, Any]) -> None:
"""Restore flow state from persistence.
Args:
@@ -735,7 +727,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
execution_data: Flow execution data containing:
- id: Flow execution ID
- flow: Flow structure
- completed_methods: List of successfully completed methods
- completed_methods: list of successfully completed methods
- execution_methods: All execution methods with their status
"""
flow_id = execution_data.get("id")
@@ -771,7 +763,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if state_to_apply:
self._apply_state_updates(state_to_apply)
for i, method in enumerate(sorted_methods[:-1]):
for method in sorted_methods[:-1]:
method_name = method.get("flow_method", {}).get("name")
if method_name:
self._completed_methods.add(method_name)
@@ -783,7 +775,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif hasattr(self._state, field_name):
object.__setattr__(self._state, field_name, value)
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
def _apply_state_updates(self, updates: dict[str, Any]) -> None:
"""Apply multiple state updates efficiently."""
if isinstance(self._state, dict):
self._state.update(updates)
@@ -792,7 +784,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if hasattr(self._state, key):
object.__setattr__(self._state, key, value)
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
"""
Start the flow execution in a synchronous context.
@@ -805,7 +797,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(run_flow())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
"""
Start the flow execution asynchronously.
@@ -840,7 +832,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
setattr(self._state, "id", inputs["id"]) # noqa: B010
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
@@ -1075,7 +1067,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
# Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method] + router_results
all_triggers = [trigger_method, *router_results]
for current_trigger in all_triggers:
if current_trigger: # Skip None results
@@ -1109,7 +1101,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
) -> list[str]:
"""
Finds all methods that should be triggered based on conditions.
@@ -1126,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns
-------
List[str]
list[str]
Names of methods that should be triggered.
Notes

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -13,23 +14,23 @@ class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
super().__init__(**data)
@@ -40,11 +41,10 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -55,12 +55,11 @@ class Knowledge(BaseModel):
if self.storage is None:
raise ValueError("Storage is not initialized.")
results = self.storage.search(
return self.storage.search(
query,
limit=results_limit,
score_threshold=score_threshold,
)
return results
def add_sources(self):
try:

View File

@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")
results_limit: int = Field(default=5, description="The number of results to return")
score_threshold: float = Field(
default=0.35,
default=0.6,
description="The minimum score for a result to be considered relevant",
)

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC):
@@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -1,24 +1,17 @@
import hashlib
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import traceback
import warnings
from typing import Any, cast
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
class KnowledgeStorage(BaseKnowledgeStorage):
@@ -27,167 +20,105 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency.
"""
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._set_embedder_config(embedder)
self._client: BaseClient | None = None
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection:
fetched = self.collection.query(
query_texts=query,
n_results=limit,
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self.app = create_persistent_client(
path=os.path.join(db_storage_path(), "knowledge"),
settings=Settings(allow_reset=True),
)
if embedder:
embedding_function = get_embedding_function(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def search(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
if self.app:
self.collection = self.app.get_or_create_collection(
name=sanitize_collection_name(collection_name),
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")
except Exception:
raise Exception("Failed to create or get collection")
query_text = " ".join(query) if len(query) > 1 else query[0]
def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = create_persistent_client(
path=base_path, settings=Settings(allow_reset=True)
return client.search(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
self.app.reset()
shutil.rmtree(base_path)
self.app = None
self.collection = None
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
if not self.collection:
raise Exception("Collection not initialized")
try:
# Create a dictionary to store unique documents
unique_docs = {}
# Generate IDs and create a mapping of id -> (document, metadata)
for idx, doc in enumerate(documents):
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
doc_metadata = None
if metadata is not None:
if isinstance(metadata, list):
doc_metadata = metadata[idx]
else:
doc_metadata = metadata
unique_docs[doc_id] = (doc, doc_metadata)
# Prepare filtered lists for ChromaDB
filtered_docs = []
filtered_metadata = []
filtered_ids = []
# Build the filtered lists
for doc_id, (doc, meta) in unique_docs.items():
filtered_docs.append(doc)
filtered_metadata.append(meta)
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
self.collection.upsert(
documents=filtered_docs,
metadatas=final_metadata,
ids=filtered_ids,
)
except chromadb.errors.InvalidDimensionException as e:
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
except Exception as e:
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)
def save(self, documents: list[str]) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.get_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
if embedder
else self._create_default_embedding_function()
)

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, List
from crewai.rag.types import SearchResult
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
"""Extract knowledge from the task prompt."""
valid_snippets = [
result["context"]
result["content"]
for result in knowledge_snippets
if result and result.get("context")
if result and result.get("content")
]
snippet = "\n".join(valid_snippets)
return f"Additional Information: {snippet}" if valid_snippets else ""

View File

@@ -1,4 +1,6 @@
from typing import Optional, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.memory import (
EntityMemory,
@@ -19,9 +21,9 @@ class ContextualMemory:
ltm: LongTermMemory,
em: EntityMemory,
exm: ExternalMemory,
agent: Optional["Agent"] = None,
task: Optional["Task"] = None,
):
agent: Agent | None = None,
task: Task | None = None,
) -> None:
self.stm = stm
self.ltm = ltm
self.em = em
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent
self.exm.task = self.task
def build_context_for_task(self, task, context) -> str:
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
@@ -52,14 +54,15 @@ class ContextualMemory:
if query == "":
return ""
context = []
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
context.append(self._fetch_external_context(query))
return "\n".join(filter(None, context))
context_parts = [
self._fetch_ltm_context(task.description),
self._fetch_stm_context(query),
self._fetch_entity_context(query),
self._fetch_external_context(query),
]
return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str:
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
@@ -70,11 +73,11 @@ class ContextualMemory:
stm_results = self.stm.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
def _fetch_ltm_context(self, task: str) -> str | None:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -90,14 +93,14 @@ class ContextualMemory:
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
def _fetch_entity_context(self, query: str) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
@@ -107,7 +110,7 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
@@ -128,6 +131,6 @@ class ContextualMemory:
return ""
formatted_memories = "\n".join(
f"- {result['context']}" for result in external_memories
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
saved_count = 0
errors = []
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item and return success status."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super(EntityMemory, self).save(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
success, error = save_single_item(item)
if success:
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
raise Exception(
f"An error occurred while resetting the entity memory: {e}"
) from e

View File

@@ -1,41 +1,41 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
import time
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any):
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
@@ -12,8 +12,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
_agent: Optional["Agent"] = None
@@ -45,7 +45,7 @@ class Memory(BaseModel):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
@@ -54,9 +54,9 @@ class Memory(BaseModel):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
crewai_event_bus.emit(
self,
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)
) from e

View File

@@ -1,12 +1,13 @@
import os
import re
from collections import defaultdict
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any
from mem0 import Memory, MemoryClient # type: ignore[import-untyped]
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.rag.chromadb.utils import _sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -15,6 +16,7 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
super().__init__()
@@ -30,7 +32,8 @@ class Mem0Storage(Storage):
supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
f"Invalid type '{type}' for Mem0Storage. "
f"Must be one of: {', '.join(supported_types)}"
)
def _extract_config_values(self):
@@ -68,7 +71,8 @@ class Mem0Storage(Storage):
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
- Includes run_id if memory_type is 'short_term' and
mem0_run_id is present.
"""
filter = defaultdict(list)
@@ -91,10 +95,14 @@ class Mem0Storage(Storage):
def save(self, value: Any, metadata: dict[str, Any]) -> None:
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
return next(
(m.get("content", "") for m in reversed(list(messages)) if m.get("role") == role),
""
(
m.get("content", "")
for m in reversed(list(messages))
if m.get("role") == role
),
"",
)
conversations = []
messages = metadata.pop("messages", None)
if messages:
@@ -103,7 +111,7 @@ class Mem0Storage(Storage):
if user_msg := self._get_user_message(last_user):
conversations.append({"role": "user", "content": user_msg})
if assistant_msg := self._get_assistant_message(last_assistant):
conversations.append({"role": "assistant", "content": assistant_msg})
else:
@@ -115,13 +123,13 @@ class Mem0Storage(Storage):
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer
"infer": self.infer,
}
# MemoryClient-specific overrides
@@ -142,13 +150,15 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> list[Any]:
def search(
self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
"output_format": "v1.1",
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -169,10 +179,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold
params["threshold"] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format']
del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"):
del params["run_id"]
@@ -180,7 +190,7 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve
for result in results["results"]:
result["context"] = result["memory"]
result["content"] = result["memory"]
return [r for r in results["results"]]
@@ -201,7 +211,9 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
return _sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)
def _get_assistant_message(self, text: str) -> str:
marker = "Final Answer:"

View File

@@ -1,17 +1,17 @@
import logging
import os
import shutil
import uuid
import traceback
import warnings
from typing import Any
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client
from crewai.rag.types import BaseRecord
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
class RAGStorage(BaseRAGStorage):
@@ -20,8 +20,6 @@ class RAGStorage(BaseRAGStorage):
search efficiency.
"""
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
@@ -33,37 +31,25 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._client: BaseClient | None = None
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self._set_embedder_config()
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function)
self._client = create_client(config)
self.app = create_persistent_client(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
self.collection = self.app.get_or_create_collection(
name=self.type, embedding_function=self.embedder_config
)
logging.info(f"Collection found or created: {self.collection}")
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def _sanitize_role(self, role: str) -> str:
"""
@@ -85,77 +71,69 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
def save(self, value: Any, metadata: dict[str, Any]) -> None:
try:
self._generate_embedding(value, metadata)
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.get_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
try:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
return results
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return client.search(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
def reset(self) -> None:
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
self.app = None
self.collection = None
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
if "attempt to write a readonly database" in str(
e
) or "does not exist" in str(e):
# Ignore readonly database and collection not found errors (already reset)
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
) from e

View File

@@ -4,8 +4,9 @@ import logging
from typing import Any
from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
)
from chromadb.api.types import (
QueryResult,
)
from typing_extensions import Unpack
@@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import (
_process_query_results,
_sanitize_collection_name,
)
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.utilities.logger_utils import suppress_logging
class ChromaDBClient(BaseClient):
@@ -41,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -300,16 +309,18 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -342,15 +353,17 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
def search(
@@ -385,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
@@ -443,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)

View File

@@ -1,20 +1,20 @@
"""ChromaDB configuration model."""
import os
import warnings
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
from chromadb.config import Settings
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
DEFAULT_TENANT,
)
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
warnings.filterwarnings(
"ignore",
@@ -49,7 +49,17 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
Returns:
Default embedding function using all-MiniLM-L6-v2 via ONNX.
"""
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return cast(
ChromaEmbeddingFunctionWrapper,
OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
),
)
@pyd_dataclass(frozen=True)

View File

@@ -2,11 +2,12 @@
import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
@@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""
persist_dir = config.settings.persist_directory
os.makedirs(persist_dir, exist_ok=True)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
@@ -37,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
return ChromaDBClient(
client=client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -3,27 +3,28 @@
from collections.abc import Mapping
from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include,
Loadable,
Where,
WhereDocument,
)
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
@@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple):
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings
metadatas: List of document metadata mappings (empty dict for no metadata)
"""
ids: list[str]
@@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable]
embedding_function: ChromaEmbeddingFunction
data_loader: DataLoader[Loadable]
get_or_create: bool

View File

@@ -5,13 +5,14 @@ from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
@@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb(
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
metadatas.append(metadata[0] if metadata and metadata[0] else {})
else:
metadatas.append(metadata)
else:
@@ -132,6 +133,9 @@ def _convert_distance_to_score(
if distance_metric == "cosine":
score = 1.0 - 0.5 * distance
return max(0.0, min(1.0, score))
if distance_metric == "l2":
score = 1.0 / (1.0 + distance)
return max(0.0, min(1.0, score))
raise ValueError(f"Unsupported distance metric: {distance_metric}")
@@ -154,7 +158,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
include_strings = [item.value for item in include] if include else []
ids = results["ids"][0] if results.get("ids") else []
@@ -188,7 +192,9 @@ def _convert_chromadb_results_to_search_results(
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"metadata": dict(metadatas[i])
if metadatas and i < len(metadatas) and metadatas[i] is not None
else {},
"score": score,
}
search_results.append(result)
@@ -271,7 +277,7 @@ def _sanitize_collection_name(
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():

View File

@@ -14,3 +14,5 @@ class BaseRagConfig:
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)

View File

@@ -1,15 +1,15 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from typing import Annotated, Any, Protocol, runtime_checkable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Required, TypedDict, Unpack
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
EmbeddingFunction,
SearchResult,
)
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
metadata_filter: dict[str, Any] | None
score_threshold: float

View File

@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
@@ -60,7 +60,7 @@ def get_embedding_function(
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings (default)
- openai: OpenAI embeddings
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
@@ -77,7 +77,7 @@ def get_embedding_function(
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples:
# Use default OpenAI with retry logic
# Use default OpenAI embedding
>>> embedder = get_embedding_function()
# Use Cohere with dict

View File

@@ -6,8 +6,8 @@ from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params,
_process_search_results,
)
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")

View File

@@ -1,6 +1,7 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseRAGStorage(ABC):
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
self.type = type
@@ -32,45 +32,21 @@ class BaseRAGStorage(ABC):
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass

View File

@@ -1,83 +0,0 @@
import os
import re
import portalocker
from chromadb import PersistentClient
from hashlib import md5
from typing import Optional
from crewai.utilities.paths import db_storage_path
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized
def create_persistent_client(path: str, **kwargs):
"""
Creates a persistent client for ChromaDB with a lock file to prevent
concurrent creations. Works for both multi-threads and multi-processes
environments.
"""
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(path=path, **kwargs)
return client

View File

@@ -9,19 +9,19 @@ import pytest
from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.process import Process
from crewai.tools import tool
from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController
from crewai.utilities.errors import AgentRepositoryError
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.process import Process
def test_agent_llm_creation_with_env_vars():
@@ -137,35 +137,6 @@ def test_custom_llm():
assert agent.llm.model == "gpt-4"
def test_custom_llm_with_langchain():
from langchain_openai import ChatOpenAI
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=ChatOpenAI(temperature=0, model="gpt-4"),
)
assert agent.llm.model == "gpt-4"
def test_custom_llm_temperature_preservation():
from langchain_openai import ChatOpenAI
langchain_llm = ChatOpenAI(temperature=0.7, model="gpt-4")
agent = Agent(
role="temperature test role",
goal="temperature test goal",
backstory="temperature test backstory",
llm=langchain_llm,
)
assert isinstance(agent.llm, LLM)
assert agent.llm.model == "gpt-4"
assert agent.llm.temperature == 0.7
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution():
agent = Agent(
@@ -445,7 +416,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_powered_by_new_o_model_family_that_uses_tool():
@tool
def comapny_customer_data() -> float:
def comapny_customer_data() -> str:
"""Useful for getting customer related data."""
return "The company has 42 customers"
@@ -500,6 +471,15 @@ def test_agent_custom_max_iterations():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage(capsys):
"""Test that agents handle repeated tool usage appropriately.
Notes:
Investigate whether to pin down the specific execution flow by examining
src/crewai/agents/crew_agent_executor.py:177-186 (max iterations check)
and src/crewai/tools/tool_usage.py:152-157 (repeated usage detection)
to ensure deterministic behavior.
"""
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this tool non-stop."""
@@ -527,41 +507,15 @@ def test_agent_repeated_tool_usage(capsys):
)
captured = capsys.readouterr()
output = (
captured.out.replace("\n", " ")
.replace(" ", " ")
.strip()
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("[", "")
.replace("]", "")
.replace("bold", "")
.replace("blue", "")
.replace("yellow", "")
.replace("green", "")
.replace("red", "")
.replace("dim", "")
.replace("🤖", "")
.replace("🔧", "")
.replace("", "")
.replace("\x1b[93m", "")
.replace("\x1b[00m", "")
.replace("\\", "")
.replace('"', "")
.replace("'", "")
)
output_lower = captured.out.lower()
# Look for the message in the normalized output, handling the apostrophe difference
expected_message = (
"I tried reusing the same input, I must stop using this action input."
has_repeated_usage_message = "tried reusing the same input" in output_lower
has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
)
assert (
expected_message in output
), f"Expected message not found in output. Output was: {output}"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -602,9 +556,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out
assert (
has_repeated_usage_message or (has_max_iterations and has_final_answer)
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -880,7 +834,7 @@ def test_agent_step_callback():
with patch.object(StepCallback, "callback") as callback:
@tool
def learn_about_AI() -> str:
def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field."
@@ -888,7 +842,7 @@ def test_agent_step_callback():
role="test role",
goal="test goal",
backstory="test backstory",
tools=[learn_about_AI],
tools=[learn_about_ai],
step_callback=StepCallback().callback,
)
@@ -910,7 +864,7 @@ def test_agent_function_calling_llm():
llm = "gpt-4o"
@tool
def learn_about_AI() -> str:
def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field."
@@ -918,7 +872,7 @@ def test_agent_function_calling_llm():
role="test role",
goal="test goal",
backstory="test backstory",
tools=[learn_about_AI],
tools=[learn_about_ai],
llm="gpt-4o",
max_iter=2,
function_calling_llm=llm,
@@ -1356,7 +1310,7 @@ def test_agent_training_handler(crew_training_handler):
verbose=True,
)
crew_training_handler().load.return_value = {
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
f"{agent.id!s}": {"0": {"human_feedback": "good"}}
}
result = agent._training_handler(task_prompt=task_prompt)
@@ -1473,7 +1427,7 @@ def test_agent_with_custom_stop_words():
)
assert isinstance(agent.llm, LLM)
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
assert all(word in agent.llm.stop for word in stop_words)
assert "\nObservation:" in agent.llm.stop
@@ -1530,7 +1484,7 @@ def test_llm_call_with_error():
llm = LLM(model="non-existent-model")
messages = [{"role": "user", "content": "This should fail"}]
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
llm.call(messages)
@@ -1830,11 +1784,11 @@ def test_agent_execute_task_with_ollama():
def test_agent_with_knowledge_sources():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with patch("crewai.knowledge") as mock_knowledge:
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.search.return_value = [{"content": content}]
MockKnowledge.add_sources.return_value = [string_source]
mock_knowledge.add_sources.return_value = [string_source]
agent = Agent(
role="Information Agent",
@@ -1863,12 +1817,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query:
agent = Agent(
role="Information Agent",
@@ -1898,15 +1865,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query:
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
agent = Agent(
role="Information Agent",
goal="Provide information based on knowledge sources",
@@ -1935,10 +1914,16 @@ def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with (
patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
) as mock_save,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
mock_save.return_value = None
agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters",
@@ -1968,8 +1953,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch(
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
autospec=True,
) as MockKnowledgeSource:
mock_knowledge_source_instance = MockKnowledgeSource.return_value
) as mock_knowledge_source:
mock_knowledge_source_instance = mock_knowledge_source.return_value
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
mock_knowledge_source_instance.sources = [string_source]
@@ -1983,9 +1968,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledgeStorage:
mock_knowledge_storage = MockKnowledgeStorage.return_value
agent.knowledge_storage = mock_knowledge_storage
) as mock_knowledge_storage:
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
agent.knowledge_storage = mock_knowledge_storage_instance
agent_copy = agent.copy()
@@ -2004,11 +1989,30 @@ def test_agent_with_knowledge_sources_generate_search_query():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with (
patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters",
goal="Provide information based on knowledge sources",
@@ -2270,7 +2274,26 @@ def test_get_knowledge_search_query():
i18n = I18N()
task_prompt = task.prompt()
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query:
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
mock_get_query.return_value = "Capital of France"
crew = Crew(agents=[agent], tasks=[task])
@@ -2309,13 +2332,11 @@ def mock_get_auth_token():
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import (
SerperDevTool,
FileReadTool,
EnterpriseActionTool,
)
from crewai_tools import (
EnterpriseActionTool,
FileReadTool,
SerperDevTool,
)
mock_get_response = MagicMock()
mock_get_response.status_code = 200
@@ -2347,7 +2368,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
tool_action = EnterpriseActionTool(
name="test_name",
description="test_description",
enterprise_action_token="test_token",
enterprise_action_token="test_token", # noqa: S106
action_name="test_action_name",
action_schema={"test": "test"},
)
@@ -2371,9 +2392,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import SerperDevTool
from crewai_tools import SerperDevTool
mock_get_response = MagicMock()
mock_get_response.status_code = 200

View File

@@ -0,0 +1,130 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
cookie:
- _cfuvid=NaXWifUGChHp6Ap1mvfMrNzmO4HdzddrqXkSR9T.hYo-1754508545647-0.0.1.1-604800000
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFLBbtswDL37Kzid4yFx46bxbVixtsfssB22wlAl2lEri5okJ+uK/Psg
OY3dtQV2MWA+vqf3SD5lAExJVgETWx5EZ3X++SrcY/Hrcec3l+SKP5frm/16Yx92m6/f9mwWGXR3
jyI8sz4K6qzGoMgMsHDIA0bVxaq8WCzPyrN5AjqSqCOttSFfUt4po/JiXizz+SpfXBzZW1ICPavg
RwYA8JS+0aeR+JtVkLRSpUPveYusOjUBMEc6Vhj3XvnATWCzERRkAppk/QYM7UFwA63aIXBoo23g
xu/RAfw0X5ThGj6l/wquUWuawXdyWn6YSjpses9jLNNrPQG4MRR4HEsKc3tEDif7mlrr6M7/Q2WN
Mspva4fck4lWfSDLEnrIAG7TmPoXyZl11NlQB3rA9NyiXA16bNzOFD2CgQLXk/qqmL2hV0sMXGk/
GTQTXGxRjtRxK7yXiiZANkn92s1b2kNyZdr/kR8BIdAGlLV1KJV4mXhscxiP972205STYebR7ZTA
Oih0cRMSG97r4aSYf/QBu7pRpkVnnRruqrF1eT7nzTmW5Zplh+wvAAAA//8DAGKunMhlAwAA
headers:
CF-RAY:
- 980b99a73c1c22c6-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:12:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=Ahwkw3J9CDiluZudRgDmybz4FO07eXLz2MQDtkgfct4-1758143531-1.0.1.1-_3e8agfTZW.FPpRMLb1A2nET4OHQEGKNZeGeWT8LIiuSi8R2HWsGsJyueUyzYBYnfHqsfBUO16K1.TkEo2XiqVCaIi6pymeeQxwtXFF1wj8;
path=/; expires=Wed, 17-Sep-25 21:42:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=iHqLoc_2sNQLMyzfGCLtGol8vf1Y44xirzQJUuUF_TI-1758143531242-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '419'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '609'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999827'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_ece5f999e09e4c189d38e5bc08b2fad9
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,128 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blBSmrabG0ICViqCCydYRbPOJDHreIztbIFV/zty
0m1SPiQukTJv3vN7M/OUAAhVixKE7DDI3ur09dvgfa8P/FO+e/9wa/aHb4cPH9viEw5yK1aRwfdf
SYZn1gvJvdUUFJsJlo4wUFTNd8U+32zybD0CPdekI621Id1w2iuj0nW23qTZLs33Z3bHSpIXJXxO
AACexm/0aWr6LkrIVs+VnrzHlkR5aQIQjnWsCPRe+YAmiNUMSjaBzGj9FgwfQaKBVj0SILTRNqDx
R3IAX8wbZVDDq/G/hI60Zjiy0/VS0FEzeIyhzKD1AkBjOGAcyhjl7oycLuY1t9bxvf+NKhpllO8q
R+jZRKM+sBUjekoA7sYhDVe5hXXc21AFfqDxubzYTXpi3s0CfXkGAwfUi/ruPNprvaqmgEr7xZiF
RNlRPVPnneBQK14AySL1n27+pj0lV6b9H/kZkJJsoLqyjmolrxPPbY7i6f6r7TLl0bDw5B6VpCoo
cnETNTU46OmghP/hA/VVo0xLzjo1XVVjq2KbYbOlorgRySn5BQAA//8DALxsmCBjAwAA
headers:
CF-RAY:
- 980ba79a4ab5f555-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:21:42 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=aMMf0fLckKHz0BLW_2lATxD.7R61uYo1ZVW8aeFbruA-1758144102-1.0.1.1-6EKM3UxpdczoiQ6VpPpqqVnY7ftnXndFRWE4vyTzVcy.CQ4N539D97Wh8Ye9EUAvpUuukhW.r5MznkXq4tPXgCCmEv44RvVz2GBAz_e31h8;
path=/; expires=Wed, 17-Sep-25 21:51:42 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=VqrtvU8.QdEHc4.1XXUVmccaCcoj_CiNfI2zhKJoGRs-1758144102566-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '308'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '620'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999827'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_fa896433021140238115972280c05651
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,127 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
is the expected criteria for your final answer: test output\nyou MUST return
the actual complete content as the final answer, not a summary.\n\nBegin! This
is VERY important to you, use the tools available and give your best Final Answer,
your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop": ["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '812'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrKxY8W4WV+JHoVgR95NJD4UvaBgJDrSS2FJclV3bSwP9e
kHYsuU2BXghwZ2c4s8vnDEDoWpQgVCdZ9c7kNx+4v7vb7jafnrrPX25/vtObX48f1Q31m+uFmEUG
PXxHxS+sN4p6Z5A12QOsPErGqFqsl1fF4nJdzBPQU40m0lrH+YLyXludX8wvFvl8nRdXR3ZHWmEQ
JXzNAACe0xl92hofRQlJK1V6DEG2KMpTE4DwZGJFyBB0YGlZzEZQkWW0yfotWNqBkhZavUWQ0Ebb
IG3YoQf4Zt9rKw28TfcSNhgYaGA3nAl6bIYgYyg7GDMBpLXEMg4lRbk/IvuTeUOt8/QQ/qCKRlsd
usqjDGSj0cDkREL3GcB9GtJwlls4T73jiukHpueK5eKgJ8bdTNDLI8jE0kzqq/XsFb2qRpbahMmY
hZKqw3qkjjuRQ61pAmST1H+7eU37kFzb9n/kR0ApdIx15TzWWp0nHts8xq/7r7bTlJNhEdBvtcKK
Nfq4iRobOZjD/kV4Cox91WjbondeH35V46rlai6bFS6X1yLbZ78BAAD//wMAZdfoWWMDAAA=
headers:
CF-RAY:
- 980b9e0c5fa516a0-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:15:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=w6UZxbAZgYg9EFkKPfrSbMK97MB4jfs7YyvcEmgkvak-1758143711-1.0.1.1-j7YC1nvoMKxYK0T.5G2XDF6TXUCPu_HUs4YO9v65r3NHQFIcOaHbQXX4vqabSgynL2tZy23pbZgD8Cdmxhdw9dp4zkAXhU.imP43_pw4dSE;
path=/; expires=Wed, 17-Sep-25 21:45:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=ij9Q8tB7sj2GczANlJ7gbXVjj6hMhz1iVb6oGHuRYu8-1758143711202-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '462'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '665'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999830'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_04536db97c8c4768a200e38c1368c176
status:
code: 200
message: OK
version: 1

View File

@@ -1,7 +1,6 @@
"""Test Knowledge creation and querying functionality."""
from pathlib import Path
from typing import List, Union
from unittest.mock import patch
import pytest
@@ -23,7 +22,7 @@ def mock_vector_db():
instance = mock.return_value
instance.query.return_value = [
{
"context": "Brandon's favorite color is blue and he likes Mexican food.",
"content": "Brandon's favorite color is blue and he likes Mexican food.",
"score": 0.9,
}
]
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite color?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("blue" in result["context"].lower() for result in results)
assert any("blue" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
# Mock the vector db query response
mock_vector_db.query.return_value = [
{"context": "Brandon has a dog named Max.", "score": 0.9}
{"content": "Brandon has a dog named Max.", "score": 0.9}
]
mock_vector_db.sources = string_sources
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("max" in result["context"].lower() for result in results)
assert any("max" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
]
mock_vector_db.sources = string_sources
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What sport does Brandon like?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("basketball" in result["context"].lower() for result in results)
assert any("basketball" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
]
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{"context": "Brandon lives in New York.", "score": 0.9}
{"content": "Brandon lives in New York.", "score": 0.9}
]
# Perform a query
query = "What city does he reside in?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("new york" in result["context"].lower() for result in results)
assert any("new york" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"score": 0.9,
}
]
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
# Combine string and file sources
mock_vector_db.sources = string_sources + file_sources
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("the alchemist" in result["context"].lower() for result in results)
assert any("the alchemist" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
)
mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [
{"context": "crewai create crew latest-ai-development", "score": 0.9}
{"content": "crewai create crew latest-ai-development", "score": 0.9}
]
# Perform a query
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"crewai create crew latest-ai-development" in result["context"].lower()
"crewai create crew latest-ai-development" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [
{"context": "Alice lives in Los Angeles.", "score": 0.9}
{"content": "Alice lives in Los Angeles.", "score": 0.9}
]
# Perform a query
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("los angeles" in result["context"].lower() for result in results)
assert any("los angeles" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
"""Test ExcelKnowledgeSource with a simple Excel file."""
# Create an Excel file with sample data
import pandas as pd
import pandas as pd # type: ignore[import-untyped]
excel_data = {
"Name": ["Brandon", "Alice", "Bob"],
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
]
# Perform a query
query = "What is reward hacking?"
results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results)
assert any("reward hacking" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@pytest.mark.vcr
def test_multiple_docling_sources():
urls: List[Union[Path, str]] = [
def test_multiple_docling_sources() -> None:
urls: list[Path | str] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
]

View File

@@ -0,0 +1,191 @@
"""Tests for Knowledge SearchResult type conversion and integration."""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
StringKnowledgeSource,
)
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
extract_knowledge_context,
)
def test_knowledge_query_returns_searchresult() -> None:
"""Test that Knowledge.query returns SearchResult format."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = [
{
"content": "AI is fascinating",
"score": 0.9,
"metadata": {"source": "doc1"},
},
{
"content": "Machine learning rocks",
"score": 0.8,
"metadata": {"source": "doc2"},
},
]
sources = [StringKnowledgeSource(content="Test knowledge content")]
knowledge = Knowledge(collection_name="test_collection", sources=sources)
results = knowledge.query(
["AI technology"], results_limit=5, score_threshold=0.3
)
mock_storage.search.assert_called_once_with(
["AI technology"], limit=5, score_threshold=0.3
)
assert isinstance(results, list)
assert len(results) == 2
for result in results:
assert isinstance(result, dict)
assert "content" in result
assert "score" in result
assert "metadata" in result
assert results[0]["content"] == "AI is fascinating"
assert results[0]["score"] == 0.9
assert results[1]["content"] == "Machine learning rocks"
assert results[1]["score"] == 0.8
def test_knowledge_query_with_empty_results() -> None:
"""Test Knowledge.query with empty search results."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = []
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="empty_test", sources=sources)
results = knowledge.query(["nonexistent query"])
assert isinstance(results, list)
assert len(results) == 0
def test_extract_knowledge_context_with_searchresult() -> None:
"""Test extract_knowledge_context works with SearchResult format."""
search_results = [
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Python is great for AI" in context
assert "Machine learning algorithms" in context
assert "Deep learning frameworks" in context
expected_content = (
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
)
assert expected_content in context
def test_extract_knowledge_context_with_empty_content() -> None:
"""Test extract_knowledge_context handles empty or invalid content."""
search_results = [
{"content": "", "score": 0.5, "metadata": {}},
{"content": None, "score": 0.4, "metadata": {}},
{"score": 0.3, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert context == ""
def test_extract_knowledge_context_filters_invalid_results() -> None:
"""Test that extract_knowledge_context filters out invalid results."""
search_results: list[dict[str, Any] | None] = [
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
{"content": "", "score": 0.8, "metadata": {}},
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
None,
{"content": None, "score": 0.6, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Valid content 1" in context
assert "Valid content 2" in context
assert context.count("\n") == 1
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_storage_exception_handling(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge handles storage exceptions gracefully."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.side_effect = Exception("Storage error")
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="error_test", sources=sources)
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.storage = None
knowledge.query(["test query"])
def test_knowledge_add_sources_integration() -> None:
"""Test Knowledge.add_sources integrates properly with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [
StringKnowledgeSource(content="Content 1"),
StringKnowledgeSource(content="Content 2"),
]
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
knowledge.add_sources()
for source in sources:
assert source.storage == mock_storage
def test_knowledge_reset_integration() -> None:
"""Test Knowledge.reset integrates with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="reset_test", sources=sources)
knowledge.reset()
mock_storage.reset.assert_called_once()
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_reset_without_storage(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge.reset raises error when storage is None."""
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.reset()

View File

@@ -0,0 +1,196 @@
"""Integration tests for KnowledgeStorage RAG client migration."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_knowledge_storage_uses_rag_client(
mock_get_embedding: MagicMock,
mock_create_client: MagicMock,
mock_get_client: MagicMock,
) -> None:
"""Test that KnowledgeStorage properly integrates with RAG client."""
mock_client = MagicMock()
mock_create_client.return_value = mock_client
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
]
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
storage = KnowledgeStorage(
embedder=embedder_config, collection_name="test_knowledge"
)
mock_create_client.assert_called_once()
results = storage.search(["test query"], limit=5, score_threshold=0.3)
mock_get_client.assert_not_called()
mock_client.search.assert_called_once_with(
collection_name="knowledge_test_knowledge",
query="test query",
limit=5,
metadata_filter=None,
score_threshold=0.3,
)
assert isinstance(results, list)
assert len(results) == 1
assert isinstance(results[0], dict)
assert "content" in results[0]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
"""Test that collection names are properly prefixed."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage(collection_name="custom_knowledge")
storage.search(["test"], limit=1)
mock_client.search.assert_called_once()
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
mock_client.reset_mock()
storage_default = KnowledgeStorage()
storage_default.search(["test"], limit=1)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
"""Test document saving through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_docs")
documents = ["Document 1 content", "Document 2 content"]
storage.save(documents)
mock_client.get_or_create_collection.assert_called_once_with(
collection_name="knowledge_test_docs"
)
mock_client.add_documents.assert_called_once()
call_kwargs = mock_client.add_documents.call_args.kwargs
added_docs = call_kwargs["documents"]
assert len(added_docs) == 2
assert added_docs[0]["content"] == "Document 1 content"
assert added_docs[1]["content"] == "Document 2 content"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_reset_integration(mock_get_client: MagicMock) -> None:
"""Test collection reset through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_reset")
storage.reset()
mock_client.delete_collection.assert_called_once_with(
collection_name="knowledge_test_reset"
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_search_error_handling(mock_get_client: MagicMock) -> None:
"""Test error handling during search operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = Exception("RAG client error")
storage = KnowledgeStorage(collection_name="error_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_embedding_configuration_flow(
mock_get_embedding: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test that embedding configuration flows properly to RAG client."""
mock_embedding_func = MagicMock()
mock_get_embedding.return_value = mock_embedding_func
mock_get_client.return_value = MagicMock()
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
"""Test that query list is properly converted to string."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
storage.search(["single query"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "single query"
mock_client.reset_mock()
storage.search(["query one", "query two"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "query one query two"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
"""Test metadata filter parameter handling."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
metadata_filter = {"category": "technical", "priority": "high"}
storage.search(["test"], metadata_filter=metadata_filter)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] == metadata_filter
mock_client.reset_mock()
storage.search(["test"], metadata_filter=None)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] is None
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
"""Test specific handling of dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
storage = KnowledgeStorage(collection_name="dimension_test")
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
storage.save(["test document"])

View File

@@ -1,19 +1,20 @@
from unittest.mock import patch, ANY
from collections import defaultdict
from unittest.mock import ANY, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.task import Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
@pytest.fixture
@@ -38,22 +39,23 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list)
with crewai_event_bus.scoped_handlers():
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
# Call the save method
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
# Call the save method
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory):
expected_result = [
{
"context": memory.data,
"content": memory.data,
"metadata": {"agent": "test_agent"},
"score": 0.95,
}
]
with patch.object(ShortTermMemory, "search", return_value=expected_result):
find = short_term_memory.search("test value", score_threshold=0.01)[0]
assert find["context"] == memory.data, "Data value mismatch."
assert find["content"] == memory.data, "Data value mismatch."
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."

View File

@@ -236,7 +236,7 @@ class TestChromaDBClient:
def test_add_documents(self, client, mock_chromadb_client) -> None:
"""Test that add_documents adds documents to collection."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -247,7 +247,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents)
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
@@ -262,7 +262,7 @@ class TestChromaDBClient:
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
"""Test add_documents with custom document IDs."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -285,6 +285,43 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
"""Test add_documents with documents that have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
client.add_documents(collection_name="test_collection", documents=documents)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
def test_add_documents_all_without_metadata(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents when all documents have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document 1"},
{"content": "Document 2"},
{"content": "Document 3"},
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] is None
def test_add_documents_empty_list_raises_error(
self, client, mock_chromadb_client
) -> None:
@@ -298,7 +335,7 @@ class TestChromaDBClient:
) -> None:
"""Test that aadd_documents adds documents to collection asynchronously."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -313,7 +350,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
@@ -331,7 +368,7 @@ class TestChromaDBClient:
) -> None:
"""Test aadd_documents with custom document IDs."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -358,6 +395,31 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
@pytest.mark.asyncio
async def test_aadd_documents_without_metadata(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with documents that have no metadata."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
@pytest.mark.asyncio
async def test_aadd_documents_empty_list_raises_error(
self, async_client, mock_async_chromadb_client
@@ -372,7 +434,7 @@ class TestChromaDBClient:
"""Test that search queries the collection correctly."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]],
@@ -382,13 +444,13 @@ class TestChromaDBClient:
results = client.search(collection_name="test_collection", query="test query")
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -404,7 +466,7 @@ class TestChromaDBClient:
"""Test search with optional parameters."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]],
@@ -437,7 +499,7 @@ class TestChromaDBClient:
"""Test that asearch queries the collection correctly."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(
@@ -453,13 +515,13 @@ class TestChromaDBClient:
collection_name="test_collection", query="test query"
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -478,7 +540,7 @@ class TestChromaDBClient:
"""Test asearch with optional parameters."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(

View File

@@ -0,0 +1,95 @@
"""Tests for ChromaDB utility functions."""
from crewai.rag.chromadb.utils import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
_is_ipv4_pattern,
_sanitize_collection_name,
)
class TestChromaDBUtils:
"""Test suite for ChromaDB utility functions."""
def test_sanitize_collection_name_long_name(self) -> None:
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = _sanitize_collection_name(long_name)
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_special_chars(self) -> None:
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = _sanitize_collection_name(special_chars)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_short_name(self) -> None:
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = _sanitize_collection_name(short_name)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_bad_ends(self) -> None:
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = _sanitize_collection_name(bad_ends)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_none(self) -> None:
"""Test sanitizing a None value."""
sanitized = _sanitize_collection_name(None)
assert sanitized == "default_collection"
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = _sanitize_collection_name(ipv4)
assert sanitized.startswith("ip_")
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_is_ipv4_pattern(self) -> None:
"""Test IPv4 pattern detection."""
assert _is_ipv4_pattern("192.168.1.1") is True
assert _is_ipv4_pattern("not.an.ip.address") is False
def test_sanitize_collection_name_properties(self) -> None:
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases: list[str] = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = _sanitize_collection_name(test_case)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_empty_string(self) -> None:
"""Test sanitizing an empty string."""
sanitized = _sanitize_collection_name("")
assert sanitized == "default_collection"
def test_sanitize_collection_name_whitespace_only(self) -> None:
"""Test sanitizing a string with only whitespace."""
sanitized = _sanitize_collection_name(" ")
assert (
sanitized == "a__z"
) # Spaces become underscores, padded to meet requirements
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()

View File

@@ -0,0 +1,250 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(
provider="openai", api_key="test-key", model="text-embedding-3-large"
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
# OpenAI uses model_name parameter, not model
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch(
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
) as mock_st:
mock_instance = MagicMock()
mock_st.return_value = mock_instance
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
mock_instance = MagicMock()
mock_ollama.return_value = mock_instance
config = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
mock_instance = MagicMock()
mock_cohere.return_value = mock_instance
config = {
"provider": "cohere",
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
mock_instance = MagicMock()
mock_hf.return_value = mock_instance
config = {
"provider": "huggingface",
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
mock_instance = MagicMock()
mock_onnx.return_value = mock_instance
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch(
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
) as mock_palm:
mock_instance = MagicMock()
mock_palm.return_value = mock_instance
config = {"provider": "google-palm", "api_key": "palm-key"}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function."""
with patch(
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
) as mock_bedrock:
mock_instance = MagicMock()
mock_bedrock.return_value = mock_instance
config = {
"provider": "amazon-bedrock",
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
mock_instance = MagicMock()
mock_jina.return_value = mock_instance
config = {
"provider": "jina",
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch(
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
) as mock_instructor:
mock_instance = MagicMock()
mock_instructor.return_value = mock_instance
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance

View File

@@ -3,7 +3,8 @@
from unittest.mock import AsyncMock, Mock
import pytest
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client import AsyncQdrantClient
from qdrant_client import QdrantClient as SyncQdrantClient
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.client import QdrantClient
@@ -435,7 +436,7 @@ class TestQdrantClient:
call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False
@@ -540,7 +541,7 @@ class TestQdrantClient:
call_args = mock_async_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False

View File

@@ -0,0 +1,218 @@
"""Tests for RAG client error handling scenarios."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles RAG client connection failures."""
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
storage = KnowledgeStorage(collection_name="connection_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles search timeouts gracefully."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = TimeoutError("Search operation timed out")
storage = KnowledgeStorage(collection_name="timeout_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles missing collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = ValueError(
"Collection 'knowledge_missing' does not exist"
)
storage = KnowledgeStorage(collection_name="missing_collection")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles invalid embedding configurations."""
mock_get_client.return_value = MagicMock()
with patch(
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
) as mock_get_embedding:
mock_get_embedding.side_effect = ValueError(
"Unsupported provider: invalid_provider"
)
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
KnowledgeStorage(
embedder={"provider": "invalid_provider"},
collection_name="invalid_embedding_test",
)
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles RAG client failures in memory operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
storage = RAGStorage("short_term", crew=None)
results = storage.search("test query")
assert results == []
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles save operation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.add_documents.side_effect = Exception("Failed to add documents")
storage = RAGStorage("long_term", crew=None)
storage.save("test memory", {"key": "value"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage reset handles readonly database errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception(
"attempt to write a readonly database"
)
storage = KnowledgeStorage(collection_name="readonly_test")
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_collection_does_not_exist(
mock_get_client: MagicMock,
) -> None:
"""Test KnowledgeStorage reset handles non-existent collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
storage = KnowledgeStorage(collection_name="nonexistent_test")
storage.reset()
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
"""Test RAGStorage reset propagates unexpected errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
storage = RAGStorage("entities", crew=None)
with pytest.raises(
Exception, match="An error occurred while resetting the entities memory"
):
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles malformed search results."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "valid result", "metadata": {"source": "test"}},
{"invalid": "missing content field", "metadata": {"source": "test"}},
None,
{"content": None, "metadata": {"source": "test"}},
]
storage = KnowledgeStorage(collection_name="malformed_test")
results = storage.search(["test query"])
assert isinstance(results, list)
assert len(results) == 4
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles network interruptions during operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = [
ConnectionError("Network interruption"),
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
]
storage = KnowledgeStorage(collection_name="network_test")
first_attempt = storage.search(["test query"])
assert first_attempt == []
mock_client.search.side_effect = None
mock_client.search.return_value = [
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
]
second_attempt = storage.search(["test query"])
assert len(second_attempt) == 1
assert second_attempt[0]["content"] == "recovered result"
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles collection creation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.side_effect = Exception(
"Failed to create collection"
)
storage = RAGStorage("user_memory", crew=None)
storage.save("test data", {"metadata": "test"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
mock_get_client: MagicMock,
) -> None:
"""Test detailed handling of embedding dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception(
"Embedding dimension mismatch: expected 384, got 1536"
)
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
with pytest.raises(ValueError) as exc_info:
storage.save(["test document"])
assert "Embedding dimension mismatch" in str(exc_info.value)
assert "Make sure you're using the same embedding model" in str(exc_info.value)
assert "crewai reset-memories -a" in str(exc_info.value)

View File

@@ -1,8 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from mem0.client.main import MemoryClient
from mem0.memory.main import Memory
from mem0 import Memory, MemoryClient
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -13,6 +12,67 @@ class MockCrew:
self.agents = [MagicMock(role="Test Agent")]
# Test data constants
SYSTEM_CONTENT = (
"You are Friendly chatbot assistant. You are a kind and "
"knowledgeable chatbot assistant. You excel at understanding user needs, "
"providing helpful responses, and maintaining engaging conversations. "
"You remember previous interactions to provide a personalized experience.\n"
"Your personal goal is: Engage in useful and interesting conversations "
"with users while remembering context.\n"
"To give my best complete final answer to the task respond using the exact "
"following format:\n\n"
"Thought: I now can give a great answer\n"
"Final Answer: Your final answer must be the great and the most complete "
"as possible, it must be outcome described.\n\n"
"I MUST use these formats, my job depends on it!"
)
USER_CONTENT = (
"\nCurrent Task: Respond to user conversation. User message: "
"What do you know about me?\n\n"
"This is the expected criteria for your final answer: Contextually "
"appropriate, helpful, and friendly response.\n"
"you MUST return the actual complete content as the final answer, "
"not a summary.\n\n"
"# Useful context: \nExternal memories:\n"
"- User is from India\n"
"- User is interested in the solar system\n"
"- User name is Vidit Ostwal\n"
"- User is interested in French cuisine\n\n"
"Begin! This is VERY important to you, use the tools available and give "
"your best Final Answer, your job depends on it!\n\n"
"Thought:"
)
ASSISTANT_CONTENT = (
"I now can give a great answer \n"
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
TEST_DESCRIPTION = (
"Respond to user conversation. User message: What do you know about me?"
)
# Extracted content (after processing by _get_user_message and _get_assistant_message)
EXTRACTED_USER_CONTENT = "What do you know about me?"
EXTRACTED_ASSISTANT_CONTENT = (
"Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
@pytest.fixture
def mock_mem0_memory():
"""Fixture to create a mock Memory instance"""
@@ -24,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
# Patch the Memory class to return our mock
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
with patch(
"mem0.Memory.from_config", return_value=mock_mem0_memory
) as mock_from_config:
config = {
"vector_store": {
"provider": "mock_vector_store",
@@ -55,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
crew = MockCrew()
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True}
embedder_config = {
"user_id": "test_user",
"local_mem0_config": config,
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True,
}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
return mem0_storage, mock_from_config, config
@@ -83,28 +152,31 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew()
embedder_config={
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True
}
embedder_config = {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True,
}
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
@pytest.fixture
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
def mem0_storage_with_memory_client_using_explictly_config(
mock_mem0_memory_client, mock_mem0_memory
):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
# We need to patch both MemoryClient and Memory to prevent actual initialization
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
with (
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
):
crew = MockCrew()
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
@@ -138,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
mock_mem0_memory_client.update_project = MagicMock()
new_categories = [
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
{
"lifestyle_management_concerns": (
"Tracks daily routines, habits, hobbies and interests "
"including cooking, time management and work-life balance"
)
},
]
crew = MockCrew()
config={
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"custom_categories": new_categories
}
config = {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"custom_categories": new_categories,
}
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
_ = Mem0Storage(type="short_term", crew=crew, config=config)
@@ -159,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
)
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
@@ -168,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent'
agent_id="Test_Agent",
)
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
mem0_storage.crew.agents = [
MagicMock(role="Test Agent"),
MagicMock(role="Test Agent 2"),
MagicMock(role="Test Agent 3"),
]
mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
)
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
def test_save_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {'description': 'Respond to user conversation. User message: What do you know about me?', 'messages': [{'role': 'system', 'content': 'You are Friendly chatbot assistant. You are a kind and knowledgeable chatbot assistant. You excel at understanding user needs, providing helpful responses, and maintaining engaging conversations. You remember previous interactions to provide a personalized experience.\nYour personal goal is: Engage in useful and interesting conversations with users while remembering context.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!'}, {'role': 'user', 'content': '\nCurrent Task: Respond to user conversation. User message: What do you know about me?\n\nThis is the expected criteria for your final answer: Contextually appropriate, helpful, and friendly response.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n- User is from India\n- User is interested in the solar system\n- User name is Vidit Ostwal\n- User is interested in French cuisine\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'}, {'role': 'assistant', 'content': "I now can give a great answer \nFinal Answer: Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}], 'agent': 'Friendly chatbot assistant'}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'user', 'content': 'What do you know about me?'}, {'role': 'assistant', 'content': "Hi Vidit! From our previous conversations, I know you're from India and have a great interest in the solar system. It's fascinating to explore the wonders of space, isn't it? Also, I remember you have a passion for French cuisine, which has so many delightful dishes to explore. If there's anything specific you'd like to discuss or learn about—whether it's about the solar system or some great French recipes—feel free to let me know! I'm here to help."}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={'type': 'short_term', 'description': 'Respond to user conversation. User message: What do you know about me?', 'agent': 'Friendly chatbot assistant'},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
version="v2",
run_id="my_run_id",
includes="include1",
excludes="exclude1",
output_format='v1.1',
user_id='test_user',
agent_id='Test_Agent'
output_format="v1.1",
user_id="test_user",
agent_id="Test_Agent",
)
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test search method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -238,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
query="test query",
limit=5,
user_id="test_user",
filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5
filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
def test_search_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test search method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -259,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
limit=5,
metadata={"type": "short_term"},
user_id="test_user",
version='v2',
version="v2",
run_id="my_run_id",
output_format='v1.1',
filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5
output_format="v1.1",
filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
@@ -275,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew()
config={
"user_id": "test_user",
"api_key": "ABCDEFGH"
}
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = {
"agent_id": "agent-123",
@@ -293,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}],
[{"role": "assistant", "content": "test memory"}],
infer=True,
metadata={"type": "external", "key": "value"},
agent_id="agent-123",
)
def test_search_method_with_agent_entity():
config = {
"agent_id": "agent-123",
}
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
@@ -314,22 +466,29 @@ def test_search_method_with_agent_entity():
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id():
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
mem0_storage = Mem0Storage(
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
)
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -337,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id():
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id='user-123',
user_id="user-123",
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"

View File

@@ -3818,10 +3818,7 @@ def test_task_tools_preserve_code_execution_tools():
"""
Test that task tools don't override code execution tools when allow_code_execution=True
"""
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import CodeInterpreterTool
from crewai_tools import CodeInterpreterTool
from pydantic import BaseModel, Field
from crewai.tools import BaseTool

View File

@@ -112,9 +112,9 @@ def test_agent_memoization():
first_call_result = crew.simple_agent()
second_call_result = crew.simple_agent()
assert (
first_call_result is second_call_result
), "Agent memoization is not working as expected"
assert first_call_result is second_call_result, (
"Agent memoization is not working as expected"
)
def test_task_memoization():
@@ -122,9 +122,9 @@ def test_task_memoization():
first_call_result = crew.simple_task()
second_call_result = crew.simple_task()
assert (
first_call_result is second_call_result
), "Task memoization is not working as expected"
assert first_call_result is second_call_result, (
"Task memoization is not working as expected"
)
def test_crew_memoization():
@@ -132,35 +132,35 @@ def test_crew_memoization():
first_call_result = crew.crew()
second_call_result = crew.crew()
assert (
first_call_result is second_call_result
), "Crew references should point to the same object"
assert first_call_result is second_call_result, (
"Crew references should point to the same object"
)
def test_task_name():
simple_task = SimpleCrew().simple_task()
assert (
simple_task.name == "simple_task"
), "Task name is not inferred from function name as expected"
assert simple_task.name == "simple_task", (
"Task name is not inferred from function name as expected"
)
custom_named_task = SimpleCrew().custom_named_task()
assert (
custom_named_task.name == "Custom"
), "Custom task name is not being set as expected"
assert custom_named_task.name == "Custom", (
"Custom task name is not being set as expected"
)
def test_agent_function_calling_llm():
crew = InternalCrew()
llm = crew.local_llm()
obj_llm_agent = crew.researcher()
assert (
obj_llm_agent.function_calling_llm is llm
), "agent's function_calling_llm is incorrect"
assert obj_llm_agent.function_calling_llm is llm, (
"agent's function_calling_llm is incorrect"
)
str_llm_agent = crew.reporting_analyst()
assert (
str_llm_agent.function_calling_llm.model == "online_llm"
), "agent's function_calling_llm is incorrect"
assert str_llm_agent.function_calling_llm.model == "online_llm", (
"agent's function_calling_llm is incorrect"
)
def test_task_guardrail():
@@ -186,9 +186,9 @@ def test_after_kickoff_modification():
# Assuming the crew execution returns a dict
result = crew.crew().kickoff({"topic": "LLMs"})
assert (
"post processed" in result.raw
), "After kickoff function did not modify outputs"
assert "post processed" in result.raw, (
"After kickoff function did not modify outputs"
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -274,10 +274,8 @@ def another_simple_tool():
def test_internal_crew_with_mcp():
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import ToolCollection
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import ToolCollection
mock = Mock(spec=MCPServerAdapter)
mock.tools = ToolCollection([simple_tool, another_simple_tool])
@@ -287,6 +285,5 @@ def test_internal_crew_with_mcp():
assert crew.researcher().tools == [simple_tool]
adapter_mock.assert_called_once_with(
{"host": "localhost", "port": 8000},
connect_timeout=120
{"host": "localhost", "port": 8000}, connect_timeout=120
)

View File

@@ -1,17 +1,20 @@
import os
from unittest.mock import MagicMock, Mock, patch
import pytest
from unittest.mock import patch, MagicMock
from crewai import Agent, Task, Crew
from crewai.flow.flow import Flow, start
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
from crewai import Agent, Crew, Task
from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
from crewai.events.listeners.tracing.trace_batch_manager import (
TraceBatchManager,
)
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.flow.flow import Flow, start
class TestTraceListenerSetup:
@@ -281,9 +284,9 @@ class TestTraceListenerSetup:
):
trace_handlers.append(handler)
assert (
len(trace_handlers) == 0
), f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
assert len(trace_handlers) == 0, (
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
)
def test_trace_listener_setup_correctly_for_crew(self):
"""Test that trace listener is set up correctly when enabled"""
@@ -403,3 +406,254 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
"""Test first-time user trace collection logic with timeout behavior"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
return_value=False,
) as mock_prompt,
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Say hello to the world",
expected_output="hello world",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], verbose=True)
from crewai.events.event_bus import crewai_event_bus
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
assert trace_listener.first_time_handler.collected_events is False
with (
patch.object(
trace_listener.first_time_handler,
"handle_execution_completion",
wraps=trace_listener.first_time_handler.handle_execution_completion,
) as mock_handle_completion,
patch.object(
trace_listener.batch_manager,
"add_event",
wraps=trace_listener.batch_manager.add_event,
) as mock_add_event,
):
result = crew.kickoff()
assert result is not None
assert mock_handle_completion.call_count >= 1
assert mock_add_event.call_count >= 1
assert trace_listener.first_time_handler.collected_events is True
mock_prompt.assert_called_once_with(timeout_seconds=20)
mock_mark_completed.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_user_accepts(self, mock_plus_api_calls):
"""Test first-time user trace collection when user accepts viewing traces"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Say hello to the world",
expected_output="hello world",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], verbose=True)
from crewai.events.event_bus import crewai_event_bus
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
with (
patch.object(
trace_listener.first_time_handler,
"_initialize_backend_and_send_events",
wraps=trace_listener.first_time_handler._initialize_backend_and_send_events,
) as mock_init_backend,
patch.object(
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
) as mock_display_link,
patch.object(
trace_listener.first_time_handler,
"handle_execution_completion",
wraps=trace_listener.first_time_handler.handle_execution_completion,
) as mock_handle_completion,
):
trace_listener.batch_manager.ephemeral_trace_url = (
"https://crewai.com/trace/mock-id"
)
crew.kickoff()
assert mock_handle_completion.call_count >= 1, (
"handle_execution_completion should be called"
)
assert trace_listener.first_time_handler.collected_events is True, (
"Events should be marked as collected"
)
mock_init_backend.assert_called_once()
mock_display_link.assert_called_once()
mock_mark_completed.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
"""Test the consolidation logic for first-time users vs regular tracing"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
):
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear()
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Test task", expected_output="test output", agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
result = crew.kickoff()
assert mock_initialize.call_count >= 1
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
assert result is not None
def test_first_time_handler_timeout_behavior(self):
"""Test the timeout behavior of the first-time trace prompt"""
with (
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch("threading.Thread") as mock_thread,
):
from crewai.events.listeners.tracing.utils import (
prompt_user_for_trace_viewing,
)
mock_thread_instance = Mock()
mock_thread_instance.is_alive.return_value = True
mock_thread.return_value = mock_thread_instance
result = prompt_user_for_trace_viewing(timeout_seconds=5)
assert result is False
mock_thread.assert_called_once()
call_args = mock_thread.call_args
assert call_args[1]["daemon"] is True
mock_thread_instance.start.assert_called_once()
mock_thread_instance.join.assert_called_once_with(timeout=5)
mock_thread_instance.is_alive.assert_called_once()
def test_first_time_handler_graceful_error_handling(self):
"""Test graceful error handling in first-time trace logic"""
with (
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
side_effect=Exception("Prompt failed"),
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
handler = FirstTimeTraceHandler()
handler.is_first_time = True
handler.collected_events = True
handler.handle_execution_completion()
mock_mark_completed.assert_called_once()

View File

@@ -1,123 +0,0 @@
import multiprocessing
import tempfile
import unittest
from chromadb.config import Settings
from unittest.mock import patch, MagicMock
from crewai.utilities.chromadb import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
is_ipv4_pattern,
sanitize_collection_name,
create_persistent_client,
)
def persistent_client_worker(path, queue):
try:
create_persistent_client(path=path)
queue.put(None)
except Exception as e:
queue.put(e)
class TestChromadbUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_special_chars(self):
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_short_name(self):
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_bad_ends(self):
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_none(self):
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
def test_sanitize_collection_name_ipv4_pattern(self):
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_is_ipv4_pattern(self):
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
def test_sanitize_collection_name_properties(self):
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_create_persistent_client_passes_args(self):
with patch(
"crewai.utilities.chromadb.PersistentClient"
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
mock_instance = MagicMock()
mock_persistent_client.return_value = mock_instance
settings = Settings(allow_reset=True)
client = create_persistent_client(path=tmpdir, settings=settings)
mock_persistent_client.assert_called_once_with(
path=tmpdir, settings=settings
)
self.assertIs(client, mock_instance)
def test_create_persistent_client_process_safe(self):
with tempfile.TemporaryDirectory() as tmpdir:
queue = multiprocessing.Queue()
processes = [
multiprocessing.Process(
target=persistent_client_worker, args=(tmpdir, queue)
)
for _ in range(5)
]
[p.start() for p in processes]
[p.join() for p in processes]
errors = [queue.get(timeout=5) for _ in processes]
self.assertTrue(all(err is None for err in errors))

View File

@@ -29,13 +29,15 @@ def mock_knowledge_source():
"""
return StringKnowledgeSource(content=content)
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_knowledge_included_in_planning(mock_chroma):
@patch("crewai.rag.config.utils.get_rag_client")
def test_knowledge_included_in_planning(mock_get_client):
"""Test that verifies knowledge sources are properly included in planning."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
mock_collection.add.return_value = None
# Mock RAG client
mock_client = mock_get_client.return_value
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.return_value = None
# Create an agent with knowledge
agent = Agent(
role="AI Researcher",
@@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma):
StringKnowledgeSource(
content="AI systems require careful training and validation."
)
]
],
)
# Create a task for the agent
task = Task(
description="Explain the basics of AI systems",
expected_output="A clear explanation of AI fundamentals",
agent=agent
agent=agent,
)
# Create a crew planner
@@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma):
task_summary = planner._create_tasks_summary()
# Verify that knowledge is included in planning when present
assert "AI systems require careful training" in task_summary, \
assert "AI systems require careful training" in task_summary, (
"Knowledge content should be present in task summary when knowledge exists"
assert '"agent_knowledge"' in task_summary, \
)
assert '"agent_knowledge"' in task_summary, (
"agent_knowledge field should be present in task summary when knowledge exists"
)
# Verify that knowledge is properly formatted
assert isinstance(task.agent.knowledge_sources, list), \
assert isinstance(task.agent.knowledge_sources, list), (
"Knowledge sources should be stored in a list"
assert len(task.agent.knowledge_sources) > 0, \
)
assert len(task.agent.knowledge_sources) > 0, (
"At least one knowledge source should be present"
assert task.agent.knowledge_sources[0].content in task_summary, \
)
assert task.agent.knowledge_sources[0].content in task_summary, (
"Knowledge source content should be included in task summary"
)
# Verify that other expected components are still present
assert task.description in task_summary, \
assert task.description in task_summary, (
"Task description should be present in task summary"
assert task.expected_output in task_summary, \
)
assert task.expected_output in task_summary, (
"Expected output should be present in task summary"
assert agent.role in task_summary, \
"Agent role should be present in task summary"
)
assert agent.role in task_summary, "Agent role should be present in task summary"

6317
uv.lock generated

File diff suppressed because it is too large Load Diff