Compare commits

...

17 Commits

Author SHA1 Message Date
Joao Moura
c47ff15bf6 set tracing if user enables it 2025-09-18 17:01:52 -07:00
Joao Moura
270e0b6edd avoinding line breaking 2025-09-18 16:40:51 -07:00
Joao Moura
a0cbb5cfdb feat(tracing): enhance first-time trace display and auto-open browser 2025-09-18 16:37:48 -07:00
Lorenze Jay
2f682e1564 feat: update ChromaDB embedding function to use OpenAI API (#3538)
- Refactor the default embedding function to utilize OpenAI's embedding function with API key support.
- Import necessary OpenAI embedding function and configure it with the environment variable for the API key.
- Ensure compatibility with existing ChromaDB configuration model.
2025-09-18 14:50:35 -07:00
Greyson LaLonde
d4aa676195 feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
2025-09-18 16:58:03 -04:00
Lorenze Jay
578fa8c2e4 Lorenze/ephemeral trace ask (#3530)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
* feat(tracing): implement first-time trace handling and improve event management

- Added FirstTimeTraceHandler for managing first-time user trace collection and display.
- Enhanced TraceBatchManager to support ephemeral trace URLs and improved event buffering.
- Updated TraceCollectionListener to utilize the new FirstTimeTraceHandler.
- Refactored type annotations across multiple files for consistency and clarity.
- Improved error handling and logging for trace-related operations.
- Introduced utility functions for trace viewing prompts and first execution checks.

* brought back crew finalize batch events

* refactor(trace): move instance variables to __init__ in TraceBatchManager

- Refactored TraceBatchManager to initialize instance variables in the constructor instead of as class variables.
- Improved clarity and encapsulation of the class state.

* fix(tracing): improve error handling in user data loading and saving

- Enhanced error handling in _load_user_data and _save_user_data functions to log warnings for JSON decoding and file access issues.
- Updated documentation for trace usage to clarify the addition of tracing parameters in Crew and Flow initialization.
- Refined state management in Flow class to ensure proper handling of state IDs when persistence is enabled.

* add some tests

* fix test

* fix tests

* refactor(tracing): enhance user input handling for trace viewing

- Replaced signal-based timeout handling with threading for user input in prompt_user_for_trace_viewing function.
- Improved user experience by allowing a configurable timeout for viewing execution traces.
- Updated tests to mock threading behavior and verify timeout handling correctly.

* fix(tracing): improve machine ID retrieval with error handling

- Added error handling to the _get_machine_id function to log warnings when retrieving the machine ID fails.
- Ensured that the function continues to provide a stable, privacy-preserving machine fingerprint even in case of errors.

* refactor(flow): streamline state ID assignment in Flow class

- Replaced direct attribute assignment with setattr for improved flexibility in handling state IDs.
- Enhanced code readability by simplifying the logic for setting the state ID when persistence is enabled.
2025-09-18 10:17:34 -07:00
Rip&Tear
6f5af2b27c Update CodeQL workflow to ignore specific paths (#3534)
Code QL, when configured through the GUI, does not allow for advanced configuration. This PR upgrades from an advanced file-based config which allows us to exclude certain paths.
2025-09-18 23:26:15 +08:00
Greyson LaLonde
8ee3cf4874 test: fix flaky agent repeated tool usage test (#3533)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
- Make assertion resilient to race condition with max iterations in CI  
- Add investigation notes and TODOs for deterministic executor flow
2025-09-17 22:00:32 -04:00
Greyson LaLonde
f2d3fd0c0f fix(events): add missing event exports to __init__.py (#3532) 2025-09-17 21:50:27 -04:00
Greyson LaLonde
f28e78c5ba refactor: unify rag storage with instance-specific client support (#3455)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- ignore line length errors globally
- migrate knowledge/memory and crew query_knowledge to `SearchResult`
- remove legacy chromadb utils; fix empty metadata handling
- restore openai as default embedding provider; support instance-specific clients
- update and fix tests for `SearchResult` migration and rag changes
2025-09-17 14:46:54 -04:00
Greyson LaLonde
81bd81e5f5 fix: handle model parameter in OpenAI adapter initialization (#3510)
Some checks failed
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-09-12 17:31:53 -04:00
Vidit Ostwal
1b00cc71ef Dropping messages from metadata in Mem0 Storage (#3390)
* Dropped messages from metadata and added user-assistant interaction directly

* Fixed test cases for this

* Fixed static type checking issue

* Changed logic to take latest user and assistant messages

* Added default value to be string

* Linting checks

* Removed duplication of tool calling

* Fixed Linting Changes

* Ruff check

* Removed console formatter file from commit

* Linting fixed

* Linting checks

* Ignoring missing imports error

* Added suggested changes

* Fixed import untyped error
2025-09-12 15:25:29 -04:00
Greyson LaLonde
45d0c9912c chore: add type annotations and docstrings to openai agent adapters (#3505) 2025-09-12 10:41:39 -04:00
Greyson LaLonde
1f1ab14b07 fix: resolve test duration cache issues in CI workflows (#3506)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
2025-09-12 08:38:47 -04:00
Lucas Gomide
1a70f1698e feat: add thread-safe platform context management (#3502)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
2025-09-11 17:32:51 -04:00
Greyson LaLonde
8883fb656b feat(tests): add duration caching for pytest-split
- Cache test durations for optimized splitting
2025-09-11 15:16:05 -04:00
Greyson LaLonde
79d65e55a1 chore: add type annotations and docstrings to langgraph adapters (#3503) 2025-09-11 13:06:44 -04:00
62 changed files with 4708 additions and 1671 deletions

102
.github/workflows/codeql.yml vendored Normal file
View File

@@ -0,0 +1,102 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL Advanced"
on:
push:
branches: [ "main" ]
paths-ignore:
- "src/crewai/cli/templates/**"
pull_request:
branches: [ "main" ]
paths-ignore:
- "src/crewai/cli/templates/**"
jobs:
analyze:
name: Analyze (${{ matrix.language }})
# Runner size impacts CodeQL analysis time. To learn more, please see:
# - https://gh.io/recommended-hardware-resources-for-running-codeql
# - https://gh.io/supported-runners-and-hardware-resources
# - https://gh.io/using-larger-runners (GitHub.com only)
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
permissions:
# required for all workflows
security-events: write
# required to fetch internal or private CodeQL packs
packages: read
# only required for workflows in private repositories
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: actions
build-mode: none
- language: python
build-mode: none
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
# Use `c-cpp` to analyze code written in C, C++ or both
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Add any setup steps before running the `github/codeql-action/init` action.
# This includes steps like installing compilers or runtimes (`actions/setup-node`
# or others). This is typically only required for manual builds.
# - name: Setup runtime (example)
# uses: actions/setup-example@v1
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# If the analyze step fails for one of the languages you are analyzing with
# "We were unable to automatically build your code", modify the matrix above
# to set the build mode to "manual" for that language. Then modify this step
# to build your code.
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
- if: matrix.build-mode == 'manual'
shell: bash
run: |
echo 'If you are using a "manual" build mode for one or more of the' \
'languages you are analyzing, replace this with the commands to build' \
'your code, for example:'
echo ' make bootstrap'
echo ' make release'
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -22,6 +22,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for proper diff
- name: Restore global uv cache
id: cache-restore
@@ -45,14 +47,41 @@ jobs:
- name: Install the project
run: uv sync --all-groups --all-extras
- name: Restore test durations
uses: actions/cache/restore@v4
with:
path: .test_durations_py*
key: test-durations-py${{ matrix.python-version }}
- name: Run tests (group ${{ matrix.group }} of 8)
run: |
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
DURATION_FILE=".test_durations_py${PYTHON_VERSION_SAFE}"
# Temporarily always skip cached durations to fix test splitting
# When durations don't match, pytest-split runs duplicate tests instead of splitting
echo "Using even test splitting (duration cache disabled until fix merged)"
DURATIONS_ARG=""
# Original logic (disabled temporarily):
# if [ ! -f "$DURATION_FILE" ]; then
# echo "No cached durations found, tests will be split evenly"
# DURATIONS_ARG=""
# elif git diff origin/${{ github.base_ref }}...HEAD --name-only 2>/dev/null | grep -q "^tests/.*\.py$"; then
# echo "Test files have changed, skipping cached durations to avoid mismatches"
# DURATIONS_ARG=""
# else
# echo "No test changes detected, using cached test durations for optimal splitting"
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
# fi
uv run pytest \
--block-network \
--timeout=30 \
-vv \
--splits 8 \
--group ${{ matrix.group }} \
$DURATIONS_ARG \
--durations=10 \
-n auto \
--maxfail=3

View File

@@ -0,0 +1,71 @@
name: Update Test Durations
on:
push:
branches:
- main
paths:
- 'tests/**/*.py'
workflow_dispatch:
permissions:
contents: read
jobs:
update-durations:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', '3.13']
env:
OPENAI_API_KEY: fake-api-key
PYTHONUNBUFFERED: 1
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Restore global uv cache
id: cache-restore
uses: actions/cache/restore@v4
with:
path: |
~/.cache/uv
~/.local/share/uv
.venv
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
restore-keys: |
uv-main-py${{ matrix.python-version }}-
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.8.4"
python-version: ${{ matrix.python-version }}
enable-cache: false
- name: Install the project
run: uv sync --all-groups --all-extras
- name: Run all tests and store durations
run: |
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
uv run pytest --store-durations --durations-path=.test_durations_py${PYTHON_VERSION_SAFE} -n auto
continue-on-error: true
- name: Save durations to cache
if: always()
uses: actions/cache/save@v4
with:
path: .test_durations_py*
key: test-durations-py${{ matrix.python-version }}
- name: Save uv caches
if: steps.cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: |
~/.cache/uv
~/.local/share/uv
.venv
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}

View File

@@ -131,13 +131,14 @@ select = [
"I001", # sort imports
"I002", # remove unused imports
]
ignore = ["E501"] # ignore line too long
ignore = ["E501"] # ignore line too long globally
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
"tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests
[tool.mypy]
exclude = ["src/crewai/cli/templates", "tests"]
exclude = ["src/crewai/cli/templates", "tests/"]
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -1,47 +1,56 @@
from typing import Any, Dict, List, Optional
"""LangGraph agent adapter for CrewAI integration.
from pydantic import Field, PrivateAttr
This module contains the LangGraphAgentAdapter class that integrates LangGraph ReAct agents
with CrewAI's agent system. Provides memory persistence, tool integration, and structured
output functionality.
"""
from collections.abc import Callable
from typing import Any, cast
from pydantic import ConfigDict, Field, PrivateAttr
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
from crewai.agents.agent_adapters.langgraph.langgraph_tool_adapter import (
LangGraphToolAdapter,
)
from crewai.agents.agent_adapters.langgraph.protocols import (
LangGraphCheckPointMemoryModule,
LangGraphPrebuiltModule,
)
from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
LangGraphConverterAdapter,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool
from crewai.utilities import Logger
from crewai.utilities.converter import Converter
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
LANGGRAPH_AVAILABLE = True
except ImportError:
LANGGRAPH_AVAILABLE = False
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool
from crewai.utilities import Logger
from crewai.utilities.converter import Converter
from crewai.utilities.import_utils import require
class LangGraphAgentAdapter(BaseAgentAdapter):
"""Adapter for LangGraph agents to work with CrewAI."""
"""Adapter for LangGraph agents to work with CrewAI.
model_config = {"arbitrary_types_allowed": True}
This adapter integrates LangGraph's ReAct agents with CrewAI's agent system,
providing memory persistence, tool integration, and structured output support.
"""
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
model_config = ConfigDict(arbitrary_types_allowed=True)
_logger: Logger = PrivateAttr(default_factory=Logger)
_tool_adapter: LangGraphToolAdapter = PrivateAttr()
_graph: Any = PrivateAttr(default=None)
_memory: Any = PrivateAttr(default=None)
_max_iterations: int = PrivateAttr(default=10)
function_calling_llm: Any = Field(default=None)
step_callback: Any = Field(default=None)
step_callback: Callable[..., Any] | None = Field(default=None)
model: str = Field(default="gpt-4o")
verbose: bool = Field(default=False)
@@ -51,17 +60,24 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
role: str,
goal: str,
backstory: str,
tools: Optional[List[BaseTool]] = None,
tools: list[BaseTool] | None = None,
llm: Any = None,
max_iterations: int = 10,
agent_config: Optional[Dict[str, Any]] = None,
agent_config: dict[str, Any] | None = None,
**kwargs,
):
"""Initialize the LangGraph agent adapter."""
if not LANGGRAPH_AVAILABLE:
raise ImportError(
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
)
) -> None:
"""Initialize the LangGraph agent adapter.
Args:
role: The role description for the agent.
goal: The primary goal the agent should achieve.
backstory: Background information about the agent.
tools: Optional list of tools available to the agent.
llm: Language model to use, defaults to gpt-4o.
max_iterations: Maximum number of iterations for task execution.
agent_config: Additional configuration for the LangGraph agent.
**kwargs: Additional arguments passed to the base adapter.
"""
super().__init__(
role=role,
goal=goal,
@@ -72,46 +88,65 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
**kwargs,
)
self._tool_adapter = LangGraphToolAdapter(tools=tools)
self._converter_adapter = LangGraphConverterAdapter(self)
self._converter_adapter: LangGraphConverterAdapter = LangGraphConverterAdapter(
self
)
self._max_iterations = max_iterations
self._setup_graph()
def _setup_graph(self) -> None:
"""Set up the LangGraph workflow graph."""
try:
self._memory = MemorySaver()
"""Set up the LangGraph workflow graph.
converted_tools: List[Any] = self._tool_adapter.tools()
if self._agent_config:
self._graph = create_react_agent(
model=self.llm,
tools=converted_tools,
checkpointer=self._memory,
debug=self.verbose,
**self._agent_config,
)
else:
self._graph = create_react_agent(
model=self.llm,
tools=converted_tools or [],
checkpointer=self._memory,
debug=self.verbose,
)
Initializes the memory saver and creates a ReAct agent with the configured
tools, memory checkpointer, and debug settings.
"""
except ImportError as e:
self._logger.log(
"error", f"Failed to import LangGraph dependencies: {str(e)}"
memory_saver: type[Any] = cast(
LangGraphCheckPointMemoryModule,
require(
"langgraph.checkpoint.memory",
purpose="LangGraph core functionality",
),
).MemorySaver
create_react_agent: Callable[..., Any] = cast(
LangGraphPrebuiltModule,
require(
"langgraph.prebuilt",
purpose="LangGraph core functionality",
),
).create_react_agent
self._memory = memory_saver()
converted_tools: list[Any] = self._tool_adapter.tools()
if self._agent_config:
self._graph = create_react_agent(
model=self.llm,
tools=converted_tools,
checkpointer=self._memory,
debug=self.verbose,
**self._agent_config,
)
else:
self._graph = create_react_agent(
model=self.llm,
tools=converted_tools or [],
checkpointer=self._memory,
debug=self.verbose,
)
raise
except Exception as e:
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
raise
def _build_system_prompt(self) -> str:
"""Build a system prompt for the LangGraph agent."""
"""Build a system prompt for the LangGraph agent.
Creates a prompt that includes the agent's role, goal, and backstory,
then enhances it through the converter adapter for structured output.
Returns:
The complete system prompt string.
"""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -123,10 +158,25 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task using the LangGraph workflow."""
"""Execute a task using the LangGraph workflow.
Configures the agent, processes the task through the LangGraph workflow,
and handles event emission for execution tracking.
Args:
task: The task object to execute.
context: Optional context information for the task.
tools: Optional additional tools for this specific execution.
Returns:
The final answer from the task execution.
Raises:
Exception: If task execution fails.
"""
self.create_agent_executor(tools)
self.configure_structured_output(task)
@@ -151,9 +201,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
session_id = f"task_{id(task)}"
config = {"configurable": {"thread_id": session_id}}
config: dict[str, dict[str, str]] = {
"configurable": {"thread_id": session_id}
}
result = self._graph.invoke(
result: dict[str, Any] = self._graph.invoke(
{
"messages": [
("system", self._build_system_prompt()),
@@ -163,10 +215,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
config,
)
messages = result.get("messages", [])
last_message = messages[-1] if messages else None
messages: list[Any] = result.get("messages", [])
last_message: Any = messages[-1] if messages else None
final_answer = ""
final_answer: str = ""
if isinstance(last_message, dict):
final_answer = last_message.get("content", "")
elif hasattr(last_message, "content"):
@@ -186,7 +238,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -197,29 +249,67 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure the LangGraph agent for execution."""
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the LangGraph agent for execution.
Args:
tools: Optional tools to configure for the agent.
"""
self.configure_tools(tools)
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the LangGraph agent."""
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the LangGraph agent.
Merges additional tools with existing ones and updates the graph's
available tools through the tool adapter.
Args:
tools: Optional additional tools to configure.
"""
if tools:
all_tools = list(self.tools or []) + list(tools or [])
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
self._tool_adapter.configure_tools(all_tools)
available_tools = self._tool_adapter.tools()
available_tools: list[Any] = self._tool_adapter.tools()
self._graph.tools = available_tools
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support for LangGraph."""
agent_tools = AgentTools(agents=agents)
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support for LangGraph.
Creates delegation tools that allow this agent to delegate tasks to other agents.
Args:
agents: List of agents available for delegation.
Returns:
List of delegation tools.
"""
agent_tools: AgentTools = AgentTools(agents=agents)
return agent_tools.tools()
@staticmethod
def get_output_converter(
self, llm: Any, text: str, model: Any, instructions: str
) -> Any:
"""Convert output format if needed."""
llm: Any, text: str, model: Any, instructions: str
) -> Converter:
"""Convert output format if needed.
Args:
llm: Language model instance.
text: Text to convert.
model: Model configuration.
instructions: Conversion instructions.
Returns:
Converter instance for output transformation.
"""
return Converter(llm=llm, text=text, model=model, instructions=instructions)
def configure_structured_output(self, task) -> None:
"""Configure the structured output for LangGraph."""
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for LangGraph.
Uses the converter adapter to set up structured output formatting
based on the task requirements.
Args:
task: Task object containing output requirements.
"""
self._converter_adapter.configure_structured_output(task)

View File

@@ -1,38 +1,72 @@
"""LangGraph tool adapter for CrewAI tool integration.
This module contains the LangGraphToolAdapter class that converts CrewAI tools
to LangGraph-compatible format using langchain_core.tools.
"""
import inspect
from typing import Any, List, Optional
from collections.abc import Awaitable
from typing import Any
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
from crewai.tools.base_tool import BaseTool
class LangGraphToolAdapter(BaseToolAdapter):
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
"""Adapts CrewAI tools to LangGraph agent tool compatible format.
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
self.converted_tools = []
Converts CrewAI BaseTool instances to langchain_core.tools format
that can be used by LangGraph agents.
"""
def configure_tools(self, tools: List[BaseTool]) -> None:
def __init__(self, tools: list[BaseTool] | None = None) -> None:
"""Initialize the tool adapter.
Args:
tools: Optional list of CrewAI tools to adapt.
"""
Configure and convert CrewAI tools to LangGraph-compatible format.
LangGraph expects tools in langchain_core.tools format.
"""
from langchain_core.tools import BaseTool, StructuredTool
super().__init__()
self.original_tools: list[BaseTool] = tools or []
self.converted_tools: list[Any] = []
converted_tools = []
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert CrewAI tools to LangGraph-compatible format.
LangGraph expects tools in langchain_core.tools format. This method
converts CrewAI BaseTool instances to StructuredTool instances.
Args:
tools: List of CrewAI tools to convert.
"""
from langchain_core.tools import BaseTool as LangChainBaseTool
from langchain_core.tools import StructuredTool
converted_tools: list[Any] = []
if self.original_tools:
all_tools = tools + self.original_tools
all_tools: list[BaseTool] = tools + self.original_tools
else:
all_tools = tools
for tool in all_tools:
if isinstance(tool, BaseTool):
if isinstance(tool, LangChainBaseTool):
converted_tools.append(tool)
continue
sanitized_name = self.sanitize_tool_name(tool.name)
sanitized_name: str = self.sanitize_tool_name(tool.name)
async def tool_wrapper(*args, tool=tool, **kwargs):
output = None
async def tool_wrapper(
*args: Any, tool: BaseTool = tool, **kwargs: Any
) -> Any:
"""Wrapper function to adapt CrewAI tool calls to LangGraph format.
Args:
*args: Positional arguments for the tool.
tool: The CrewAI tool to wrap.
**kwargs: Keyword arguments for the tool.
Returns:
The result from the tool execution.
"""
output: Any | Awaitable[Any]
if len(args) > 0 and isinstance(args[0], str):
output = tool.run(args[0])
elif "input" in kwargs:
@@ -41,12 +75,12 @@ class LangGraphToolAdapter(BaseToolAdapter):
output = tool.run(**kwargs)
if inspect.isawaitable(output):
result = await output
result: Any = await output
else:
result = output
return result
converted_tool = StructuredTool(
converted_tool: StructuredTool = StructuredTool(
name=sanitized_name,
description=tool.description,
func=tool_wrapper,
@@ -57,5 +91,10 @@ class LangGraphToolAdapter(BaseToolAdapter):
self.converted_tools = converted_tools
def tools(self) -> List[Any]:
def tools(self) -> list[Any]:
"""Get the list of converted tools.
Returns:
List of LangGraph-compatible tools.
"""
return self.converted_tools or []

View File

@@ -0,0 +1,55 @@
"""Type protocols for LangGraph modules."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class LangGraphMemorySaver(Protocol):
"""Protocol for LangGraph MemorySaver.
Defines the interface for LangGraph's memory persistence mechanism.
"""
def __init__(self) -> None:
"""Initialize the memory saver."""
...
@runtime_checkable
class LangGraphCheckPointMemoryModule(Protocol):
"""Protocol for LangGraph checkpoint memory module.
Defines the interface for modules containing memory checkpoint functionality.
"""
MemorySaver: type[LangGraphMemorySaver]
@runtime_checkable
class LangGraphPrebuiltModule(Protocol):
"""Protocol for LangGraph prebuilt module.
Defines the interface for modules containing prebuilt agent factories.
"""
def create_react_agent(
self,
model: Any,
tools: list[Any],
checkpointer: Any,
debug: bool = False,
**kwargs: Any,
) -> Any:
"""Create a ReAct agent with the given configuration.
Args:
model: The language model to use for the agent.
tools: List of tools available to the agent.
checkpointer: Memory checkpointer for state persistence.
debug: Whether to enable debug mode.
**kwargs: Additional configuration options.
Returns:
The configured ReAct agent instance.
"""
...

View File

@@ -1,21 +1,45 @@
"""LangGraph structured output converter for CrewAI task integration.
This module contains the LangGraphConverterAdapter class that handles structured
output conversion for LangGraph agents, supporting JSON and Pydantic model formats.
"""
import json
import re
from typing import Any, Literal
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
from crewai.utilities.converter import generate_model_description
class LangGraphConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in LangGraph agents"""
"""Adapter for handling structured output conversion in LangGraph agents.
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._system_prompt_appendix = None
Converts task output requirements into system prompt modifications and
post-processing logic to ensure agents return properly structured outputs.
"""
def configure_structured_output(self, task) -> None:
"""Configure the structured output for LangGraph."""
def __init__(self, agent_adapter: Any) -> None:
"""Initialize the converter adapter with a reference to the agent adapter.
Args:
agent_adapter: The LangGraph agent adapter instance.
"""
super().__init__(agent_adapter=agent_adapter)
self.agent_adapter: Any = agent_adapter
self._output_format: Literal["json", "pydantic"] | None = None
self._schema: str | None = None
self._system_prompt_appendix: str | None = None
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for LangGraph.
Analyzes the task's output requirements and sets up the necessary
formatting and validation logic.
Args:
task: The task object containing output format specifications.
"""
if not (task.output_json or task.output_pydantic):
self._output_format = None
self._schema = None
@@ -32,7 +56,14 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
self._system_prompt_appendix = self._generate_system_prompt_appendix()
def _generate_system_prompt_appendix(self) -> str:
"""Generate an appendix for the system prompt to enforce structured output"""
"""Generate an appendix for the system prompt to enforce structured output.
Creates instructions that are appended to the system prompt to guide
the agent in producing properly formatted output.
Returns:
System prompt appendix string, or empty string if no structured output.
"""
if not self._output_format or not self._schema:
return ""
@@ -41,19 +72,36 @@ Important: Your final answer MUST be provided in the following structured format
{self._schema}
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
The output should be raw JSON that exactly matches the specified schema.
"""
def enhance_system_prompt(self, original_prompt: str) -> str:
"""Add structured output instructions to the system prompt if needed"""
"""Add structured output instructions to the system prompt if needed.
Args:
original_prompt: The base system prompt.
Returns:
Enhanced system prompt with structured output instructions.
"""
if not self._system_prompt_appendix:
return original_prompt
return f"{original_prompt}\n{self._system_prompt_appendix}"
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format"""
"""Post-process the result to ensure it matches the expected format.
Attempts to extract and validate JSON content from agent responses,
handling cases where JSON may be wrapped in markdown or other formatting.
Args:
result: The raw result string from the agent.
Returns:
Processed result string, ideally in valid JSON format.
"""
if not self._output_format:
return result
@@ -65,16 +113,16 @@ The output should be raw JSON that exactly matches the specified schema.
return result
except json.JSONDecodeError:
# Try to extract JSON from the text
import re
json_match = re.search(r"(\{.*\})", result, re.DOTALL)
json_match: re.Match[str] | None = re.search(
r"(\{.*})", result, re.DOTALL
)
if json_match:
try:
extracted = json_match.group(1)
extracted: str = json_match.group(1)
# Validate it's proper JSON
json.loads(extracted)
return extracted
except:
except json.JSONDecodeError:
pass
return result

View File

@@ -1,78 +1,99 @@
from typing import Any, List, Optional
"""OpenAI agents adapter for CrewAI integration.
from pydantic import Field, PrivateAttr
This module contains the OpenAIAgentAdapter class that integrates OpenAI Assistants
with CrewAI's agent system, providing tool integration and structured output support.
"""
from typing import Any, cast
from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Unpack
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
from crewai.agents.agent_adapters.openai_agents.openai_agent_tool_adapter import (
OpenAIAgentToolAdapter,
)
from crewai.agents.agent_adapters.openai_agents.protocols import (
AgentKwargs,
OpenAIAgentsModule,
)
from crewai.agents.agent_adapters.openai_agents.protocols import (
OpenAIAgent as OpenAIAgentProtocol,
)
from crewai.agents.agent_adapters.openai_agents.structured_output_converter import (
OpenAIConverterAdapter,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Logger
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Logger
from crewai.utilities.import_utils import require
try:
from agents import Agent as OpenAIAgent # type: ignore
from agents import Runner, enable_verbose_stdout_logging # type: ignore
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
openai_agents_module = cast(
OpenAIAgentsModule,
require(
"agents",
purpose="OpenAI agents functionality",
),
)
OpenAIAgent = openai_agents_module.Agent
Runner = openai_agents_module.Runner
enable_verbose_stdout_logging = openai_agents_module.enable_verbose_stdout_logging
class OpenAIAgentAdapter(BaseAgentAdapter):
"""Adapter for OpenAI Assistants"""
"""Adapter for OpenAI Assistants.
model_config = {"arbitrary_types_allowed": True}
Integrates OpenAI Assistants API with CrewAI's agent system, providing
tool configuration, structured output handling, and task execution.
"""
_openai_agent: "OpenAIAgent" = PrivateAttr()
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
_active_thread: Optional[str] = PrivateAttr(default=None)
model_config = ConfigDict(arbitrary_types_allowed=True)
_openai_agent: OpenAIAgentProtocol = PrivateAttr()
_logger: Logger = PrivateAttr(default_factory=Logger)
_active_thread: str | None = PrivateAttr(default=None)
function_calling_llm: Any = Field(default=None)
step_callback: Any = Field(default=None)
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
_tool_adapter: OpenAIAgentToolAdapter = PrivateAttr()
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
def __init__(
self,
model: str = "gpt-4o-mini",
tools: Optional[List[BaseTool]] = None,
agent_config: Optional[dict] = None,
**kwargs,
):
if not OPENAI_AVAILABLE:
raise ImportError(
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
)
else:
role = kwargs.pop("role", None)
goal = kwargs.pop("goal", None)
backstory = kwargs.pop("backstory", None)
super().__init__(
role=role,
goal=goal,
backstory=backstory,
tools=tools,
agent_config=agent_config,
**kwargs,
)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self.llm = model
self._converter_adapter = OpenAIConverterAdapter(self)
**kwargs: Unpack[AgentKwargs],
) -> None:
"""Initialize the OpenAI agent adapter.
Args:
**kwargs: All initialization arguments including role, goal, backstory,
model, tools, and agent_config.
Raises:
ImportError: If OpenAI agent dependencies are not installed.
"""
self.llm = kwargs.pop("model", "gpt-4o-mini")
super().__init__(**kwargs)
self._tool_adapter = OpenAIAgentToolAdapter(tools=kwargs.get("tools"))
self._converter_adapter = OpenAIConverterAdapter(agent_adapter=self)
def _build_system_prompt(self) -> str:
"""Build a system prompt for the OpenAI agent."""
"""Build a system prompt for the OpenAI agent.
Creates a prompt containing the agent's role, goal, and backstory,
then enhances it with structured output instructions if needed.
Returns:
The complete system prompt string.
"""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -84,10 +105,25 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task using the OpenAI Assistant"""
"""Execute a task using the OpenAI Assistant.
Configures the assistant, processes the task, and handles event emission
for execution tracking.
Args:
task: The task object to execute.
context: Optional context information for the task.
tools: Optional additional tools for this execution.
Returns:
The final answer from the task execution.
Raises:
Exception: If task execution fails.
"""
self._converter_adapter.configure_structured_output(task)
self.create_agent_executor(tools)
@@ -95,7 +131,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
enable_verbose_stdout_logging()
try:
task_prompt = task.prompt()
task_prompt: str = task.prompt()
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
@@ -109,8 +145,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
task=task,
),
)
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
final_answer = self.handle_execution_result(result)
result: Any = self.agent_executor.run_sync(self._openai_agent, task_prompt)
final_answer: str = self.handle_execution_result(result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(
@@ -120,7 +156,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -131,15 +167,22 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""
Configure the OpenAI agent for execution.
While OpenAI handles execution differently through Runner,
we can use this method to set up tools and configurations.
"""
all_tools = list(self.tools or []) + list(tools or [])
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the OpenAI agent for execution.
instructions = self._build_system_prompt()
While OpenAI handles execution differently through Runner,
this method sets up tools and agent configuration.
Args:
tools: Optional tools to configure for the agent.
Notes:
TODO: Properly type agent_executor in BaseAgent to avoid type issues
when assigning Runner class to this attribute.
"""
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
instructions: str = self._build_system_prompt()
self._openai_agent = OpenAIAgent(
name=self.role,
instructions=instructions,
@@ -152,27 +195,48 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self.agent_executor = Runner
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the OpenAI Assistant"""
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the OpenAI Assistant.
Args:
tools: Optional tools to configure for the assistant.
"""
if tools:
self._tool_adapter.configure_tools(tools)
if self._tool_adapter.converted_tools:
self._openai_agent.tools = self._tool_adapter.converted_tools
def handle_execution_result(self, result: Any) -> str:
"""Process OpenAI Assistant execution result converting any structured output to a string"""
"""Process OpenAI Assistant execution result.
Converts any structured output to a string through the converter adapter.
Args:
result: The execution result from the OpenAI assistant.
Returns:
Processed result as a string.
"""
return self._converter_adapter.post_process_result(result.final_output)
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support"""
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support.
def configure_structured_output(self, task) -> None:
Creates delegation tools that allow this agent to delegate tasks to other agents.
Args:
agents: List of agents available for delegation.
Returns:
List of delegation tools.
"""
agent_tools: AgentTools = AgentTools(agents=agents)
return agent_tools.tools()
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for the specific agent implementation.
Args:
structured_output: The structured output to be configured
task: The task object containing output format specifications.
"""
self._converter_adapter.configure_structured_output(task)

View File

@@ -1,57 +1,125 @@
import inspect
from typing import Any, List, Optional
"""OpenAI agent tool adapter for CrewAI tool integration.
from agents import FunctionTool, Tool
This module contains the OpenAIAgentToolAdapter class that converts CrewAI tools
to OpenAI Assistant-compatible format using the agents library.
"""
import inspect
import json
import re
from collections.abc import Awaitable
from typing import Any, cast
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
from crewai.agents.agent_adapters.openai_agents.protocols import (
OpenAIFunctionTool,
OpenAITool,
)
from crewai.tools import BaseTool
from crewai.utilities.import_utils import require
agents_module = cast(
Any,
require(
"agents",
purpose="OpenAI agents functionality",
),
)
FunctionTool = agents_module.FunctionTool
Tool = agents_module.Tool
class OpenAIAgentToolAdapter(BaseToolAdapter):
"""Adapter for OpenAI Assistant tools"""
"""Adapter for OpenAI Assistant tools.
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
Converts CrewAI BaseTool instances to OpenAI Assistant FunctionTool format
that can be used by OpenAI agents.
"""
def configure_tools(self, tools: List[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant"""
def __init__(self, tools: list[BaseTool] | None = None) -> None:
"""Initialize the tool adapter.
Args:
tools: Optional list of CrewAI tools to adapt.
"""
super().__init__()
self.original_tools: list[BaseTool] = tools or []
self.converted_tools: list[OpenAITool] = []
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant.
Merges provided tools with original tools and converts them to
OpenAI Assistant format.
Args:
tools: List of CrewAI tools to configure.
"""
if self.original_tools:
all_tools = tools + self.original_tools
all_tools: list[BaseTool] = tools + self.original_tools
else:
all_tools = tools
if all_tools:
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
@staticmethod
def _convert_tools_to_openai_format(
self, tools: Optional[List[BaseTool]]
) -> List[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format"""
tools: list[BaseTool] | None,
) -> list[OpenAITool]:
"""Convert CrewAI tools to OpenAI Assistant tool format.
Args:
tools: List of CrewAI tools to convert.
Returns:
List of OpenAI Assistant FunctionTool instances.
"""
if not tools:
return []
def sanitize_tool_name(name: str) -> str:
"""Convert tool name to match OpenAI's required pattern"""
import re
"""Convert tool name to match OpenAI's required pattern.
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
return sanitized
Args:
name: Original tool name.
def create_tool_wrapper(tool: BaseTool):
"""Create a wrapper function that handles the OpenAI function tool interface"""
Returns:
Sanitized tool name matching OpenAI requirements.
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
def create_tool_wrapper(tool: BaseTool) -> Any:
"""Create a wrapper function that handles the OpenAI function tool interface.
Args:
tool: The CrewAI tool to wrap.
Returns:
Async wrapper function for OpenAI agent integration.
"""
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
"""Wrapper function to adapt CrewAI tool calls to OpenAI format.
Args:
context_wrapper: OpenAI context wrapper.
arguments: Tool arguments from OpenAI.
Returns:
Tool execution result.
"""
# Get the parameter name from the schema
param_name = list(
tool.args_schema.model_json_schema()["properties"].keys()
)[0]
param_name: str = next(
iter(tool.args_schema.model_json_schema()["properties"].keys())
)
# Handle different argument types
args_dict: dict[str, Any]
if isinstance(arguments, dict):
args_dict = arguments
elif isinstance(arguments, str):
try:
import json
args_dict = json.loads(arguments)
except json.JSONDecodeError:
args_dict = {param_name: arguments}
@@ -59,11 +127,11 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
args_dict = {param_name: str(arguments)}
# Run the tool with the processed arguments
output = tool._run(**args_dict)
output: Any | Awaitable[Any] = tool._run(**args_dict)
# Await if the tool returned a coroutine
if inspect.isawaitable(output):
result = await output
result: Any = await output
else:
result = output
@@ -74,17 +142,20 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
return wrapper
openai_tools = []
openai_tools: list[OpenAITool] = []
for tool in tools:
schema = tool.args_schema.model_json_schema()
schema: dict[str, Any] = tool.args_schema.model_json_schema()
schema.update({"additionalProperties": False, "type": "object"})
openai_tool = FunctionTool(
name=sanitize_tool_name(tool.name),
description=tool.description,
params_json_schema=schema,
on_invoke_tool=create_tool_wrapper(tool),
openai_tool: OpenAIFunctionTool = cast(
OpenAIFunctionTool,
FunctionTool(
name=sanitize_tool_name(tool.name),
description=tool.description,
params_json_schema=schema,
on_invoke_tool=create_tool_wrapper(tool),
),
)
openai_tools.append(openai_tool)

View File

@@ -0,0 +1,74 @@
"""Type protocols for OpenAI agents modules."""
from collections.abc import Callable
from typing import Any, Protocol, TypedDict, runtime_checkable
from crewai.tools.base_tool import BaseTool
class AgentKwargs(TypedDict, total=False):
"""Typed dict for agent initialization kwargs."""
role: str
goal: str
backstory: str
model: str
tools: list[BaseTool] | None
agent_config: dict[str, Any] | None
@runtime_checkable
class OpenAIAgent(Protocol):
"""Protocol for OpenAI Agent."""
def __init__(
self,
name: str,
instructions: str,
model: str,
**kwargs: Any,
) -> None:
"""Initialize the OpenAI agent."""
...
tools: list[Any]
output_type: Any
@runtime_checkable
class OpenAIRunner(Protocol):
"""Protocol for OpenAI Runner."""
@classmethod
def run_sync(cls, agent: OpenAIAgent, message: str) -> Any:
"""Run agent synchronously with a message."""
...
@runtime_checkable
class OpenAIAgentsModule(Protocol):
"""Protocol for OpenAI agents module."""
Agent: type[OpenAIAgent]
Runner: type[OpenAIRunner]
enable_verbose_stdout_logging: Callable[[], None]
@runtime_checkable
class OpenAITool(Protocol):
"""Protocol for OpenAI Tool."""
@runtime_checkable
class OpenAIFunctionTool(Protocol):
"""Protocol for OpenAI FunctionTool."""
def __init__(
self,
name: str,
description: str,
params_json_schema: dict[str, Any],
on_invoke_tool: Any,
) -> None:
"""Initialize the function tool."""
...

View File

@@ -1,5 +1,12 @@
"""OpenAI structured output converter for CrewAI task integration.
This module contains the OpenAIConverterAdapter class that handles structured
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
"""
import json
import re
from typing import Any, Literal
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
from crewai.utilities.converter import generate_model_description
@@ -7,8 +14,7 @@ from crewai.utilities.i18n import I18N
class OpenAIConverterAdapter(BaseConverterAdapter):
"""
Adapter for handling structured output conversion in OpenAI agents.
"""Adapter for handling structured output conversion in OpenAI agents.
This adapter enhances the OpenAI agent to handle structured output formats
and post-processes the results when needed.
@@ -19,19 +25,23 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
_output_model: The Pydantic model for the output
"""
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._output_model = None
def configure_structured_output(self, task) -> None:
"""
Configure the structured output for OpenAI agent based on task requirements.
def __init__(self, agent_adapter: Any) -> None:
"""Initialize the converter adapter with a reference to the agent adapter.
Args:
task: The task containing output format requirements
agent_adapter: The OpenAI agent adapter instance.
"""
super().__init__(agent_adapter=agent_adapter)
self.agent_adapter: Any = agent_adapter
self._output_format: Literal["json", "pydantic"] | None = None
self._schema: str | None = None
self._output_model: Any = None
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for OpenAI agent based on task requirements.
Args:
task: The task containing output format requirements.
"""
# Reset configuration
self._output_format = None
@@ -55,19 +65,18 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
self._output_model = task.output_pydantic
def enhance_system_prompt(self, base_prompt: str) -> str:
"""
Enhance the base system prompt with structured output requirements if needed.
"""Enhance the base system prompt with structured output requirements if needed.
Args:
base_prompt: The original system prompt
base_prompt: The original system prompt.
Returns:
Enhanced system prompt with output format instructions if needed
Enhanced system prompt with output format instructions if needed.
"""
if not self._output_format:
return base_prompt
output_schema = (
output_schema: str = (
I18N()
.slice("formatted_task_instructions")
.format(output_format=self._schema)
@@ -76,16 +85,15 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return f"{base_prompt}\n\n{output_schema}"
def post_process_result(self, result: str) -> str:
"""
Post-process the result to ensure it matches the expected format.
"""Post-process the result to ensure it matches the expected format.
This method attempts to extract valid JSON from the result if necessary.
Args:
result: The raw result from the agent
result: The raw result from the agent.
Returns:
Processed result conforming to the expected output format
Processed result conforming to the expected output format.
"""
if not self._output_format:
return result
@@ -97,26 +105,30 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return result
except json.JSONDecodeError:
# Try to extract JSON from markdown code blocks
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)```"
code_blocks = re.findall(code_block_pattern, result)
code_block_pattern: str = r"```(?:json)?\s*([\s\S]*?)```"
code_blocks: list[str] = re.findall(code_block_pattern, result)
for block in code_blocks:
stripped_block = block.strip()
try:
json.loads(block.strip())
return block.strip()
json.loads(stripped_block)
return stripped_block
except json.JSONDecodeError:
continue
pass
# Try to extract any JSON-like structure
json_pattern = r"(\{[\s\S]*\})"
json_matches = re.findall(json_pattern, result, re.DOTALL)
json_pattern: str = r"(\{[\s\S]*\})"
json_matches: list[str] = re.findall(json_pattern, result, re.DOTALL)
for match in json_matches:
is_valid = True
try:
json.loads(match)
return match
except json.JSONDecodeError:
continue
is_valid = False
if is_valid:
return match
# If all extraction attempts fail, return the original
return str(result)

25
src/crewai/context.py Normal file
View File

@@ -0,0 +1,25 @@
import os
import contextvars
from typing import Optional
from contextlib import contextmanager
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"platform_integration_token", default=None
)
def set_platform_integration_token(integration_token: str) -> None:
_platform_integration_token.set(integration_token)
def get_platform_integration_token() -> Optional[str]:
token = _platform_integration_token.get()
if token is None:
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
return token
@contextmanager
def platform_context(integration_token: str):
token = _platform_integration_token.set(integration_token)
try:
yield
finally:
_platform_integration_token.reset(token)

View File

@@ -3,26 +3,17 @@ import json
import re
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -39,26 +30,15 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
should_auto_collect_first_time_traces,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -70,16 +50,28 @@ from crewai.events.types.crew_events import (
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -94,28 +86,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
class Crew(FlowTrackable, BaseModel):
"""
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
Represents a group of agents, defining how they should collaborate and the
tasks they should perform.
Attributes:
tasks: List of tasks assigned to the crew.
agents: List of agents part of this crew.
tasks: list of tasks assigned to the crew.
agents: list of agents part of this crew.
manager_llm: The language model that will run manager agent.
manager_agent: Custom agent that will be used as manager.
memory: Whether the crew should use memory to store memories of it's execution.
cache: Whether the crew should use a cache to store the results of the tools execution.
function_calling_llm: The language model that will run the tool calling for all the agents.
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
memory: Whether the crew should use memory to store memories of it's
execution.
cache: Whether the crew should use a cache to store the results of the
tools execution.
function_calling_llm: The language model that will run the tool calling
for all the agents.
process: The process flow that the crew will follow (e.g., sequential,
hierarchical).
verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew.
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
max_rpm: Maximum number of requests per minute for the crew execution to
be respected.
prompt_file: Path to the prompt json file to be used for the crew.
id: A unique identifier for the crew instance.
task_callback: Callback to be executed after each task for every agents execution.
step_callback: Callback to be executed after each step for every agents execution.
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
task_callback: Callback to be executed after each task for every agents
execution.
step_callback: Callback to be executed after each step for every agents
execution.
share_crew: Whether you want to share the complete crew information and
execution with crewAI to make the library better, and allow us to
train models.
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
security_config: Security configuration for the crew, including fingerprinting.
chat_llm: The language model used for orchestrating chat interactions
with the crew.
security_config: Security configuration for the crew, including
fingerprinting.
"""
__hash__ = object.__hash__ # type: ignore
@@ -124,13 +128,13 @@ class Crew(FlowTrackable, BaseModel):
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
@@ -138,107 +142,121 @@ class Crew(FlowTrackable, BaseModel):
default_factory=TaskOutputStorageHandler
)
name: Optional[str] = Field(default="crew")
name: str | None = Field(default="crew")
cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
tasks: list[Task] = Field(default_factory=list)
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool = Field(
default=False,
description="Whether the crew should use memory to store memories of it's execution",
description="If crew should use memory to store memories of it's execution",
)
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
default=None,
description="An Instance of the ShortTermMemory to be used by the Crew",
)
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
default=None,
description="An Instance of the LongTermMemory to be used by the Crew",
)
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
entity_memory: InstanceOf[EntityMemory] | None = Field(
default=None,
description="An Instance of the EntityMemory to be used by the Crew",
)
external_memory: Optional[InstanceOf[ExternalMemory]] = Field(
external_memory: InstanceOf[ExternalMemory] | None = Field(
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
embedder: dict | None = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
usage_metrics: Optional[UsageMetrics] = Field(
usage_metrics: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
config: Json | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
step_callback: Optional[Any] = Field(
share_crew: bool | None = Field(default=False)
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step for all agents execution.",
)
task_callback: Optional[Any] = Field(
task_callback: Any | None = Field(
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: List[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
before_kickoff_callbacks: list[
Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field(
default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
)
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
description=(
"List of callbacks to be executed after crew kickoff. "
"It may be used to adjust the output of the crew."
),
)
max_rpm: Optional[int] = Field(
max_rpm: int | None = Field(
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
description=(
"Maximum number of requests per minute for the crew execution "
"to be respected."
),
)
prompt_file: Optional[str] = Field(
prompt_file: str | None = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
output_log_file: Optional[Union[bool, str]] = Field(
output_log_file: bool | str | None = Field(
default=None,
description="Path to the log file to be saved",
)
planning: Optional[bool] = Field(
planning: bool | None = Field(
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
description=(
"Language model that will run the AgentPlanner if planning is True."
),
)
task_execution_output_json_files: Optional[List[str]] = Field(
task_execution_output_json_files: list[str] | None = Field(
default=None,
description="List of file paths for task execution JSON files.",
description="list of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: list[dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
description="list of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
description=(
"Knowledge sources for the crew. Add knowledge sources to the "
"knowledge object."
),
)
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
knowledge: Optional[Knowledge] = Field(
knowledge: Knowledge | None = Field(
default=None,
description="Knowledge for the crew.",
)
@@ -246,18 +264,18 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
token_usage: Optional[UsageMetrics] = Field(
token_usage: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
tracing: Optional[bool] = Field(
tracing: bool | None = Field(
default=False,
description="Whether to enable tracing for the crew.",
)
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
"""Prevent manual setting of the 'id' field by users."""
if v:
raise PydanticCustomError(
@@ -266,9 +284,7 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
@@ -281,12 +297,16 @@ class Crew(FlowTrackable, BaseModel):
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
"""Set private attributes."""
"""set private attributes."""
self._cache_handler = CacheHandler()
event_listener = EventListener()
if is_tracing_enabled() or self.tracing:
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
event_listener.verbose = self.verbose
@@ -314,7 +334,8 @@ class Crew(FlowTrackable, BaseModel):
def create_crew_memory(self) -> "Crew":
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
# External memory does not support a default value since it was
# designed to be managed entirely externally
self.external_memory.set_crew(self) if self.external_memory else None
)
@@ -355,7 +376,10 @@ class Crew(FlowTrackable, BaseModel):
if not self.manager_llm and not self.manager_agent:
raise PydanticCustomError(
"missing_manager_llm_or_manager_agent",
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.",
(
"Attribute `manager_llm` or `manager_agent` is required "
"when using hierarchical process."
),
{},
)
@@ -398,7 +422,10 @@ class Crew(FlowTrackable, BaseModel):
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Sequential process error: Agent is missing in the task "
f"with the following description: {task.description}"
), # type: ignore # Dynamic string in error message
{},
)
@@ -459,7 +486,10 @@ class Crew(FlowTrackable, BaseModel):
if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError(
"invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
(
f"Conditional Task: {task.description}, "
f"cannot be executed asynchronously."
),
{},
)
return self
@@ -478,7 +508,9 @@ class Crew(FlowTrackable, BaseModel):
for j in range(i - 1, -1, -1):
if self.tasks[j] == context_task:
raise ValueError(
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
f"Task '{task.description}' is asynchronous and "
f"cannot include other sequential asynchronous "
f"tasks in its context."
)
if not self.tasks[j].async_execution:
break
@@ -496,13 +528,15 @@ class Crew(FlowTrackable, BaseModel):
continue # Skip context tasks not in the main tasks list
if task_indices[id(context_task)] > task_indices[id(task)]:
raise ValueError(
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
f"Task '{task.description}' has a context dependency "
f"on a future task '{context_task.description}', "
f"which is not allowed."
)
return self
@property
def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [
source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -518,9 +552,9 @@ class Crew(FlowTrackable, BaseModel):
return self.security_config.fingerprint
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
if self.config is None:
raise ValueError("Config should not be None.")
if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError(
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
@@ -530,7 +564,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task:
def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration.
Args:
@@ -559,7 +593,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file()
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
) -> None:
"""Trains the crew for a given number of iterations."""
inputs = inputs or {}
@@ -611,7 +645,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: dict[str, Any] | None = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -682,9 +716,9 @@ class Crew(FlowTrackable, BaseModel):
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""
results: List[CrewOutput] = []
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input and aggregates results."""
results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics()
@@ -703,14 +737,12 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
async def kickoff_async(
self, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
@@ -739,7 +771,9 @@ class Crew(FlowTrackable, BaseModel):
tasks=self.tasks, planning_agent_llm=self.planning_llm
)._handle_crew_planning()
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task):
for task, step_plan in zip(
self.tasks, result.list_of_plans_per_task, strict=False
):
task.description += step_plan.plan
def _store_execution_log(
@@ -776,7 +810,7 @@ class Crew(FlowTrackable, BaseModel):
return self._execute_tasks(self.tasks)
def _run_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
"""Creates and assigns a manager agent to complete the tasks."""
self._create_manager_agent()
return self._execute_tasks(self.tasks)
@@ -807,23 +841,24 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks(
self,
tasks: List[Task],
start_index: Optional[int] = 0,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.
Args:
tasks (List[Task]): List of tasks to execute
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
manager (Optional[BaseAgent], optional): Manager agent to use for
delegation. Defaults to None.
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None
task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
@@ -838,7 +873,9 @@ class Crew(FlowTrackable, BaseModel):
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
# Determine which tools to use - task tools take precedence over agent tools
@@ -847,7 +884,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
cast(list[Tool] | list[BaseTool], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -867,7 +904,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
@@ -879,7 +916,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=cast(list[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -893,11 +930,11 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]],
task_outputs: list[TaskOutput],
futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
) -> TaskOutput | None:
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
@@ -917,8 +954,8 @@ class Crew(FlowTrackable, BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
@@ -947,22 +984,22 @@ class Crew(FlowTrackable, BaseModel):
):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
# Return a List[BaseTool] compatible with Task.execute_sync and execute_async
return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
if self.process == Process.hierarchical:
return self.manager_agent
return task.agent
def _merge_tools(
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
existing_tools: list[Tool] | list[BaseTool],
new_tools: list[Tool] | list[BaseTool],
) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates."""
if not new_tools:
return cast(List[BaseTool], existing_tools)
return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +1010,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools
tools.extend(new_tools)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _inject_delegation_tools(
self,
tools: Union[List[Tool], List[BaseTool]],
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
agents: list[BaseAgent],
) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(list[BaseTool], tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(list[BaseTool], tools)
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -1015,7 +1052,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
@@ -1024,8 +1061,8 @@ class Crew(FlowTrackable, BaseModel):
)
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,18 +1070,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
context = (
return (
aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_tasks(task.context)
)
return context
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None"
@@ -1057,7 +1093,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs:
raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1124,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks(
self,
futures: List[Tuple[Task, Future[TaskOutput], int]],
futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False,
) -> List[TaskOutput]:
task_outputs: List[TaskOutput] = []
) -> list[TaskOutput]:
task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures:
task_output = future.result()
task_outputs.append(task_output)
@@ -1101,9 +1137,7 @@ class Crew(FlowTrackable, BaseModel):
)
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]:
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
return next(
(
index
@@ -1113,9 +1147,8 @@ class Crew(FlowTrackable, BaseModel):
None,
)
def replay(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
"""Replay the crew execution from a specific task."""
stored_outputs = self._task_output_handler.load()
if not stored_outputs:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
@@ -1151,19 +1184,19 @@ class Crew(FlowTrackable, BaseModel):
self.tasks[i].output = task_output
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, start_index, True)
return result
return self._execute_tasks(self.tasks, start_index, True)
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information."""
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,11 +1205,11 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks for inputs
for task in self.tasks:
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
# description and expected_output might contain e.g. {topic}, {user_name}
text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text))
@@ -1230,7 +1263,7 @@ class Crew(FlowTrackable, BaseModel):
cloned_tasks.append(cloned_task)
task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
if isinstance(original_task.context, list):
cloned_context = [
task_mapping[context_task.key]
@@ -1256,7 +1289,7 @@ class Crew(FlowTrackable, BaseModel):
copied_data.pop("agents", None)
copied_data.pop("tasks", None)
copied_crew = Crew(
return Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
@@ -1266,15 +1299,13 @@ class Crew(FlowTrackable, BaseModel):
manager_llm=manager_llm,
)
return copied_crew
def _set_tasks_callbacks(self) -> None:
"""Sets callback for every task suing task_callback"""
for task in self.tasks:
if not task.callback:
task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents."""
[
task.interpolate_inputs_and_add_conversation_history(
@@ -1307,10 +1338,13 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
eval_llm: str | InstanceOf[BaseLLM],
inputs: dict[str, Any] | None = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
"""Test and evaluate the Crew with the given inputs for n iterations.
Uses concurrent.futures for concurrent execution.
"""
try:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
@@ -1350,7 +1384,11 @@ class Crew(FlowTrackable, BaseModel):
raise
def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
return (
f"Crew(id={self.id}, process={self.process}, "
f"number_of_agents={len(self.agents)}, "
f"number_of_tasks={len(self.tasks)})"
)
def reset_memories(self, command_type: str) -> None:
"""Reset specific or all memories for the crew.
@@ -1364,7 +1402,7 @@ class Crew(FlowTrackable, BaseModel):
ValueError: If an invalid command type is provided.
RuntimeError: If memory reset operation fails.
"""
VALID_TYPES = frozenset(
valid_types = frozenset(
[
"long",
"short",
@@ -1377,9 +1415,10 @@ class Crew(FlowTrackable, BaseModel):
]
)
if command_type not in VALID_TYPES:
if command_type not in valid_types:
raise ValueError(
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
f"Invalid command type. Must be one of: "
f"{', '.join(sorted(valid_types))}"
)
try:
@@ -1389,7 +1428,7 @@ class Crew(FlowTrackable, BaseModel):
self._reset_specific_memory(command_type)
except Exception as e:
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
error_msg = f"Failed to reset {command_type} memory: {e!s}"
self._logger.log("error", error_msg)
raise RuntimeError(error_msg) from e
@@ -1397,7 +1436,7 @@ class Crew(FlowTrackable, BaseModel):
"""Reset all available memory systems."""
memory_systems = self._get_memory_systems()
for memory_type, config in memory_systems.items():
for config in memory_systems.values():
if (system := config.get("system")) is not None:
name = config.get("name")
try:
@@ -1405,11 +1444,13 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _reset_specific_memory(self, memory_type: str) -> None:
@@ -1434,18 +1475,21 @@ class Crew(FlowTrackable, BaseModel):
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _get_memory_systems(self):
"""Get all available memory systems with their configuration.
Returns:
Dict containing all memory systems with their reset functions and display names.
Dict containing all memory systems with their reset functions and
display names.
"""
def default_reset(memory):
@@ -1506,7 +1550,7 @@ class Crew(FlowTrackable, BaseModel):
},
}
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()

View File

@@ -9,48 +9,158 @@ This module provides the event infrastructure that allows users to:
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.logging_events import (
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeRetrievalStartedEvent,
KnowledgeRetrievalCompletedEvent,
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
ReasoningEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffStartedEvent,
CrewKickoffCompletedEvent,
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskEvaluationEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
)
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
from crewai.events.types.tool_usage_events import (
ToolExecutionErrorEvent,
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
ToolValidateInputErrorEvent,
)
__all__ = [
"AgentEvaluationCompletedEvent",
"AgentEvaluationFailedEvent",
"AgentEvaluationStartedEvent",
"AgentExecutionCompletedEvent",
"AgentExecutionErrorEvent",
"AgentExecutionStartedEvent",
"AgentLogsExecutionEvent",
"AgentLogsStartedEvent",
"AgentReasoningCompletedEvent",
"AgentReasoningFailedEvent",
"AgentReasoningStartedEvent",
"BaseEventListener",
"crewai_event_bus",
"CrewKickoffCompletedEvent",
"CrewKickoffFailedEvent",
"CrewKickoffStartedEvent",
"CrewTestCompletedEvent",
"CrewTestFailedEvent",
"CrewTestResultEvent",
"CrewTestStartedEvent",
"CrewTrainCompletedEvent",
"CrewTrainFailedEvent",
"CrewTrainStartedEvent",
"FlowCreatedEvent",
"FlowEvent",
"FlowFinishedEvent",
"FlowPlotEvent",
"FlowStartedEvent",
"KnowledgeQueryCompletedEvent",
"KnowledgeQueryFailedEvent",
"KnowledgeQueryStartedEvent",
"KnowledgeRetrievalCompletedEvent",
"KnowledgeRetrievalStartedEvent",
"KnowledgeSearchQueryFailedEvent",
"LLMCallCompletedEvent",
"LLMCallFailedEvent",
"LLMCallStartedEvent",
"LLMGuardrailCompletedEvent",
"LLMGuardrailStartedEvent",
"LLMStreamChunkEvent",
"LiteAgentExecutionCompletedEvent",
"LiteAgentExecutionErrorEvent",
"LiteAgentExecutionStartedEvent",
"MemoryQueryCompletedEvent",
"MemorySaveCompletedEvent",
"MemorySaveStartedEvent",
"MemoryQueryFailedEvent",
"MemoryQueryStartedEvent",
"MemoryRetrievalCompletedEvent",
"MemoryRetrievalStartedEvent",
"MemorySaveCompletedEvent",
"MemorySaveFailedEvent",
"MemoryQueryFailedEvent",
"KnowledgeRetrievalStartedEvent",
"KnowledgeRetrievalCompletedEvent",
"CrewKickoffStartedEvent",
"CrewKickoffCompletedEvent",
"AgentExecutionCompletedEvent",
"LLMStreamChunkEvent",
]
"MemorySaveStartedEvent",
"MethodExecutionFailedEvent",
"MethodExecutionFinishedEvent",
"MethodExecutionStartedEvent",
"ReasoningEvent",
"TaskCompletedEvent",
"TaskEvaluationEvent",
"TaskFailedEvent",
"TaskStartedEvent",
"ToolExecutionErrorEvent",
"ToolSelectionErrorEvent",
"ToolUsageErrorEvent",
"ToolUsageEvent",
"ToolUsageFinishedEvent",
"ToolUsageStartedEvent",
"ToolValidateInputErrorEvent",
"crewai_event_bus",
]

View File

@@ -0,0 +1,230 @@
import logging
import os
import uuid
import webbrowser
from pathlib import Path
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
from crewai.events.listeners.tracing.utils import (
mark_first_execution_completed,
prompt_user_for_trace_viewing,
should_auto_collect_first_time_traces,
)
logger = logging.getLogger(__name__)
def _update_or_create_env_file():
"""Update or create .env file with CREWAI_TRACING_ENABLED=true."""
env_path = Path(".env")
env_content = ""
variable_name = "CREWAI_TRACING_ENABLED"
variable_value = "true"
# Read existing content if file exists
if env_path.exists():
with open(env_path, "r") as f:
env_content = f.read()
# Check if CREWAI_TRACING_ENABLED is already set
lines = env_content.splitlines()
variable_exists = False
updated_lines = []
for line in lines:
if line.strip().startswith(f"{variable_name}="):
# Update existing variable
updated_lines.append(f"{variable_name}={variable_value}")
variable_exists = True
else:
updated_lines.append(line)
# Add variable if it doesn't exist
if not variable_exists:
if updated_lines and not updated_lines[-1].strip():
# If last line is empty, replace it
updated_lines[-1] = f"{variable_name}={variable_value}"
else:
# Add new line and then the variable
updated_lines.append(f"{variable_name}={variable_value}")
# Write updated content
with open(env_path, "w") as f:
f.write("\n".join(updated_lines))
if updated_lines: # Add final newline if there's content
f.write("\n")
class FirstTimeTraceHandler:
"""Handles the first-time user trace collection and display flow."""
def __init__(self):
self.is_first_time: bool = False
self.collected_events: bool = False
self.trace_batch_id: str | None = None
self.ephemeral_url: str | None = None
self.batch_manager: TraceBatchManager | None = None
def initialize_for_first_time_user(self) -> bool:
"""Check if this is first time and initialize collection."""
self.is_first_time = should_auto_collect_first_time_traces()
return self.is_first_time
def set_batch_manager(self, batch_manager: TraceBatchManager):
"""Set reference to batch manager for sending events."""
self.batch_manager = batch_manager
def mark_events_collected(self):
"""Mark that events have been collected during execution."""
self.collected_events = True
def handle_execution_completion(self):
"""Handle the completion flow as shown in your diagram."""
if not self.is_first_time or not self.collected_events:
return
try:
user_wants_traces = prompt_user_for_trace_viewing(timeout_seconds=20)
if user_wants_traces:
self._initialize_backend_and_send_events()
# Enable tracing for future runs by updating .env file
try:
_update_or_create_env_file()
except Exception as e:
pass
if self.ephemeral_url:
self._display_ephemeral_trace_link()
mark_first_execution_completed()
except Exception as e:
self._gracefully_fail(f"Error in trace handling: {e}")
mark_first_execution_completed()
def _initialize_backend_and_send_events(self):
"""Initialize backend batch and send collected events."""
if not self.batch_manager:
return
try:
if not self.batch_manager.backend_initialized:
original_metadata = (
self.batch_manager.current_batch.execution_metadata
if self.batch_manager.current_batch
else {}
)
user_context = {
"privacy_level": "standard",
"user_id": "first_time_user",
"session_id": str(uuid.uuid4()),
"trace_id": self.batch_manager.trace_batch_id,
}
execution_metadata = {
"execution_type": original_metadata.get("execution_type", "crew"),
"crew_name": original_metadata.get(
"crew_name", "First Time Execution"
),
"flow_name": original_metadata.get("flow_name"),
"agent_count": original_metadata.get("agent_count", 1),
"task_count": original_metadata.get("task_count", 1),
"crewai_version": original_metadata.get("crewai_version"),
}
self.batch_manager._initialize_backend_batch(
user_context=user_context,
execution_metadata=execution_metadata,
use_ephemeral=True,
)
self.batch_manager.backend_initialized = True
if self.batch_manager.event_buffer:
self.batch_manager._send_events_to_backend()
self.batch_manager.finalize_batch()
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
if not self.ephemeral_url:
self._show_local_trace_message()
except Exception as e:
self._gracefully_fail(f"Backend initialization failed: {e}")
def _display_ephemeral_trace_link(self):
"""Display the ephemeral trace link to the user and automatically open browser."""
console = Console()
try:
webbrowser.open(self.ephemeral_url)
except Exception as e:
pass
panel_content = f"""
🎉 Your First CrewAI Execution Trace is Ready!
View your execution details here:
{self.ephemeral_url}
This trace shows:
• Agent decisions and interactions
• Task execution timeline
• Tool usage and results
• LLM calls and responses
✅ Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
You can also add tracing=True to your Crew(tracing=True) / Flow(tracing=True) for more control.
📝 Note: This link will expire in 24 hours.
""".strip()
panel = Panel(
panel_content,
title="🔍 Execution Trace Generated",
border_style="bright_green",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
def _gracefully_fail(self, error_message: str):
"""Handle errors gracefully without disrupting user experience."""
console = Console()
console.print(f"[yellow]Note: {error_message}[/yellow]")
logger.debug(f"First-time trace error: {error_message}")
def _show_local_trace_message(self):
"""Show message when traces were collected locally but couldn't be uploaded."""
console = Console()
panel_content = f"""
📊 Your execution traces were collected locally!
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
{len(self.batch_manager.event_buffer)} trace events
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
• Batch ID: {self.batch_manager.trace_batch_id}
Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
The traces include agent decisions, task execution, and tool usage.
""".strip()
panel = Panel(
panel_content,
title="🔍 Local Traces Collected",
border_style="yellow",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()

View File

@@ -1,18 +1,18 @@
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime, timezone
from logging import getLogger
from typing import Any
from crewai.utilities.constants import CREWAI_BASE_URL
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.cli.plus_api import PlusAPI
from rich.console import Console
from rich.panel import Panel
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.cli.version import get_crewai_version
from crewai.events.listeners.tracing.types import TraceEvent
from logging import getLogger
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
from crewai.utilities.constants import CREWAI_BASE_URL
logger = getLogger(__name__)
@@ -23,11 +23,11 @@ class TraceBatch:
version: str = field(default_factory=get_crewai_version)
batch_id: str = field(default_factory=lambda: str(uuid.uuid4()))
user_context: Dict[str, str] = field(default_factory=dict)
execution_metadata: Dict[str, Any] = field(default_factory=dict)
events: List[TraceEvent] = field(default_factory=list)
user_context: dict[str, str] = field(default_factory=dict)
execution_metadata: dict[str, Any] = field(default_factory=dict)
events: list[TraceEvent] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"version": self.version,
"batch_id": self.batch_id,
@@ -40,26 +40,28 @@ class TraceBatch:
class TraceBatchManager:
"""Single responsibility: Manage batches and event buffering"""
is_current_batch_ephemeral: bool = False
trace_batch_id: Optional[str] = None
current_batch: Optional[TraceBatch] = None
event_buffer: List[TraceEvent] = []
execution_start_times: Dict[str, datetime] = {}
batch_owner_type: Optional[str] = None
batch_owner_id: Optional[str] = None
def __init__(self):
self.is_current_batch_ephemeral: bool = False
self.trace_batch_id: str | None = None
self.current_batch: TraceBatch | None = None
self.event_buffer: list[TraceEvent] = []
self.execution_start_times: dict[str, datetime] = {}
self.batch_owner_type: str | None = None
self.batch_owner_id: str | None = None
self.backend_initialized: bool = False
self.ephemeral_trace_url: str | None = None
try:
self.plus_api = PlusAPI(
api_key=get_auth_token(),
)
except AuthError:
self.plus_api = PlusAPI(api_key="")
self.ephemeral_trace_url = None
def initialize_batch(
self,
user_context: Dict[str, str],
execution_metadata: Dict[str, Any],
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
) -> TraceBatch:
"""Initialize a new trace batch"""
@@ -70,14 +72,21 @@ class TraceBatchManager:
self.is_current_batch_ephemeral = use_ephemeral
self.record_start_time("execution")
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
)
self.backend_initialized = True
return self.current_batch
def _initialize_backend_batch(
self,
user_context: Dict[str, str],
execution_metadata: Dict[str, Any],
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
):
"""Send batch initialization to backend"""
@@ -129,13 +138,6 @@ class TraceBatchManager:
if not use_ephemeral
else response_data["ephemeral_trace_id"]
)
console = Console()
panel = Panel(
f"✅ Trace batch initialized with session ID: {self.trace_batch_id}",
title="Trace Batch Initialization",
border_style="green",
)
console.print(panel)
else:
logger.warning(
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
@@ -143,7 +145,7 @@ class TraceBatchManager:
except Exception as e:
logger.warning(
f"Error initializing trace batch: {str(e)}. Continuing without tracing."
f"Error initializing trace batch: {e}. Continuing without tracing."
)
def add_event(self, trace_event: TraceEvent):
@@ -154,7 +156,6 @@ class TraceBatchManager:
"""Send buffered events to backend with graceful failure handling"""
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
return 500
try:
payload = {
"events": [event.to_dict() for event in self.event_buffer],
@@ -178,19 +179,19 @@ class TraceBatchManager:
if response.status_code in [200, 201]:
self.event_buffer.clear()
return 200
else:
logger.warning(
f"Failed to send events: {response.status_code}. Events will be lost."
)
return 500
except Exception as e:
logger.warning(
f"Error sending events to backend: {str(e)}. Events will be lost."
f"Failed to send events: {response.status_code}. Events will be lost."
)
return 500
def finalize_batch(self) -> Optional[TraceBatch]:
except Exception as e:
logger.warning(
f"Error sending events to backend: {e}. Events will be lost."
)
return 500
def finalize_batch(self) -> TraceBatch | None:
"""Finalize batch and return it for sending"""
if not self.current_batch:
return None
@@ -246,12 +247,27 @@ class TraceBatchManager:
if not self.is_current_batch_ephemeral and access_code is None
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
)
if self.is_current_batch_ephemeral:
self.ephemeral_trace_url = return_link
# Create a properly formatted message with URL on its own line
message_parts = [
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}",
"",
f"🔗 View here: {return_link}",
]
if access_code:
message_parts.append(f"🔑 Access Code: {access_code}")
panel = Panel(
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
"\n".join(message_parts),
title="Trace Batch Finalization",
border_style="green",
)
console.print(panel)
if not should_auto_collect_first_time_traces():
console.print(panel)
else:
logger.error(
@@ -259,8 +275,8 @@ class TraceBatchManager:
)
except Exception as e:
logger.error(f"❌ Error finalizing trace batch: {str(e)}")
# TODO: send error to app
logger.error(f"❌ Error finalizing trace batch: {e}")
# TODO: send error to app marking as failed
def _cleanup_batch_data(self):
"""Clean up batch data after successful finalization to free memory"""
@@ -277,7 +293,7 @@ class TraceBatchManager:
self.batch_sequence = 0
except Exception as e:
logger.error(f"Warning: Error during cleanup: {str(e)}")
logger.error(f"Warning: Error during cleanup: {e}")
def has_events(self) -> bool:
"""Check if there are events in the buffer"""
@@ -306,7 +322,7 @@ class TraceBatchManager:
return duration_ms
return 0
def get_trace_id(self) -> Optional[str]:
def get_trace_id(self) -> str | None:
"""Get current trace ID"""
if self.current_batch:
return self.current_batch.user_context.get("trace_id")

View File

@@ -1,28 +1,59 @@
import os
import uuid
from typing import Any, ClassVar
from typing import Dict, Any, Optional
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
AgentExecutionErrorEvent,
from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
from crewai.events.listeners.tracing.utils import safe_serialize_to_dict
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -33,49 +64,16 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
FlowFinishedEvent,
MethodExecutionStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
FlowPlotEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.utilities.serialization import to_serializable
from .trace_batch_manager import TraceBatchManager
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
class TraceCollectionListener(BaseEventListener):
"""
Trace collection listener that orchestrates trace collection
"""
complex_events = [
complex_events: ClassVar[list[str]] = [
"task_started",
"task_completed",
"llm_call_started",
@@ -88,14 +86,14 @@ class TraceCollectionListener(BaseEventListener):
_initialized = False
_listeners_setup = False
def __new__(cls, batch_manager=None):
def __new__(cls, batch_manager: TraceBatchManager | None = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
batch_manager: Optional[TraceBatchManager] = None,
batch_manager: TraceBatchManager | None = None,
):
if self._initialized:
return
@@ -103,16 +101,19 @@ class TraceCollectionListener(BaseEventListener):
super().__init__()
self.batch_manager = batch_manager or TraceBatchManager()
self._initialized = True
self.first_time_handler = FirstTimeTraceHandler()
if self.first_time_handler.initialize_for_first_time_user():
self.first_time_handler.set_batch_manager(self.batch_manager)
def _check_authenticated(self) -> bool:
"""Check if tracing should be enabled"""
try:
res = bool(get_auth_token())
return res
return bool(get_auth_token())
except AuthError:
return False
def _get_user_context(self) -> Dict[str, str]:
def _get_user_context(self) -> dict[str, str]:
"""Extract user context for tracing"""
return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
@@ -161,8 +162,14 @@ class TraceCollectionListener(BaseEventListener):
@event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event):
self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow":
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
# Normal flow finalization
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
@@ -181,12 +188,20 @@ class TraceCollectionListener(BaseEventListener):
def on_crew_completed(source, event):
self._handle_trace_event("crew_kickoff_completed", source, event)
if self.batch_manager.batch_owner_type == "crew":
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event):
self._handle_trace_event("crew_kickoff_failed", source, event)
self.batch_manager.finalize_batch()
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
self.batch_manager.finalize_batch()
@event_bus.on(TaskStartedEvent)
def on_task_started(source, event):
@@ -325,17 +340,19 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata)
def _initialize_batch(
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
):
"""Initialize trace batch if ephemeral"""
if not self._check_authenticated():
self.batch_manager.initialize_batch(
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
if self.first_time_handler.is_first_time:
return self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=True
)
else:
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=False
)
use_ephemeral = not self._check_authenticated()
return self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=use_ephemeral
)
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for context end events"""
@@ -371,11 +388,11 @@ class TraceCollectionListener(BaseEventListener):
def _build_event_data(
self, event_type: str, event: Any, source: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Build event data"""
if event_type not in self.complex_events:
return self._safe_serialize_to_dict(event)
elif event_type == "task_started":
return safe_serialize_to_dict(event)
if event_type == "task_started":
return {
"task_description": event.task.description,
"expected_output": event.task.expected_output,
@@ -384,7 +401,7 @@ class TraceCollectionListener(BaseEventListener):
"agent_role": source.agent.role,
"task_id": str(event.task.id),
}
elif event_type == "task_completed":
if event_type == "task_completed":
return {
"task_description": event.task.description if event.task else None,
"task_name": event.task.name or event.task.description
@@ -397,63 +414,31 @@ class TraceCollectionListener(BaseEventListener):
else None,
"agent_role": event.output.agent if event.output else None,
}
elif event_type == "agent_execution_started":
if event_type == "agent_execution_started":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "agent_execution_completed":
if event_type == "agent_execution_completed":
return {
"agent_role": event.agent.role,
"agent_goal": event.agent.goal,
"agent_backstory": event.agent.backstory,
}
elif event_type == "llm_call_started":
event_data = self._safe_serialize_to_dict(event)
if event_type == "llm_call_started":
event_data = safe_serialize_to_dict(event)
event_data["task_name"] = (
event.task_name or event.task_description
if hasattr(event, "task_name") and event.task_name
else None
)
return event_data
elif event_type == "llm_call_completed":
return self._safe_serialize_to_dict(event)
else:
return {
"event_type": event_type,
"event": self._safe_serialize_to_dict(event),
"source": source,
}
if event_type == "llm_call_completed":
return safe_serialize_to_dict(event)
# TODO: move to utils
def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None
) -> Dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
try:
serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict):
return serialized
else:
return {"serialized_data": serialized}
except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__}
# TODO: move to utils
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
"""Truncate message content and limit number of messages"""
if not messages or not isinstance(messages, list):
return messages
# Limit number of messages
limited_messages = messages[:max_messages]
# Truncate each message content
for msg in limited_messages:
if isinstance(msg, dict) and "content" in msg:
content = msg["content"]
if len(content) > max_content_length:
msg["content"] = content[:max_content_length] + "..."
return limited_messages
return {
"event_type": event_type,
"event": safe_serialize_to_dict(event),
"source": source,
}

View File

@@ -1,17 +1,25 @@
import getpass
import hashlib
import json
import logging
import os
import platform
import uuid
import hashlib
import subprocess
import getpass
from pathlib import Path
from datetime import datetime
import re
import json
import subprocess
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
import click
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from crewai.utilities.paths import db_storage_path
from crewai.utilities.serialization import to_serializable
logger = logging.getLogger(__name__)
def is_tracing_enabled() -> bool:
@@ -43,13 +51,11 @@ def _get_machine_id() -> str:
try:
mac = ":".join(
["{:02x}".format((uuid.getnode() >> b) & 0xFF) for b in range(0, 12, 2)][
::-1
]
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
)
parts.append(mac)
except Exception:
pass
logger.warning("Error getting machine id for fingerprinting")
sysname = platform.system()
parts.append(sysname)
@@ -57,7 +63,7 @@ def _get_machine_id() -> str:
try:
if sysname == "Darwin":
res = subprocess.run(
["system_profiler", "SPHardwareDataType"],
["/usr/sbin/system_profiler", "SPHardwareDataType"],
capture_output=True,
text=True,
timeout=2,
@@ -72,7 +78,7 @@ def _get_machine_id() -> str:
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
elif sysname == "Windows":
res = subprocess.run(
["wmic", "csproduct", "get", "UUID"],
["C:\\Windows\\System32\\wbem\\wmic.exe", "csproduct", "get", "UUID"],
capture_output=True,
text=True,
timeout=2,
@@ -81,7 +87,7 @@ def _get_machine_id() -> str:
if len(lines) >= 2:
parts.append(lines[1])
except Exception:
pass
logger.exception("Error getting machine ID")
return hashlib.sha256("".join(parts).encode()).hexdigest()
@@ -97,8 +103,8 @@ def _load_user_data() -> dict:
if p.exists():
try:
return json.loads(p.read_text())
except Exception:
pass
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning(f"Failed to load user data: {e}")
return {}
@@ -106,8 +112,8 @@ def _save_user_data(data: dict) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except Exception:
pass
except (OSError, PermissionError) as e:
logger.warning(f"Failed to save user data: {e}")
def get_user_id() -> str:
@@ -151,3 +157,103 @@ def mark_first_execution_done() -> None:
}
)
_save_user_data(data)
def safe_serialize_to_dict(obj, exclude: set[str] | None = None) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
try:
serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict):
return serialized
return {"serialized_data": serialized}
except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__}
def truncate_messages(messages, max_content_length=500, max_messages=5):
"""Truncate message content and limit number of messages"""
if not messages or not isinstance(messages, list):
return messages
limited_messages = messages[:max_messages]
for msg in limited_messages:
if isinstance(msg, dict) and "content" in msg:
content = msg["content"]
if len(content) > max_content_length:
msg["content"] = content[:max_content_length] + "..."
return limited_messages
def should_auto_collect_first_time_traces() -> bool:
"""True if we should auto-collect traces for first-time user."""
if _is_test_environment():
return False
return is_first_execution()
def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
"""
Prompt user if they want to see their traces with timeout.
Returns True if user wants to see traces, False otherwise.
"""
if _is_test_environment():
return False
try:
import threading
console = Console()
content = Text()
content.append("🔍 ", style="cyan bold")
content.append(
"Detailed execution traces are available!\n\n", style="cyan bold"
)
content.append("View insights including:\n", style="white")
content.append(" • Agent decision-making process\n", style="bright_blue")
content.append(" • Task execution flow and timing\n", style="bright_blue")
content.append(" • Tool usage details", style="bright_blue")
panel = Panel(
content,
title="[bold cyan]Execution Traces[/bold cyan]",
border_style="cyan",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
prompt_text = click.style(
f"Would you like to view your execution traces? [y/N] ({timeout_seconds}s timeout): ",
fg="white",
bold=True,
)
click.echo(prompt_text, nl=False)
result = [False]
def get_input():
try:
response = input().strip().lower()
result[0] = response in ["y", "yes"]
except (EOFError, KeyboardInterrupt):
result[0] = False
input_thread = threading.Thread(target=get_input, daemon=True)
input_thread.start()
input_thread.join(timeout=timeout_seconds)
if input_thread.is_alive():
return False
return result[0]
except Exception:
return False
def mark_first_execution_completed() -> None:
"""Mark first execution as completed (called after trace prompt)."""
mark_first_execution_done()

View File

@@ -2,30 +2,22 @@ import asyncio
import copy
import inspect
import logging
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from collections.abc import Callable
from typing import Any, ClassVar, Generic, TypeVar, cast
from uuid import uuid4
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from pydantic import BaseModel, Field, ValidationError
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.events.event_bus import crewai_event_bus
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
should_auto_collect_first_time_traces,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -35,12 +27,10 @@ from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer
logger = logging.getLogger(__name__)
@@ -55,16 +45,14 @@ class FlowState(BaseModel):
)
# Type variables with explicit bounds
T = TypeVar(
"T", bound=Union[Dict[str, Any], BaseModel]
) # Generic flow state type parameter
# type variables with explicit bounds
T = TypeVar("T", bound=dict[str, Any] | BaseModel) # Generic flow state type parameter
StateT = TypeVar(
"StateT", bound=Union[Dict[str, Any], BaseModel]
"StateT", bound=dict[str, Any] | BaseModel
) # State validation type parameter
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation.
Args:
@@ -104,7 +92,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
raise TypeError(f"Invalid expected_type: {expected_type}")
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
def start(condition: str | dict | Callable | None = None) -> Callable:
"""
Marks a method as a flow's starting point.
@@ -171,7 +159,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable:
def listen(condition: str | dict | Callable) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
@@ -231,7 +219,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def router(condition: Union[str, dict, Callable]) -> Callable:
def router(condition: str | dict | Callable) -> Callable:
"""
Creates a routing method that directs flow execution based on conditions.
@@ -297,7 +285,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict:
def or_(*conditions: str | dict | Callable) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
@@ -343,7 +331,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
return {"type": "OR", "methods": methods}
def and_(*conditions: Union[str, dict, Callable]) -> dict:
def and_(*conditions: str | dict | Callable) -> dict:
"""
Combines multiple conditions with AND logic for flow control.
@@ -425,10 +413,10 @@ class FlowMeta(type):
if possible_returns:
router_paths[attr_name] = possible_returns
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
cls._start_methods = start_methods
cls._listeners = listeners
cls._routers = routers
cls._router_paths = router_paths
return cls
@@ -436,29 +424,29 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
_printer = Printer()
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
name: Optional[str] = None
tracing: Optional[bool] = False
_start_methods: ClassVar[list[str]] = []
_listeners: ClassVar[dict[str, tuple[str, list[str]]]] = {}
_routers: ClassVar[set[str]] = set()
_router_paths: ClassVar[dict[str, list[str]]] = {}
initial_state: type[T] | T | None = None
name: str | None = None
tracing: bool | None = False
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
_initial_state_t = item # type: ignore
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
return _FlowGeneric
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
tracing: Optional[bool] = False,
persistence: FlowPersistence | None = None,
tracing: bool | None = False,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
@@ -468,18 +456,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: Dict[str, Callable] = {}
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._completed_methods: Set[str] = set() # Track completed methods for reload
self._persistence: Optional[FlowPersistence] = persistence
self._methods: dict[str, Callable] = {}
self._method_execution_counts: dict[str, int] = {}
self._pending_and_listeners: dict[str, set[str]] = {}
self._method_outputs: list[Any] = [] # list to store all method outputs
self._completed_methods: set[str] = set() # Track completed methods for reload
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
if is_tracing_enabled() or self.tracing:
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
@@ -521,25 +513,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
TypeError: If state is neither BaseModel nor dictionary
"""
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T")
if self.initial_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
instance.id = str(uuid4())
return cast(T, instance)
elif issubclass(state_type, BaseModel):
if issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
instance.id = str(uuid4())
return cast(T, instance)
elif state_type is dict:
if state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
@@ -550,13 +542,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
elif issubclass(self.initial_state, BaseModel):
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
elif self.initial_state is dict:
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
@@ -600,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._state
@property
def method_outputs(self) -> List[Any]:
def method_outputs(self) -> list[Any]:
"""Returns the list of all outputs from executed methods."""
return self._method_outputs
@@ -631,13 +623,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
if isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", ""))
return ""
except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
def _initialize_state(self, inputs: dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Args:
@@ -691,7 +683,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
def _restore_state(self, stored_state: dict[str, Any]) -> None:
"""Restore flow state from persistence.
Args:
@@ -735,7 +727,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
execution_data: Flow execution data containing:
- id: Flow execution ID
- flow: Flow structure
- completed_methods: List of successfully completed methods
- completed_methods: list of successfully completed methods
- execution_methods: All execution methods with their status
"""
flow_id = execution_data.get("id")
@@ -771,7 +763,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if state_to_apply:
self._apply_state_updates(state_to_apply)
for i, method in enumerate(sorted_methods[:-1]):
for method in sorted_methods[:-1]:
method_name = method.get("flow_method", {}).get("name")
if method_name:
self._completed_methods.add(method_name)
@@ -783,7 +775,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif hasattr(self._state, field_name):
object.__setattr__(self._state, field_name, value)
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
def _apply_state_updates(self, updates: dict[str, Any]) -> None:
"""Apply multiple state updates efficiently."""
if isinstance(self._state, dict):
self._state.update(updates)
@@ -792,7 +784,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if hasattr(self._state, key):
object.__setattr__(self._state, key, value)
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
"""
Start the flow execution in a synchronous context.
@@ -805,7 +797,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(run_flow())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
"""
Start the flow execution asynchronously.
@@ -840,7 +832,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
setattr(self._state, "id", inputs["id"]) # noqa: B010
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
@@ -1075,7 +1067,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
# Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method] + router_results
all_triggers = [trigger_method, *router_results]
for current_trigger in all_triggers:
if current_trigger: # Skip None results
@@ -1109,7 +1101,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
) -> list[str]:
"""
Finds all methods that should be triggered based on conditions.
@@ -1126,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns
-------
List[str]
list[str]
Names of methods that should be triggered.
Notes

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -13,23 +14,23 @@ class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
super().__init__(**data)
@@ -40,11 +41,10 @@ class Knowledge(BaseModel):
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -55,12 +55,11 @@ class Knowledge(BaseModel):
if self.storage is None:
raise ValueError("Storage is not initialized.")
results = self.storage.search(
return self.storage.search(
query,
limit=results_limit,
score_threshold=score_threshold,
)
return results
def add_sources(self):
try:

View File

@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")
results_limit: int = Field(default=5, description="The number of results to return")
score_threshold: float = Field(
default=0.35,
default=0.6,
description="The minimum score for a result to be considered relevant",
)

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC):
@@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -1,24 +1,17 @@
import hashlib
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import traceback
import warnings
from typing import Any, cast
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
class KnowledgeStorage(BaseKnowledgeStorage):
@@ -27,167 +20,105 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency.
"""
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._set_embedder_config(embedder)
self._client: BaseClient | None = None
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection:
fetched = self.collection.query(
query_texts=query,
n_results=limit,
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self.app = create_persistent_client(
path=os.path.join(db_storage_path(), "knowledge"),
settings=Settings(allow_reset=True),
)
if embedder:
embedding_function = get_embedding_function(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def search(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
if self.app:
self.collection = self.app.get_or_create_collection(
name=sanitize_collection_name(collection_name),
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")
except Exception:
raise Exception("Failed to create or get collection")
query_text = " ".join(query) if len(query) > 1 else query[0]
def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = create_persistent_client(
path=base_path, settings=Settings(allow_reset=True)
return client.search(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
self.app.reset()
shutil.rmtree(base_path)
self.app = None
self.collection = None
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
if not self.collection:
raise Exception("Collection not initialized")
try:
# Create a dictionary to store unique documents
unique_docs = {}
# Generate IDs and create a mapping of id -> (document, metadata)
for idx, doc in enumerate(documents):
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
doc_metadata = None
if metadata is not None:
if isinstance(metadata, list):
doc_metadata = metadata[idx]
else:
doc_metadata = metadata
unique_docs[doc_id] = (doc, doc_metadata)
# Prepare filtered lists for ChromaDB
filtered_docs = []
filtered_metadata = []
filtered_ids = []
# Build the filtered lists
for doc_id, (doc, meta) in unique_docs.items():
filtered_docs.append(doc)
filtered_metadata.append(meta)
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
self.collection.upsert(
documents=filtered_docs,
metadatas=final_metadata,
ids=filtered_ids,
)
except chromadb.errors.InvalidDimensionException as e:
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
except Exception as e:
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)
def save(self, documents: list[str]) -> None:
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
client.get_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
if embedder
else self._create_default_embedding_function()
)

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, List
from crewai.rag.types import SearchResult
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
"""Extract knowledge from the task prompt."""
valid_snippets = [
result["context"]
result["content"]
for result in knowledge_snippets
if result and result.get("context")
if result and result.get("content")
]
snippet = "\n".join(valid_snippets)
return f"Additional Information: {snippet}" if valid_snippets else ""

View File

@@ -1,4 +1,6 @@
from typing import Optional, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.memory import (
EntityMemory,
@@ -19,9 +21,9 @@ class ContextualMemory:
ltm: LongTermMemory,
em: EntityMemory,
exm: ExternalMemory,
agent: Optional["Agent"] = None,
task: Optional["Task"] = None,
):
agent: Agent | None = None,
task: Task | None = None,
) -> None:
self.stm = stm
self.ltm = ltm
self.em = em
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent
self.exm.task = self.task
def build_context_for_task(self, task, context) -> str:
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
@@ -52,14 +54,15 @@ class ContextualMemory:
if query == "":
return ""
context = []
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
context.append(self._fetch_external_context(query))
return "\n".join(filter(None, context))
context_parts = [
self._fetch_ltm_context(task.description),
self._fetch_stm_context(query),
self._fetch_entity_context(query),
self._fetch_external_context(query),
]
return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str:
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
@@ -70,11 +73,11 @@ class ContextualMemory:
stm_results = self.stm.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
def _fetch_ltm_context(self, task: str) -> str | None:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -90,14 +93,14 @@ class ContextualMemory:
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
def _fetch_entity_context(self, query: str) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
@@ -107,7 +110,7 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
@@ -128,6 +131,6 @@ class ContextualMemory:
return ""
formatted_memories = "\n".join(
f"- {result['context']}" for result in external_memories
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
saved_count = 0
errors = []
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item and return success status."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super(EntityMemory, self).save(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
success, error = save_single_item(item)
if success:
saved_count += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
raise Exception(
f"An error occurred while resetting the entity memory: {e}"
) from e

View File

@@ -1,41 +1,41 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
import time
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any):
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
@@ -12,8 +12,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
_agent: Optional["Agent"] = None
@@ -45,7 +45,7 @@ class Memory(BaseModel):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
@@ -54,9 +54,9 @@ class Memory(BaseModel):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
crewai_event_bus.emit(
self,
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)
) from e

View File

@@ -1,10 +1,13 @@
import os
from typing import Any, Dict, List
import re
from collections import defaultdict
from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from collections.abc import Iterable
from typing import Any
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
from crewai.memory.storage.interface import Storage
from crewai.rag.chromadb.utils import _sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -13,6 +16,7 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
super().__init__()
@@ -28,7 +32,8 @@ class Mem0Storage(Storage):
supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
f"Invalid type '{type}' for Mem0Storage. "
f"Must be one of: {', '.join(supported_types)}"
)
def _extract_config_values(self):
@@ -66,7 +71,8 @@ class Mem0Storage(Storage):
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
- Includes run_id if memory_type is 'short_term' and
mem0_run_id is present.
"""
filter = defaultdict(list)
@@ -86,21 +92,44 @@ class Mem0Storage(Storage):
return filter
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
return next(
(
m.get("content", "")
for m in reversed(list(messages))
if m.get("role") == role
),
"",
)
conversations = []
messages = metadata.pop("messages", None)
if messages:
last_user = _last_content(messages, "user")
last_assistant = _last_content(messages, "assistant")
if user_msg := self._get_user_message(last_user):
conversations.append({"role": "user", "content": user_msg})
if assistant_msg := self._get_assistant_message(last_assistant):
conversations.append({"role": "assistant", "content": assistant_msg})
else:
conversations.append({"role": "assistant", "content": value})
user_id = self.config.get("user_id", "")
assistant_message = [{"role" : "assistant","content" : value}]
base_metadata = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer
"infer": self.infer,
}
# MemoryClient-specific overrides
@@ -119,15 +148,17 @@ class Mem0Storage(Storage):
if agent_id := self.config.get("agent_id", self._get_agent_name()):
params["agent_id"] = agent_id
self.memory.add(assistant_message, **params)
self.memory.add(conversations, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
def search(
self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
"output_format": "v1.1",
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -148,10 +179,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold
params["threshold"] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format']
del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"):
del params["run_id"]
@@ -159,8 +190,8 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve
for result in results["results"]:
result["context"] = result["memory"]
result["content"] = result["memory"]
return [r for r in results["results"]]
def reset(self):
@@ -180,4 +211,19 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
return _sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)
def _get_assistant_message(self, text: str) -> str:
marker = "Final Answer:"
if marker in text:
return text.split(marker, 1)[1].strip()
return text
def _get_user_message(self, text: str) -> str:
pattern = r"User message:\s*(.*)"
match = re.search(pattern, text)
if match:
return match.group(1).strip()
return text

View File

@@ -1,17 +1,17 @@
import logging
import os
import shutil
import uuid
import traceback
import warnings
from typing import Any
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client
from crewai.rag.types import BaseRecord
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
class RAGStorage(BaseRAGStorage):
@@ -20,8 +20,6 @@ class RAGStorage(BaseRAGStorage):
search efficiency.
"""
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
@@ -33,37 +31,25 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._client: BaseClient | None = None
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
self._set_embedder_config()
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function)
self._client = create_client(config)
self.app = create_persistent_client(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
self.collection = self.app.get_or_create_collection(
name=self.type, embedding_function=self.embedder_config
)
logging.info(f"Collection found or created: {self.collection}")
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def _sanitize_role(self, role: str) -> str:
"""
@@ -85,77 +71,69 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
def save(self, value: Any, metadata: dict[str, Any]) -> None:
try:
self._generate_embedding(value, metadata)
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.get_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
try:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
return results
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return client.search(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
def reset(self) -> None:
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
self.app = None
self.collection = None
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
if "attempt to write a readonly database" in str(
e
) or "does not exist" in str(e):
# Ignore readonly database and collection not found errors (already reset)
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
) from e

View File

@@ -4,8 +4,9 @@ import logging
from typing import Any
from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
)
from chromadb.api.types import (
QueryResult,
)
from typing_extensions import Unpack
@@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import (
_process_query_results,
_sanitize_collection_name,
)
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.utilities.logger_utils import suppress_logging
class ChromaDBClient(BaseClient):
@@ -41,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction[Embeddable],
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -300,16 +309,18 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -342,15 +353,17 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
metadatas=metadatas,
)
def search(
@@ -385,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
@@ -443,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)

View File

@@ -1,20 +1,20 @@
"""ChromaDB configuration model."""
import os
import warnings
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
from chromadb.config import Settings
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
DEFAULT_TENANT,
)
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
warnings.filterwarnings(
"ignore",
@@ -49,7 +49,17 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
Returns:
Default embedding function using all-MiniLM-L6-v2 via ONNX.
"""
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return cast(
ChromaEmbeddingFunctionWrapper,
OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
),
)
@pyd_dataclass(frozen=True)

View File

@@ -2,11 +2,12 @@
import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
@@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""
persist_dir = config.settings.persist_directory
os.makedirs(persist_dir, exist_ok=True)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
@@ -37,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
return ChromaDBClient(
client=client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -3,27 +3,28 @@
from collections.abc import Mapping
from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include,
Loadable,
Where,
WhereDocument,
)
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
@@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple):
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings
metadatas: List of document metadata mappings (empty dict for no metadata)
"""
ids: list[str]
@@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable]
embedding_function: ChromaEmbeddingFunction
data_loader: DataLoader[Loadable]
get_or_create: bool

View File

@@ -5,13 +5,14 @@ from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
@@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb(
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
metadatas.append(metadata[0] if metadata and metadata[0] else {})
else:
metadatas.append(metadata)
else:
@@ -154,7 +155,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
include_strings = [item.value for item in include] if include else []
ids = results["ids"][0] if results.get("ids") else []
@@ -188,7 +189,9 @@ def _convert_chromadb_results_to_search_results(
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"metadata": dict(metadatas[i])
if metadatas and i < len(metadatas) and metadatas[i] is not None
else {},
"score": score,
}
search_results.append(result)
@@ -271,7 +274,7 @@ def _sanitize_collection_name(
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():

View File

@@ -14,3 +14,5 @@ class BaseRagConfig:
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)

View File

@@ -1,15 +1,15 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from typing import Annotated, Any, Protocol, runtime_checkable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Required, TypedDict, Unpack
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
EmbeddingFunction,
SearchResult,
)
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
metadata_filter: dict[str, Any] | None
score_threshold: float

View File

@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
@@ -60,7 +60,7 @@ def get_embedding_function(
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings (default)
- openai: OpenAI embeddings
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
@@ -77,7 +77,7 @@ def get_embedding_function(
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples:
# Use default OpenAI with retry logic
# Use default OpenAI embedding
>>> embedder = get_embedding_function()
# Use Cohere with dict

View File

@@ -6,8 +6,8 @@ from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params,
_process_search_results,
)
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")

View File

@@ -1,6 +1,7 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseRAGStorage(ABC):
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
self.type = type
@@ -32,45 +32,21 @@ class BaseRAGStorage(ABC):
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass

View File

@@ -1,83 +0,0 @@
import os
import re
import portalocker
from chromadb import PersistentClient
from hashlib import md5
from typing import Optional
from crewai.utilities.paths import db_storage_path
MIN_COLLECTION_LENGTH = 3
MAX_COLLECTION_LENGTH = 63
DEFAULT_COLLECTION = "default_collection"
# Compiled regex patterns for better performance
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
def is_ipv4_pattern(name: str) -> bool:
"""
Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def sanitize_collection_name(
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""
Sanitize a collection name to meet ChromaDB requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized
def create_persistent_client(path: str, **kwargs):
"""
Creates a persistent client for ChromaDB with a lock file to prevent
concurrent creations. Works for both multi-threads and multi-processes
environments.
"""
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(path=path, **kwargs)
return client

View File

@@ -9,19 +9,19 @@ import pytest
from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.process import Process
from crewai.tools import tool
from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController
from crewai.utilities.errors import AgentRepositoryError
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.process import Process
def test_agent_llm_creation_with_env_vars():
@@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_powered_by_new_o_model_family_that_uses_tool():
@tool
def comapny_customer_data() -> float:
def comapny_customer_data() -> str:
"""Useful for getting customer related data."""
return "The company has 42 customers"
@@ -500,6 +500,15 @@ def test_agent_custom_max_iterations():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage(capsys):
"""Test that agents handle repeated tool usage appropriately.
Notes:
Investigate whether to pin down the specific execution flow by examining
src/crewai/agents/crew_agent_executor.py:177-186 (max iterations check)
and src/crewai/tools/tool_usage.py:152-157 (repeated usage detection)
to ensure deterministic behavior.
"""
@tool
def get_final_answer() -> float:
"""Get the final answer but don't give it yet, just re-use this tool non-stop."""
@@ -527,41 +536,15 @@ def test_agent_repeated_tool_usage(capsys):
)
captured = capsys.readouterr()
output = (
captured.out.replace("\n", " ")
.replace(" ", " ")
.strip()
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("", "")
.replace("[", "")
.replace("]", "")
.replace("bold", "")
.replace("blue", "")
.replace("yellow", "")
.replace("green", "")
.replace("red", "")
.replace("dim", "")
.replace("🤖", "")
.replace("🔧", "")
.replace("", "")
.replace("\x1b[93m", "")
.replace("\x1b[00m", "")
.replace("\\", "")
.replace('"', "")
.replace("'", "")
)
output_lower = captured.out.lower()
# Look for the message in the normalized output, handling the apostrophe difference
expected_message = (
"I tried reusing the same input, I must stop using this action input."
has_repeated_usage_message = "tried reusing the same input" in output_lower
has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
)
assert (
expected_message in output
), f"Expected message not found in output. Output was: {output}"
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -602,9 +585,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out
assert (
has_repeated_usage_message or (has_max_iterations and has_final_answer)
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -880,7 +863,7 @@ def test_agent_step_callback():
with patch.object(StepCallback, "callback") as callback:
@tool
def learn_about_AI() -> str:
def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field."
@@ -888,7 +871,7 @@ def test_agent_step_callback():
role="test role",
goal="test goal",
backstory="test backstory",
tools=[learn_about_AI],
tools=[learn_about_ai],
step_callback=StepCallback().callback,
)
@@ -910,7 +893,7 @@ def test_agent_function_calling_llm():
llm = "gpt-4o"
@tool
def learn_about_AI() -> str:
def learn_about_ai() -> str:
"""Useful for when you need to learn about AI to write an paragraph about it."""
return "AI is a very broad field."
@@ -918,7 +901,7 @@ def test_agent_function_calling_llm():
role="test role",
goal="test goal",
backstory="test backstory",
tools=[learn_about_AI],
tools=[learn_about_ai],
llm="gpt-4o",
max_iter=2,
function_calling_llm=llm,
@@ -1356,7 +1339,7 @@ def test_agent_training_handler(crew_training_handler):
verbose=True,
)
crew_training_handler().load.return_value = {
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
f"{agent.id!s}": {"0": {"human_feedback": "good"}}
}
result = agent._training_handler(task_prompt=task_prompt)
@@ -1473,7 +1456,7 @@ def test_agent_with_custom_stop_words():
)
assert isinstance(agent.llm, LLM)
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
assert all(word in agent.llm.stop for word in stop_words)
assert "\nObservation:" in agent.llm.stop
@@ -1530,7 +1513,7 @@ def test_llm_call_with_error():
llm = LLM(model="non-existent-model")
messages = [{"role": "user", "content": "This should fail"}]
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
llm.call(messages)
@@ -1830,11 +1813,11 @@ def test_agent_execute_task_with_ollama():
def test_agent_with_knowledge_sources():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with patch("crewai.knowledge") as mock_knowledge:
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.search.return_value = [{"content": content}]
MockKnowledge.add_sources.return_value = [string_source]
mock_knowledge.add_sources.return_value = [string_source]
agent = Agent(
role="Information Agent",
@@ -1863,12 +1846,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query:
agent = Agent(
role="Information Agent",
@@ -1898,15 +1894,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
with patch.object(Knowledge, "query") as mock_knowledge_query:
string_source = StringKnowledgeSource(content=content)
knowledge_config = KnowledgeConfig()
agent = Agent(
role="Information Agent",
goal="Provide information based on knowledge sources",
@@ -1935,10 +1943,16 @@ def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with (
patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
) as mock_save,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
mock_save.return_value = None
agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters",
@@ -1968,8 +1982,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch(
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
autospec=True,
) as MockKnowledgeSource:
mock_knowledge_source_instance = MockKnowledgeSource.return_value
) as mock_knowledge_source:
mock_knowledge_source_instance = mock_knowledge_source.return_value
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
mock_knowledge_source_instance.sources = [string_source]
@@ -1983,9 +1997,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as MockKnowledgeStorage:
mock_knowledge_storage = MockKnowledgeStorage.return_value
agent.knowledge_storage = mock_knowledge_storage
) as mock_knowledge_storage:
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
agent.knowledge_storage = mock_knowledge_storage_instance
agent_copy = agent.copy()
@@ -2004,11 +2018,30 @@ def test_agent_with_knowledge_sources_generate_search_query():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
with patch("crewai.knowledge") as MockKnowledge:
mock_knowledge_instance = MockKnowledge.return_value
with (
patch("crewai.knowledge") as mock_knowledge,
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
):
mock_knowledge_instance = mock_knowledge.return_value
mock_knowledge_instance.sources = [string_source]
mock_knowledge_instance.query.return_value = [{"content": content}]
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
agent = Agent(
role="Information Agent with extensive role description that is longer than 80 characters",
goal="Provide information based on knowledge sources",
@@ -2270,7 +2303,26 @@ def test_get_knowledge_search_query():
i18n = I18N()
task_prompt = task.prompt()
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query:
with (
patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage,
patch(
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
) as mock_base_knowledge_storage,
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
):
mock_storage_instance = mock_knowledge_storage.return_value
mock_storage_instance.sources = [string_source]
mock_storage_instance.query.return_value = [{"content": content}]
mock_storage_instance.save.return_value = None
mock_chromadb_instance = mock_chromadb.return_value
mock_chromadb_instance.add_documents.return_value = None
mock_base_knowledge_storage.return_value = mock_storage_instance
mock_get_query.return_value = "Capital of France"
crew = Crew(agents=[agent], tasks=[task])
@@ -2312,9 +2364,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import (
SerperDevTool,
FileReadTool,
EnterpriseActionTool,
FileReadTool,
SerperDevTool,
)
mock_get_response = MagicMock()
@@ -2347,7 +2399,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
tool_action = EnterpriseActionTool(
name="test_name",
description="test_description",
enterprise_action_token="test_token",
enterprise_action_token="test_token", # noqa: S106
action_name="test_action_name",
action_schema={"test": "test"},
)

View File

@@ -0,0 +1,130 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
cookie:
- _cfuvid=NaXWifUGChHp6Ap1mvfMrNzmO4HdzddrqXkSR9T.hYo-1754508545647-0.0.1.1-604800000
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFLBbtswDL37Kzid4yFx46bxbVixtsfssB22wlAl2lEri5okJ+uK/Psg
OY3dtQV2MWA+vqf3SD5lAExJVgETWx5EZ3X++SrcY/Hrcec3l+SKP5frm/16Yx92m6/f9mwWGXR3
jyI8sz4K6qzGoMgMsHDIA0bVxaq8WCzPyrN5AjqSqCOttSFfUt4po/JiXizz+SpfXBzZW1ICPavg
RwYA8JS+0aeR+JtVkLRSpUPveYusOjUBMEc6Vhj3XvnATWCzERRkAppk/QYM7UFwA63aIXBoo23g
xu/RAfw0X5ThGj6l/wquUWuawXdyWn6YSjpses9jLNNrPQG4MRR4HEsKc3tEDif7mlrr6M7/Q2WN
Mspva4fck4lWfSDLEnrIAG7TmPoXyZl11NlQB3rA9NyiXA16bNzOFD2CgQLXk/qqmL2hV0sMXGk/
GTQTXGxRjtRxK7yXiiZANkn92s1b2kNyZdr/kR8BIdAGlLV1KJV4mXhscxiP972205STYebR7ZTA
Oih0cRMSG97r4aSYf/QBu7pRpkVnnRruqrF1eT7nzTmW5Zplh+wvAAAA//8DAGKunMhlAwAA
headers:
CF-RAY:
- 980b99a73c1c22c6-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:12:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=Ahwkw3J9CDiluZudRgDmybz4FO07eXLz2MQDtkgfct4-1758143531-1.0.1.1-_3e8agfTZW.FPpRMLb1A2nET4OHQEGKNZeGeWT8LIiuSi8R2HWsGsJyueUyzYBYnfHqsfBUO16K1.TkEo2XiqVCaIi6pymeeQxwtXFF1wj8;
path=/; expires=Wed, 17-Sep-25 21:42:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=iHqLoc_2sNQLMyzfGCLtGol8vf1Y44xirzQJUuUF_TI-1758143531242-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '419'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '609'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999827'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_ece5f999e09e4c189d38e5bc08b2fad9
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,128 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blBSmrabG0ICViqCCydYRbPOJDHreIztbIFV/zty
0m1SPiQukTJv3vN7M/OUAAhVixKE7DDI3ur09dvgfa8P/FO+e/9wa/aHb4cPH9viEw5yK1aRwfdf
SYZn1gvJvdUUFJsJlo4wUFTNd8U+32zybD0CPdekI621Id1w2iuj0nW23qTZLs33Z3bHSpIXJXxO
AACexm/0aWr6LkrIVs+VnrzHlkR5aQIQjnWsCPRe+YAmiNUMSjaBzGj9FgwfQaKBVj0SILTRNqDx
R3IAX8wbZVDDq/G/hI60Zjiy0/VS0FEzeIyhzKD1AkBjOGAcyhjl7oycLuY1t9bxvf+NKhpllO8q
R+jZRKM+sBUjekoA7sYhDVe5hXXc21AFfqDxubzYTXpi3s0CfXkGAwfUi/ruPNprvaqmgEr7xZiF
RNlRPVPnneBQK14AySL1n27+pj0lV6b9H/kZkJJsoLqyjmolrxPPbY7i6f6r7TLl0bDw5B6VpCoo
cnETNTU46OmghP/hA/VVo0xLzjo1XVVjq2KbYbOlorgRySn5BQAA//8DALxsmCBjAwAA
headers:
CF-RAY:
- 980ba79a4ab5f555-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:21:42 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=aMMf0fLckKHz0BLW_2lATxD.7R61uYo1ZVW8aeFbruA-1758144102-1.0.1.1-6EKM3UxpdczoiQ6VpPpqqVnY7ftnXndFRWE4vyTzVcy.CQ4N539D97Wh8Ye9EUAvpUuukhW.r5MznkXq4tPXgCCmEv44RvVz2GBAz_e31h8;
path=/; expires=Wed, 17-Sep-25 21:51:42 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=VqrtvU8.QdEHc4.1XXUVmccaCcoj_CiNfI2zhKJoGRs-1758144102566-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '308'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '620'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999827'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_fa896433021140238115972280c05651
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,127 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
is the expected criteria for your final answer: test output\nyou MUST return
the actual complete content as the final answer, not a summary.\n\nBegin! This
is VERY important to you, use the tools available and give your best Final Answer,
your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop": ["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '812'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrKxY8W4WV+JHoVgR95NJD4UvaBgJDrSS2FJclV3bSwP9e
kHYsuU2BXghwZ2c4s8vnDEDoWpQgVCdZ9c7kNx+4v7vb7jafnrrPX25/vtObX48f1Q31m+uFmEUG
PXxHxS+sN4p6Z5A12QOsPErGqFqsl1fF4nJdzBPQU40m0lrH+YLyXludX8wvFvl8nRdXR3ZHWmEQ
JXzNAACe0xl92hofRQlJK1V6DEG2KMpTE4DwZGJFyBB0YGlZzEZQkWW0yfotWNqBkhZavUWQ0Ebb
IG3YoQf4Zt9rKw28TfcSNhgYaGA3nAl6bIYgYyg7GDMBpLXEMg4lRbk/IvuTeUOt8/QQ/qCKRlsd
usqjDGSj0cDkREL3GcB9GtJwlls4T73jiukHpueK5eKgJ8bdTNDLI8jE0kzqq/XsFb2qRpbahMmY
hZKqw3qkjjuRQ61pAmST1H+7eU37kFzb9n/kR0ApdIx15TzWWp0nHts8xq/7r7bTlJNhEdBvtcKK
Nfq4iRobOZjD/kV4Cox91WjbondeH35V46rlai6bFS6X1yLbZ78BAAD//wMAZdfoWWMDAAA=
headers:
CF-RAY:
- 980b9e0c5fa516a0-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 17 Sep 2025 21:15:11 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=w6UZxbAZgYg9EFkKPfrSbMK97MB4jfs7YyvcEmgkvak-1758143711-1.0.1.1-j7YC1nvoMKxYK0T.5G2XDF6TXUCPu_HUs4YO9v65r3NHQFIcOaHbQXX4vqabSgynL2tZy23pbZgD8Cdmxhdw9dp4zkAXhU.imP43_pw4dSE;
path=/; expires=Wed, 17-Sep-25 21:45:11 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=ij9Q8tB7sj2GczANlJ7gbXVjj6hMhz1iVb6oGHuRYu8-1758143711202-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '462'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '665'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999830'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_04536db97c8c4768a200e38c1368c176
status:
code: 200
message: OK
version: 1

View File

@@ -1,7 +1,6 @@
"""Test Knowledge creation and querying functionality."""
from pathlib import Path
from typing import List, Union
from unittest.mock import patch
import pytest
@@ -23,7 +22,7 @@ def mock_vector_db():
instance = mock.return_value
instance.query.return_value = [
{
"context": "Brandon's favorite color is blue and he likes Mexican food.",
"content": "Brandon's favorite color is blue and he likes Mexican food.",
"score": 0.9,
}
]
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite color?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("blue" in result["context"].lower() for result in results)
assert any("blue" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
content=content, metadata={"preference": "personal"}
)
mock_vector_db.sources = [string_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
# Mock the vector db query response
mock_vector_db.query.return_value = [
{"context": "Brandon has a dog named Max.", "score": 0.9}
{"content": "Brandon has a dog named Max.", "score": 0.9}
]
mock_vector_db.sources = string_sources
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("max" in result["context"].lower() for result in results)
assert any("max" in result["content"].lower() for result in results)
# Verify the mock was called
mock_vector_db.query.assert_called_once()
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
]
mock_vector_db.sources = string_sources
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What sport does Brandon like?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("basketball" in result["context"].lower() for result in results)
assert any("basketball" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
file_paths=[file_path], metadata={"preference": "personal"}
)
mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite movie?"
results = mock_vector_db.query(query)
# Assert that the results contain the expected information
assert any("inception" in result["context"].lower() for result in results)
assert any("inception" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
]
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{"context": "Brandon lives in New York.", "score": 0.9}
{"content": "Brandon lives in New York.", "score": 0.9}
]
# Perform a query
query = "What city does he reside in?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("new york" in result["context"].lower() for result in results)
assert any("new york" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
mock_vector_db.sources = file_sources
mock_vector_db.query.return_value = [
{
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
"score": 0.9,
}
]
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved
assert any(
"the hitchhiker's guide to the galaxy" in result["context"].lower()
"the hitchhiker's guide to the galaxy" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
# Combine string and file sources
mock_vector_db.sources = string_sources + file_sources
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
# Perform a query
query = "What is Brandon's favorite book?"
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("the alchemist" in result["context"].lower() for result in results)
assert any("the alchemist" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
)
mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [
{"context": "crewai create crew latest-ai-development", "score": 0.9}
{"content": "crewai create crew latest-ai-development", "score": 0.9}
]
# Perform a query
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Assert that the correct information is retrieved
assert any(
"crewai create crew latest-ai-development" in result["context"].lower()
"crewai create crew latest-ai-development" in result["content"].lower()
for result in results
)
mock_vector_db.query.assert_called_once()
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [
{"context": "Alice lives in Los Angeles.", "score": 0.9}
{"content": "Alice lives in Los Angeles.", "score": 0.9}
]
# Perform a query
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("los angeles" in result["context"].lower() for result in results)
assert any("los angeles" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
"""Test ExcelKnowledgeSource with a simple Excel file."""
# Create an Excel file with sample data
import pandas as pd
import pandas as pd # type: ignore[import-untyped]
excel_data = {
"Name": ["Brandon", "Alice", "Bob"],
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
)
mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [
{"context": "Brandon is 30 years old.", "score": 0.9}
{"content": "Brandon is 30 years old.", "score": 0.9}
]
# Perform a query
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
results = mock_vector_db.query(query)
# Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results)
assert any("30" in result["content"] for result in results)
mock_vector_db.query.assert_called_once()
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
]
# Perform a query
query = "What is reward hacking?"
results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results)
assert any("reward hacking" in result["content"].lower() for result in results)
mock_vector_db.query.assert_called_once()
@pytest.mark.vcr
def test_multiple_docling_sources():
urls: List[Union[Path, str]] = [
def test_multiple_docling_sources() -> None:
urls: list[Path | str] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
]

View File

@@ -0,0 +1,191 @@
"""Tests for Knowledge SearchResult type conversion and integration."""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
StringKnowledgeSource,
)
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
extract_knowledge_context,
)
def test_knowledge_query_returns_searchresult() -> None:
"""Test that Knowledge.query returns SearchResult format."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = [
{
"content": "AI is fascinating",
"score": 0.9,
"metadata": {"source": "doc1"},
},
{
"content": "Machine learning rocks",
"score": 0.8,
"metadata": {"source": "doc2"},
},
]
sources = [StringKnowledgeSource(content="Test knowledge content")]
knowledge = Knowledge(collection_name="test_collection", sources=sources)
results = knowledge.query(
["AI technology"], results_limit=5, score_threshold=0.3
)
mock_storage.search.assert_called_once_with(
["AI technology"], limit=5, score_threshold=0.3
)
assert isinstance(results, list)
assert len(results) == 2
for result in results:
assert isinstance(result, dict)
assert "content" in result
assert "score" in result
assert "metadata" in result
assert results[0]["content"] == "AI is fascinating"
assert results[0]["score"] == 0.9
assert results[1]["content"] == "Machine learning rocks"
assert results[1]["score"] == 0.8
def test_knowledge_query_with_empty_results() -> None:
"""Test Knowledge.query with empty search results."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.return_value = []
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="empty_test", sources=sources)
results = knowledge.query(["nonexistent query"])
assert isinstance(results, list)
assert len(results) == 0
def test_extract_knowledge_context_with_searchresult() -> None:
"""Test extract_knowledge_context works with SearchResult format."""
search_results = [
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Python is great for AI" in context
assert "Machine learning algorithms" in context
assert "Deep learning frameworks" in context
expected_content = (
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
)
assert expected_content in context
def test_extract_knowledge_context_with_empty_content() -> None:
"""Test extract_knowledge_context handles empty or invalid content."""
search_results = [
{"content": "", "score": 0.5, "metadata": {}},
{"content": None, "score": 0.4, "metadata": {}},
{"score": 0.3, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert context == ""
def test_extract_knowledge_context_filters_invalid_results() -> None:
"""Test that extract_knowledge_context filters out invalid results."""
search_results: list[dict[str, Any] | None] = [
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
{"content": "", "score": 0.8, "metadata": {}},
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
None,
{"content": None, "score": 0.6, "metadata": {}},
]
context = extract_knowledge_context(search_results)
assert "Additional Information:" in context
assert "Valid content 1" in context
assert "Valid content 2" in context
assert context.count("\n") == 1
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_storage_exception_handling(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge handles storage exceptions gracefully."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.search.side_effect = Exception("Storage error")
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="error_test", sources=sources)
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.storage = None
knowledge.query(["test query"])
def test_knowledge_add_sources_integration() -> None:
"""Test Knowledge.add_sources integrates properly with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [
StringKnowledgeSource(content="Content 1"),
StringKnowledgeSource(content="Content 2"),
]
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
knowledge.add_sources()
for source in sources:
assert source.storage == mock_storage
def test_knowledge_reset_integration() -> None:
"""Test Knowledge.reset integrates with storage."""
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="reset_test", sources=sources)
knowledge.reset()
mock_storage.reset.assert_called_once()
@patch("crewai.rag.config.utils.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
def test_knowledge_reset_without_storage(
mock_storage_class: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test Knowledge.reset raises error when storage is None."""
sources = [StringKnowledgeSource(content="Test content")]
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
knowledge.storage = None
with pytest.raises(ValueError, match="Storage is not initialized"):
knowledge.reset()

View File

@@ -0,0 +1,196 @@
"""Integration tests for KnowledgeStorage RAG client migration."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_knowledge_storage_uses_rag_client(
mock_get_embedding: MagicMock,
mock_create_client: MagicMock,
mock_get_client: MagicMock,
) -> None:
"""Test that KnowledgeStorage properly integrates with RAG client."""
mock_client = MagicMock()
mock_create_client.return_value = mock_client
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
]
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
storage = KnowledgeStorage(
embedder=embedder_config, collection_name="test_knowledge"
)
mock_create_client.assert_called_once()
results = storage.search(["test query"], limit=5, score_threshold=0.3)
mock_get_client.assert_not_called()
mock_client.search.assert_called_once_with(
collection_name="knowledge_test_knowledge",
query="test query",
limit=5,
metadata_filter=None,
score_threshold=0.3,
)
assert isinstance(results, list)
assert len(results) == 1
assert isinstance(results[0], dict)
assert "content" in results[0]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
"""Test that collection names are properly prefixed."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage(collection_name="custom_knowledge")
storage.search(["test"], limit=1)
mock_client.search.assert_called_once()
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
mock_client.reset_mock()
storage_default = KnowledgeStorage()
storage_default.search(["test"], limit=1)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["collection_name"] == "knowledge"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
"""Test document saving through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_docs")
documents = ["Document 1 content", "Document 2 content"]
storage.save(documents)
mock_client.get_or_create_collection.assert_called_once_with(
collection_name="knowledge_test_docs"
)
mock_client.add_documents.assert_called_once()
call_kwargs = mock_client.add_documents.call_args.kwargs
added_docs = call_kwargs["documents"]
assert len(added_docs) == 2
assert added_docs[0]["content"] == "Document 1 content"
assert added_docs[1]["content"] == "Document 2 content"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_reset_integration(mock_get_client: MagicMock) -> None:
"""Test collection reset through RAG client."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="test_reset")
storage.reset()
mock_client.delete_collection.assert_called_once_with(
collection_name="knowledge_test_reset"
)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_search_error_handling(mock_get_client: MagicMock) -> None:
"""Test error handling during search operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = Exception("RAG client error")
storage = KnowledgeStorage(collection_name="error_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_embedding_configuration_flow(
mock_get_embedding: MagicMock, mock_get_client: MagicMock
) -> None:
"""Test that embedding configuration flows properly to RAG client."""
mock_embedding_func = MagicMock()
mock_get_embedding.return_value = mock_embedding_func
mock_get_client.return_value = MagicMock()
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
"""Test that query list is properly converted to string."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
storage.search(["single query"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "single query"
mock_client.reset_mock()
storage.search(["query one", "query two"])
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["query"] == "query one query two"
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
"""Test metadata filter parameter handling."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = []
storage = KnowledgeStorage()
metadata_filter = {"category": "technical", "priority": "high"}
storage.search(["test"], metadata_filter=metadata_filter)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] == metadata_filter
mock_client.reset_mock()
storage.search(["test"], metadata_filter=None)
call_kwargs = mock_client.search.call_args.kwargs
assert call_kwargs["metadata_filter"] is None
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
"""Test specific handling of dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
storage = KnowledgeStorage(collection_name="dimension_test")
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
storage.save(["test document"])

View File

@@ -1,19 +1,20 @@
from unittest.mock import patch, ANY
from collections import defaultdict
from unittest.mock import ANY, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.task import Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
@pytest.fixture
@@ -38,22 +39,23 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list)
with crewai_event_bus.scoped_handlers():
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
# Call the save method
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
# Call the save method
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory):
expected_result = [
{
"context": memory.data,
"content": memory.data,
"metadata": {"agent": "test_agent"},
"score": 0.95,
}
]
with patch.object(ShortTermMemory, "search", return_value=expected_result):
find = short_term_memory.search("test value", score_threshold=0.01)[0]
assert find["context"] == memory.data, "Data value mismatch."
assert find["content"] == memory.data, "Data value mismatch."
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."

View File

@@ -236,7 +236,7 @@ class TestChromaDBClient:
def test_add_documents(self, client, mock_chromadb_client) -> None:
"""Test that add_documents adds documents to collection."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -247,7 +247,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents)
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
@@ -262,7 +262,7 @@ class TestChromaDBClient:
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
"""Test add_documents with custom document IDs."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -285,6 +285,43 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
"""Test add_documents with documents that have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
client.add_documents(collection_name="test_collection", documents=documents)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
def test_add_documents_all_without_metadata(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents when all documents have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document 1"},
{"content": "Document 2"},
{"content": "Document 3"},
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] is None
def test_add_documents_empty_list_raises_error(
self, client, mock_chromadb_client
) -> None:
@@ -298,7 +335,7 @@ class TestChromaDBClient:
) -> None:
"""Test that aadd_documents adds documents to collection asynchronously."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -313,7 +350,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
@@ -331,7 +368,7 @@ class TestChromaDBClient:
) -> None:
"""Test aadd_documents with custom document IDs."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -358,6 +395,31 @@ class TestChromaDBClient:
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
@pytest.mark.asyncio
async def test_aadd_documents_without_metadata(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with documents that have no metadata."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
{"content": "Another document", "metadata": None},
{"content": "Document with metadata", "metadata": {"key": "value"}},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
# Verify upsert was called with empty dicts for missing metadata
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
@pytest.mark.asyncio
async def test_aadd_documents_empty_list_raises_error(
self, async_client, mock_async_chromadb_client
@@ -372,7 +434,7 @@ class TestChromaDBClient:
"""Test that search queries the collection correctly."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]],
@@ -382,13 +444,13 @@ class TestChromaDBClient:
results = client.search(collection_name="test_collection", query="test query")
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -404,7 +466,7 @@ class TestChromaDBClient:
"""Test search with optional parameters."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]],
@@ -437,7 +499,7 @@ class TestChromaDBClient:
"""Test that asearch queries the collection correctly."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(
@@ -453,13 +515,13 @@ class TestChromaDBClient:
collection_name="test_collection", query="test query"
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -478,7 +540,7 @@ class TestChromaDBClient:
"""Test asearch with optional parameters."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(

View File

@@ -0,0 +1,95 @@
"""Tests for ChromaDB utility functions."""
from crewai.rag.chromadb.utils import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
_is_ipv4_pattern,
_sanitize_collection_name,
)
class TestChromaDBUtils:
"""Test suite for ChromaDB utility functions."""
def test_sanitize_collection_name_long_name(self) -> None:
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = _sanitize_collection_name(long_name)
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_special_chars(self) -> None:
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = _sanitize_collection_name(special_chars)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_sanitize_collection_name_short_name(self) -> None:
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = _sanitize_collection_name(short_name)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_bad_ends(self) -> None:
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = _sanitize_collection_name(bad_ends)
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_none(self) -> None:
"""Test sanitizing a None value."""
sanitized = _sanitize_collection_name(None)
assert sanitized == "default_collection"
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = _sanitize_collection_name(ipv4)
assert sanitized.startswith("ip_")
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
def test_is_ipv4_pattern(self) -> None:
"""Test IPv4 pattern detection."""
assert _is_ipv4_pattern("192.168.1.1") is True
assert _is_ipv4_pattern("not.an.ip.address") is False
def test_sanitize_collection_name_properties(self) -> None:
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases: list[str] = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = _sanitize_collection_name(test_case)
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert len(sanitized) <= MAX_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
def test_sanitize_collection_name_empty_string(self) -> None:
"""Test sanitizing an empty string."""
sanitized = _sanitize_collection_name("")
assert sanitized == "default_collection"
def test_sanitize_collection_name_whitespace_only(self) -> None:
"""Test sanitizing a string with only whitespace."""
sanitized = _sanitize_collection_name(" ")
assert (
sanitized == "a__z"
) # Spaces become underscores, padded to meet requirements
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()

View File

@@ -0,0 +1,250 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(
provider="openai", api_key="test-key", model="text-embedding-3-large"
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
# OpenAI uses model_name parameter, not model
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch(
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
) as mock_st:
mock_instance = MagicMock()
mock_st.return_value = mock_instance
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
mock_instance = MagicMock()
mock_ollama.return_value = mock_instance
config = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
mock_instance = MagicMock()
mock_cohere.return_value = mock_instance
config = {
"provider": "cohere",
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
mock_instance = MagicMock()
mock_hf.return_value = mock_instance
config = {
"provider": "huggingface",
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
mock_instance = MagicMock()
mock_onnx.return_value = mock_instance
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch(
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
) as mock_palm:
mock_instance = MagicMock()
mock_palm.return_value = mock_instance
config = {"provider": "google-palm", "api_key": "palm-key"}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function."""
with patch(
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
) as mock_bedrock:
mock_instance = MagicMock()
mock_bedrock.return_value = mock_instance
config = {
"provider": "amazon-bedrock",
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
mock_instance = MagicMock()
mock_jina.return_value = mock_instance
config = {
"provider": "jina",
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch(
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
) as mock_instructor:
mock_instance = MagicMock()
mock_instructor.return_value = mock_instance
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance

View File

@@ -3,7 +3,8 @@
from unittest.mock import AsyncMock, Mock
import pytest
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client import AsyncQdrantClient
from qdrant_client import QdrantClient as SyncQdrantClient
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.client import QdrantClient
@@ -435,7 +436,7 @@ class TestQdrantClient:
call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False
@@ -540,7 +541,7 @@ class TestQdrantClient:
call_args = mock_async_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False

View File

@@ -0,0 +1,218 @@
"""Tests for RAG client error handling scenarios."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles RAG client connection failures."""
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
storage = KnowledgeStorage(collection_name="connection_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles search timeouts gracefully."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = TimeoutError("Search operation timed out")
storage = KnowledgeStorage(collection_name="timeout_test")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles missing collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = ValueError(
"Collection 'knowledge_missing' does not exist"
)
storage = KnowledgeStorage(collection_name="missing_collection")
results = storage.search(["test query"])
assert results == []
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles invalid embedding configurations."""
mock_get_client.return_value = MagicMock()
with patch(
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
) as mock_get_embedding:
mock_get_embedding.side_effect = ValueError(
"Unsupported provider: invalid_provider"
)
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
KnowledgeStorage(
embedder={"provider": "invalid_provider"},
collection_name="invalid_embedding_test",
)
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles RAG client failures in memory operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
storage = RAGStorage("short_term", crew=None)
results = storage.search("test query")
assert results == []
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles save operation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.add_documents.side_effect = Exception("Failed to add documents")
storage = RAGStorage("long_term", crew=None)
storage.save("test memory", {"key": "value"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage reset handles readonly database errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception(
"attempt to write a readonly database"
)
storage = KnowledgeStorage(collection_name="readonly_test")
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_reset_collection_does_not_exist(
mock_get_client: MagicMock,
) -> None:
"""Test KnowledgeStorage reset handles non-existent collections."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
storage = KnowledgeStorage(collection_name="nonexistent_test")
storage.reset()
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
"""Test RAGStorage reset propagates unexpected errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
storage = RAGStorage("entities", crew=None)
with pytest.raises(
Exception, match="An error occurred while resetting the entities memory"
):
storage.reset()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles malformed search results."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.return_value = [
{"content": "valid result", "metadata": {"source": "test"}},
{"invalid": "missing content field", "metadata": {"source": "test"}},
None,
{"content": None, "metadata": {"source": "test"}},
]
storage = KnowledgeStorage(collection_name="malformed_test")
results = storage.search(["test query"])
assert isinstance(results, list)
assert len(results) == 4
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
"""Test KnowledgeStorage handles network interruptions during operations."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.search.side_effect = [
ConnectionError("Network interruption"),
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
]
storage = KnowledgeStorage(collection_name="network_test")
first_attempt = storage.search(["test query"])
assert first_attempt == []
mock_client.search.side_effect = None
mock_client.search.return_value = [
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
]
second_attempt = storage.search(["test query"])
assert len(second_attempt) == 1
assert second_attempt[0]["content"] == "recovered result"
@patch("crewai.memory.storage.rag_storage.get_rag_client")
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
"""Test RAGStorage handles collection creation failures."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.side_effect = Exception(
"Failed to create collection"
)
storage = RAGStorage("user_memory", crew=None)
storage.save("test data", {"metadata": "test"})
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
mock_get_client: MagicMock,
) -> None:
"""Test detailed handling of embedding dimension mismatch errors."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.side_effect = Exception(
"Embedding dimension mismatch: expected 384, got 1536"
)
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
with pytest.raises(ValueError) as exc_info:
storage.save(["test document"])
assert "Embedding dimension mismatch" in str(exc_info.value)
assert "Make sure you're using the same embedding model" in str(exc_info.value)
assert "crewai reset-memories -a" in str(exc_info.value)

View File

@@ -1,8 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from mem0.client.main import MemoryClient
from mem0.memory.main import Memory
from mem0 import Memory, MemoryClient
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -13,11 +12,71 @@ class MockCrew:
self.agents = [MagicMock(role="Test Agent")]
# Test data constants
SYSTEM_CONTENT = (
"You are Friendly chatbot assistant. You are a kind and "
"knowledgeable chatbot assistant. You excel at understanding user needs, "
"providing helpful responses, and maintaining engaging conversations. "
"You remember previous interactions to provide a personalized experience.\n"
"Your personal goal is: Engage in useful and interesting conversations "
"with users while remembering context.\n"
"To give my best complete final answer to the task respond using the exact "
"following format:\n\n"
"Thought: I now can give a great answer\n"
"Final Answer: Your final answer must be the great and the most complete "
"as possible, it must be outcome described.\n\n"
"I MUST use these formats, my job depends on it!"
)
USER_CONTENT = (
"\nCurrent Task: Respond to user conversation. User message: "
"What do you know about me?\n\n"
"This is the expected criteria for your final answer: Contextually "
"appropriate, helpful, and friendly response.\n"
"you MUST return the actual complete content as the final answer, "
"not a summary.\n\n"
"# Useful context: \nExternal memories:\n"
"- User is from India\n"
"- User is interested in the solar system\n"
"- User name is Vidit Ostwal\n"
"- User is interested in French cuisine\n\n"
"Begin! This is VERY important to you, use the tools available and give "
"your best Final Answer, your job depends on it!\n\n"
"Thought:"
)
ASSISTANT_CONTENT = (
"I now can give a great answer \n"
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
TEST_DESCRIPTION = (
"Respond to user conversation. User message: What do you know about me?"
)
# Extracted content (after processing by _get_user_message and _get_assistant_message)
EXTRACTED_USER_CONTENT = "What do you know about me?"
EXTRACTED_ASSISTANT_CONTENT = (
"Hi Vidit! From our previous conversations, I know you're "
"from India and have a great interest in the solar system. It's fascinating "
"to explore the wonders of space, isn't it? Also, I remember you have a "
"passion for French cuisine, which has so many delightful dishes to explore. "
"If there's anything specific you'd like to discuss or learn about—whether "
"it's about the solar system or some great French recipes—feel free to let "
"me know! I'm here to help."
)
@pytest.fixture
def mock_mem0_memory():
"""Fixture to create a mock Memory instance"""
mock_memory = MagicMock(spec=Memory)
return mock_memory
return MagicMock(spec=Memory)
@pytest.fixture
@@ -25,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
# Patch the Memory class to return our mock
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
with patch(
"mem0.Memory.from_config", return_value=mock_mem0_memory
) as mock_from_config:
config = {
"vector_store": {
"provider": "mock_vector_store",
@@ -56,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
crew = MockCrew()
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True}
embedder_config = {
"user_id": "test_user",
"local_mem0_config": config,
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True,
}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
return mem0_storage, mock_from_config, config
@@ -73,8 +141,7 @@ def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_
@pytest.fixture
def mock_mem0_memory_client():
"""Fixture to create a mock MemoryClient instance"""
mock_memory = MagicMock(spec=MemoryClient)
return mock_memory
return MagicMock(spec=MemoryClient)
@pytest.fixture
@@ -85,34 +152,35 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew()
embedder_config={
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True
}
embedder_config = {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
"infer": True,
}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
return mem0_storage
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
@pytest.fixture
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
def mem0_storage_with_memory_client_using_explictly_config(
mock_mem0_memory_client, mock_mem0_memory
):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
# We need to patch both MemoryClient and Memory to prevent actual initialization
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
with (
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
):
crew = MockCrew()
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=new_config)
return mem0_storage
return Mem0Storage(type="short_term", crew=crew, config=new_config)
def test_mem0_storage_with_memory_client_initialization(
@@ -142,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
mock_mem0_memory_client.update_project = MagicMock()
new_categories = [
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
{
"lifestyle_management_concerns": (
"Tracks daily routines, habits, hobbies and interests "
"including cooking, time management and work-life balance"
)
},
]
crew = MockCrew()
config={
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"custom_categories": new_categories
}
config = {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"custom_categories": new_categories,
}
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
_ = Mem0Storage(type="short_term", crew=crew, config=config)
@@ -163,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
)
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
@@ -172,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={"type": "short_term", "key": "value"},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent'
agent_id="Test_Agent",
)
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
mem0_storage.crew.agents = [
MagicMock(role="Test Agent"),
MagicMock(role="Test Agent 2"),
MagicMock(role="Test Agent 3"),
]
mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={"type": "short_term", "key": "value"},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
)
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
def test_save_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}
test_metadata = {
"description": TEST_DESCRIPTION,
"messages": [
{"role": "system", "content": SYSTEM_CONTENT},
{"role": "user", "content": USER_CONTENT},
{"role": "assistant", "content": ASSISTANT_CONTENT},
],
"agent": "Friendly chatbot assistant",
}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': test_value}],
[
{"role": "user", "content": EXTRACTED_USER_CONTENT},
{
"role": "assistant",
"content": EXTRACTED_ASSISTANT_CONTENT,
},
],
infer=True,
metadata={"type": "short_term", "key": "value"},
metadata={
"type": "short_term",
"description": TEST_DESCRIPTION,
"agent": "Friendly chatbot assistant",
},
version="v2",
run_id="my_run_id",
includes="include1",
excludes="exclude1",
output_format='v1.1',
user_id='test_user',
agent_id='Test_Agent'
output_format="v1.1",
user_id="test_user",
agent_id="Test_Agent",
)
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test search method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -242,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
query="test query",
limit=5,
user_id="test_user",
filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5
filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
def test_search_method_with_memory_client(
mem0_storage_with_memory_client_using_config_from_crew,
):
"""Test search method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -263,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
limit=5,
metadata={"type": "short_term"},
user_id="test_user",
version='v2',
version="v2",
run_id="my_run_id",
output_format='v1.1',
filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5
output_format="v1.1",
filters={"AND": [{"run_id": "my_run_id"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
@@ -279,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew()
config={
"user_id": "test_user",
"api_key": "ABCDEFGH"
}
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = {
"agent_id": "agent-123",
@@ -297,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}],
[{"role": "assistant", "content": "test memory"}],
infer=True,
metadata={"type": "external", "key": "value"},
agent_id="agent-123",
)
def test_search_method_with_agent_entity():
config = {
"agent_id": "agent-123",
}
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
@@ -318,22 +466,29 @@ def test_search_method_with_agent_entity():
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id():
mock_memory = MagicMock(spec=Memory)
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
mock_results = {
"results": [
{"score": 0.9, "memory": "Result 1"},
{"score": 0.4, "memory": "Result 2"},
]
}
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
mem0_storage = Mem0Storage(
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
)
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
@@ -341,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id():
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id='user-123',
user_id="user-123",
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["context"] == "Result 1"
assert results[0]["content"] == "Result 1"

216
tests/test_context.py Normal file
View File

@@ -0,0 +1,216 @@
# ruff: noqa: S105
import os
import pytest
from unittest.mock import patch
from crewai.context import (
set_platform_integration_token,
get_platform_integration_token,
platform_context,
_platform_integration_token,
)
class TestPlatformIntegrationToken:
def setup_method(self):
_platform_integration_token.set(None)
def teardown_method(self):
_platform_integration_token.set(None)
def test_set_platform_integration_token(self):
test_token = "test-token-123"
assert get_platform_integration_token() is None
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
def test_get_platform_integration_token_from_context_var(self):
test_token = "context-var-token"
_platform_integration_token.set(test_token)
assert get_platform_integration_token() == test_token
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
def test_get_platform_integration_token_from_env_var(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() == "env-token-456"
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
def test_context_var_takes_precedence_over_env_var(self):
context_token = "context-token"
set_platform_integration_token(context_token)
assert get_platform_integration_token() == context_token
@patch.dict(os.environ, {}, clear=True)
def test_get_platform_integration_token_returns_none_when_not_set(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() is None
def test_platform_context_manager_basic_usage(self):
test_token = "context-manager-token"
assert get_platform_integration_token() is None
with platform_context(test_token):
assert get_platform_integration_token() == test_token
assert get_platform_integration_token() is None
def test_platform_context_manager_nested_contexts(self):
"""Test nested platform_context context managers."""
outer_token = "outer-token"
inner_token = "inner-token"
assert get_platform_integration_token() is None
with platform_context(outer_token):
assert get_platform_integration_token() == outer_token
with platform_context(inner_token):
assert get_platform_integration_token() == inner_token
assert get_platform_integration_token() == outer_token
assert get_platform_integration_token() is None
def test_platform_context_manager_preserves_existing_token(self):
"""Test that platform_context preserves existing token when exiting."""
initial_token = "initial-token"
context_token = "context-token"
set_platform_integration_token(initial_token)
assert get_platform_integration_token() == initial_token
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == initial_token
def test_platform_context_manager_exception_handling(self):
"""Test that platform_context properly resets token even when exception occurs."""
initial_token = "initial-token"
context_token = "context-token"
set_platform_integration_token(initial_token)
with pytest.raises(ValueError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise ValueError("Test exception")
assert get_platform_integration_token() == initial_token
def test_platform_context_manager_with_none_initial_state(self):
"""Test platform_context when initial state is None."""
context_token = "context-token"
assert get_platform_integration_token() is None
with pytest.raises(RuntimeError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise RuntimeError("Test exception")
assert get_platform_integration_token() is None
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
def test_platform_context_with_env_fallback(self):
"""Test platform_context interaction with environment variable fallback."""
context_token = "context-token"
assert get_platform_integration_token() == "env-backup"
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == "env-backup"
def test_multiple_sequential_context_managers(self):
"""Test multiple sequential uses of platform_context."""
token1 = "token-1"
token2 = "token-2"
token3 = "token-3"
with platform_context(token1):
assert get_platform_integration_token() == token1
assert get_platform_integration_token() is None
with platform_context(token2):
assert get_platform_integration_token() == token2
assert get_platform_integration_token() is None
with platform_context(token3):
assert get_platform_integration_token() == token3
assert get_platform_integration_token() is None
def test_empty_string_token(self):
empty_token = ""
set_platform_integration_token(empty_token)
assert get_platform_integration_token() == ""
with platform_context(empty_token):
assert get_platform_integration_token() == ""
def test_special_characters_in_token(self):
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
set_platform_integration_token(special_token)
assert get_platform_integration_token() == special_token
with platform_context(special_token):
assert get_platform_integration_token() == special_token
def test_very_long_token(self):
long_token = "a" * 10000
set_platform_integration_token(long_token)
assert get_platform_integration_token() == long_token
with platform_context(long_token):
assert get_platform_integration_token() == long_token
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
def test_empty_env_var(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() == ""
@patch('crewai.context.os.getenv')
def test_env_var_access_error_handling(self, mock_getenv):
mock_getenv.side_effect = OSError("Environment access error")
with pytest.raises(OSError):
get_platform_integration_token()
def test_context_var_isolation_between_tests(self):
"""Test that context variable changes don't leak between test methods."""
test_token = "isolation-test-token"
assert get_platform_integration_token() is None
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
def test_context_manager_return_value(self):
"""Test that platform_context can be used in with statement with return value."""
test_token = "return-value-token"
with platform_context(test_token):
assert get_platform_integration_token() == test_token
with platform_context(test_token) as ctx:
assert ctx is None
assert get_platform_integration_token() == test_token

View File

@@ -1,17 +1,20 @@
import os
from unittest.mock import MagicMock, Mock, patch
import pytest
from unittest.mock import patch, MagicMock
from crewai import Agent, Task, Crew
from crewai.flow.flow import Flow, start
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
from crewai import Agent, Crew, Task
from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
from crewai.events.listeners.tracing.trace_batch_manager import (
TraceBatchManager,
)
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.flow.flow import Flow, start
class TestTraceListenerSetup:
@@ -281,9 +284,9 @@ class TestTraceListenerSetup:
):
trace_handlers.append(handler)
assert (
len(trace_handlers) == 0
), f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
assert len(trace_handlers) == 0, (
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
)
def test_trace_listener_setup_correctly_for_crew(self):
"""Test that trace listener is set up correctly when enabled"""
@@ -403,3 +406,254 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
"""Test first-time user trace collection logic with timeout behavior"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
return_value=False,
) as mock_prompt,
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Say hello to the world",
expected_output="hello world",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], verbose=True)
from crewai.events.event_bus import crewai_event_bus
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
assert trace_listener.first_time_handler.collected_events is False
with (
patch.object(
trace_listener.first_time_handler,
"handle_execution_completion",
wraps=trace_listener.first_time_handler.handle_execution_completion,
) as mock_handle_completion,
patch.object(
trace_listener.batch_manager,
"add_event",
wraps=trace_listener.batch_manager.add_event,
) as mock_add_event,
):
result = crew.kickoff()
assert result is not None
assert mock_handle_completion.call_count >= 1
assert mock_add_event.call_count >= 1
assert trace_listener.first_time_handler.collected_events is True
mock_prompt.assert_called_once_with(timeout_seconds=20)
mock_mark_completed.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_user_accepts(self, mock_plus_api_calls):
"""Test first-time user trace collection when user accepts viewing traces"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Say hello to the world",
expected_output="hello world",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], verbose=True)
from crewai.events.event_bus import crewai_event_bus
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
with (
patch.object(
trace_listener.first_time_handler,
"_initialize_backend_and_send_events",
wraps=trace_listener.first_time_handler._initialize_backend_and_send_events,
) as mock_init_backend,
patch.object(
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
) as mock_display_link,
patch.object(
trace_listener.first_time_handler,
"handle_execution_completion",
wraps=trace_listener.first_time_handler.handle_execution_completion,
) as mock_handle_completion,
):
trace_listener.batch_manager.ephemeral_trace_url = (
"https://crewai.com/trace/mock-id"
)
crew.kickoff()
assert mock_handle_completion.call_count >= 1, (
"handle_execution_completion should be called"
)
assert trace_listener.first_time_handler.collected_events is True, (
"Events should be marked as collected"
)
mock_init_backend.assert_called_once()
mock_display_link.assert_called_once()
mock_mark_completed.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
"""Test the consolidation logic for first-time users vs regular tracing"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.utils.is_first_execution",
return_value=True,
),
):
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear()
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
assert trace_listener.first_time_handler.is_first_time is True
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm="gpt-4o-mini",
)
task = Task(
description="Test task", expected_output="test output", agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
result = crew.kickoff()
assert mock_initialize.call_count >= 1
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
assert result is not None
def test_first_time_handler_timeout_behavior(self):
"""Test the timeout behavior of the first-time trace prompt"""
with (
patch(
"crewai.events.listeners.tracing.utils._is_test_environment",
return_value=False,
),
patch("threading.Thread") as mock_thread,
):
from crewai.events.listeners.tracing.utils import (
prompt_user_for_trace_viewing,
)
mock_thread_instance = Mock()
mock_thread_instance.is_alive.return_value = True
mock_thread.return_value = mock_thread_instance
result = prompt_user_for_trace_viewing(timeout_seconds=5)
assert result is False
mock_thread.assert_called_once()
call_args = mock_thread.call_args
assert call_args[1]["daemon"] is True
mock_thread_instance.start.assert_called_once()
mock_thread_instance.join.assert_called_once_with(timeout=5)
mock_thread_instance.is_alive.assert_called_once()
def test_first_time_handler_graceful_error_handling(self):
"""Test graceful error handling in first-time trace logic"""
with (
patch(
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
return_value=True,
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
side_effect=Exception("Prompt failed"),
),
patch(
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
) as mock_mark_completed,
):
handler = FirstTimeTraceHandler()
handler.is_first_time = True
handler.collected_events = True
handler.handle_execution_completion()
mock_mark_completed.assert_called_once()

View File

@@ -1,123 +0,0 @@
import multiprocessing
import tempfile
import unittest
from chromadb.config import Settings
from unittest.mock import patch, MagicMock
from crewai.utilities.chromadb import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
is_ipv4_pattern,
sanitize_collection_name,
create_persistent_client,
)
def persistent_client_worker(path, queue):
try:
create_persistent_client(path=path)
queue.put(None)
except Exception as e:
queue.put(e)
class TestChromadbUtils(unittest.TestCase):
def test_sanitize_collection_name_long_name(self):
"""Test sanitizing a very long collection name."""
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
sanitized = sanitize_collection_name(long_name)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_special_chars(self):
"""Test sanitizing a name with special characters."""
special_chars = "Agent@123!#$%^&*()"
sanitized = sanitize_collection_name(special_chars)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_sanitize_collection_name_short_name(self):
"""Test sanitizing a very short name."""
short_name = "A"
sanitized = sanitize_collection_name(short_name)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_bad_ends(self):
"""Test sanitizing a name with non-alphanumeric start/end."""
bad_ends = "_Agent_"
sanitized = sanitize_collection_name(bad_ends)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_sanitize_collection_name_none(self):
"""Test sanitizing a None value."""
sanitized = sanitize_collection_name(None)
self.assertEqual(sanitized, "default_collection")
def test_sanitize_collection_name_ipv4_pattern(self):
"""Test sanitizing an IPv4 address."""
ipv4 = "192.168.1.1"
sanitized = sanitize_collection_name(ipv4)
self.assertTrue(sanitized.startswith("ip_"))
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
def test_is_ipv4_pattern(self):
"""Test IPv4 pattern detection."""
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
def test_sanitize_collection_name_properties(self):
"""Test that sanitized collection names always meet ChromaDB requirements."""
test_cases = [
"A" * 100, # Very long name
"_start_with_underscore",
"end_with_underscore_",
"contains@special#characters",
"192.168.1.1", # IPv4 address
"a" * 2, # Too short
]
for test_case in test_cases:
sanitized = sanitize_collection_name(test_case)
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
self.assertTrue(sanitized[0].isalnum())
self.assertTrue(sanitized[-1].isalnum())
def test_create_persistent_client_passes_args(self):
with patch(
"crewai.utilities.chromadb.PersistentClient"
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
mock_instance = MagicMock()
mock_persistent_client.return_value = mock_instance
settings = Settings(allow_reset=True)
client = create_persistent_client(path=tmpdir, settings=settings)
mock_persistent_client.assert_called_once_with(
path=tmpdir, settings=settings
)
self.assertIs(client, mock_instance)
def test_create_persistent_client_process_safe(self):
with tempfile.TemporaryDirectory() as tmpdir:
queue = multiprocessing.Queue()
processes = [
multiprocessing.Process(
target=persistent_client_worker, args=(tmpdir, queue)
)
for _ in range(5)
]
[p.start() for p in processes]
[p.join() for p in processes]
errors = [queue.get(timeout=5) for _ in processes]
self.assertTrue(all(err is None for err in errors))

View File

@@ -29,13 +29,15 @@ def mock_knowledge_source():
"""
return StringKnowledgeSource(content=content)
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_knowledge_included_in_planning(mock_chroma):
@patch("crewai.rag.config.utils.get_rag_client")
def test_knowledge_included_in_planning(mock_get_client):
"""Test that verifies knowledge sources are properly included in planning."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
mock_collection.add.return_value = None
# Mock RAG client
mock_client = mock_get_client.return_value
mock_client.get_or_create_collection.return_value = None
mock_client.add_documents.return_value = None
# Create an agent with knowledge
agent = Agent(
role="AI Researcher",
@@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma):
StringKnowledgeSource(
content="AI systems require careful training and validation."
)
]
],
)
# Create a task for the agent
task = Task(
description="Explain the basics of AI systems",
expected_output="A clear explanation of AI fundamentals",
agent=agent
agent=agent,
)
# Create a crew planner
@@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma):
task_summary = planner._create_tasks_summary()
# Verify that knowledge is included in planning when present
assert "AI systems require careful training" in task_summary, \
assert "AI systems require careful training" in task_summary, (
"Knowledge content should be present in task summary when knowledge exists"
assert '"agent_knowledge"' in task_summary, \
)
assert '"agent_knowledge"' in task_summary, (
"agent_knowledge field should be present in task summary when knowledge exists"
)
# Verify that knowledge is properly formatted
assert isinstance(task.agent.knowledge_sources, list), \
assert isinstance(task.agent.knowledge_sources, list), (
"Knowledge sources should be stored in a list"
assert len(task.agent.knowledge_sources) > 0, \
)
assert len(task.agent.knowledge_sources) > 0, (
"At least one knowledge source should be present"
assert task.agent.knowledge_sources[0].content in task_summary, \
)
assert task.agent.knowledge_sources[0].content in task_summary, (
"Knowledge source content should be included in task summary"
)
# Verify that other expected components are still present
assert task.description in task_summary, \
assert task.description in task_summary, (
"Task description should be present in task summary"
assert task.expected_output in task_summary, \
)
assert task.expected_output in task_summary, (
"Expected output should be present in task summary"
assert agent.role in task_summary, \
"Agent role should be present in task summary"
)
assert agent.role in task_summary, "Agent role should be present in task summary"