mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 21:28:29 +00:00
Compare commits
17 Commits
devin/1757
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c47ff15bf6 | ||
|
|
270e0b6edd | ||
|
|
a0cbb5cfdb | ||
|
|
2f682e1564 | ||
|
|
d4aa676195 | ||
|
|
578fa8c2e4 | ||
|
|
6f5af2b27c | ||
|
|
8ee3cf4874 | ||
|
|
f2d3fd0c0f | ||
|
|
f28e78c5ba | ||
|
|
81bd81e5f5 | ||
|
|
1b00cc71ef | ||
|
|
45d0c9912c | ||
|
|
1f1ab14b07 | ||
|
|
1a70f1698e | ||
|
|
8883fb656b | ||
|
|
79d65e55a1 |
102
.github/workflows/codeql.yml
vendored
Normal file
102
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL Advanced"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze (${{ matrix.language }})
|
||||
# Runner size impacts CodeQL analysis time. To learn more, please see:
|
||||
# - https://gh.io/recommended-hardware-resources-for-running-codeql
|
||||
# - https://gh.io/supported-runners-and-hardware-resources
|
||||
# - https://gh.io/using-larger-runners (GitHub.com only)
|
||||
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
|
||||
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
# only required for workflows in private repositories
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- language: actions
|
||||
build-mode: none
|
||||
- language: python
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Add any setup steps before running the `github/codeql-action/init` action.
|
||||
# This includes steps like installing compilers or runtimes (`actions/setup-node`
|
||||
# or others). This is typically only required for manual builds.
|
||||
# - name: Setup runtime (example)
|
||||
# uses: actions/setup-example@v1
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
# If the analyze step fails for one of the languages you are analyzing with
|
||||
# "We were unable to automatically build your code", modify the matrix above
|
||||
# to set the build mode to "manual" for that language. Then modify this step
|
||||
# to build your code.
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
- if: matrix.build-mode == 'manual'
|
||||
shell: bash
|
||||
run: |
|
||||
echo 'If you are using a "manual" build mode for one or more of the' \
|
||||
'languages you are analyzing, replace this with the commands to build' \
|
||||
'your code, for example:'
|
||||
echo ' make bootstrap'
|
||||
echo ' make release'
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
29
.github/workflows/tests.yml
vendored
29
.github/workflows/tests.yml
vendored
@@ -22,6 +22,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
@@ -45,14 +47,41 @@ jobs:
|
||||
- name: Install the project
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Restore test durations
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: .test_durations_py*
|
||||
key: test-durations-py${{ matrix.python-version }}
|
||||
|
||||
- name: Run tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||
DURATION_FILE=".test_durations_py${PYTHON_VERSION_SAFE}"
|
||||
|
||||
# Temporarily always skip cached durations to fix test splitting
|
||||
# When durations don't match, pytest-split runs duplicate tests instead of splitting
|
||||
echo "Using even test splitting (duration cache disabled until fix merged)"
|
||||
DURATIONS_ARG=""
|
||||
|
||||
# Original logic (disabled temporarily):
|
||||
# if [ ! -f "$DURATION_FILE" ]; then
|
||||
# echo "No cached durations found, tests will be split evenly"
|
||||
# DURATIONS_ARG=""
|
||||
# elif git diff origin/${{ github.base_ref }}...HEAD --name-only 2>/dev/null | grep -q "^tests/.*\.py$"; then
|
||||
# echo "Test files have changed, skipping cached durations to avoid mismatches"
|
||||
# DURATIONS_ARG=""
|
||||
# else
|
||||
# echo "No test changes detected, using cached test durations for optimal splitting"
|
||||
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
|
||||
# fi
|
||||
|
||||
uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
71
.github/workflows/update-test-durations.yml
vendored
Normal file
71
.github/workflows/update-test-durations.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Update Test Durations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'tests/**/*.py'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update-durations:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py${{ matrix.python-version }}-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Run all tests and store durations
|
||||
run: |
|
||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||
uv run pytest --store-durations --durations-path=.test_durations_py${PYTHON_VERSION_SAFE} -n auto
|
||||
continue-on-error: true
|
||||
|
||||
- name: Save durations to cache
|
||||
if: always()
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: .test_durations_py*
|
||||
key: test-durations-py${{ matrix.python-version }}
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
@@ -131,13 +131,14 @@ select = [
|
||||
"I001", # sort imports
|
||||
"I002", # remove unused imports
|
||||
]
|
||||
ignore = ["E501"] # ignore line too long
|
||||
ignore = ["E501"] # ignore line too long globally
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
|
||||
"tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ["src/crewai/cli/templates", "tests"]
|
||||
exclude = ["src/crewai/cli/templates", "tests/"]
|
||||
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -1,47 +1,56 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""LangGraph agent adapter for CrewAI integration.
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
This module contains the LangGraphAgentAdapter class that integrates LangGraph ReAct agents
|
||||
with CrewAI's agent system. Provides memory persistence, tool integration, and structured
|
||||
output functionality.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
|
||||
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
|
||||
from crewai.agents.agent_adapters.langgraph.langgraph_tool_adapter import (
|
||||
LangGraphToolAdapter,
|
||||
)
|
||||
from crewai.agents.agent_adapters.langgraph.protocols import (
|
||||
LangGraphCheckPointMemoryModule,
|
||||
LangGraphPrebuiltModule,
|
||||
)
|
||||
from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
|
||||
LangGraphConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
LANGGRAPH_AVAILABLE = True
|
||||
except ImportError:
|
||||
LANGGRAPH_AVAILABLE = False
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for LangGraph agents to work with CrewAI."""
|
||||
"""Adapter for LangGraph agents to work with CrewAI.
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
This adapter integrates LangGraph's ReAct agents with CrewAI's agent system,
|
||||
providing memory persistence, tool integration, and structured output support.
|
||||
"""
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||
_tool_adapter: LangGraphToolAdapter = PrivateAttr()
|
||||
_graph: Any = PrivateAttr(default=None)
|
||||
_memory: Any = PrivateAttr(default=None)
|
||||
_max_iterations: int = PrivateAttr(default=10)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
step_callback: Callable[..., Any] | None = Field(default=None)
|
||||
|
||||
model: str = Field(default="gpt-4o")
|
||||
verbose: bool = Field(default=False)
|
||||
@@ -51,17 +60,24 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
raise ImportError(
|
||||
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
)
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter.
|
||||
|
||||
Args:
|
||||
role: The role description for the agent.
|
||||
goal: The primary goal the agent should achieve.
|
||||
backstory: Background information about the agent.
|
||||
tools: Optional list of tools available to the agent.
|
||||
llm: Language model to use, defaults to gpt-4o.
|
||||
max_iterations: Maximum number of iterations for task execution.
|
||||
agent_config: Additional configuration for the LangGraph agent.
|
||||
**kwargs: Additional arguments passed to the base adapter.
|
||||
"""
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
@@ -72,46 +88,65 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = LangGraphToolAdapter(tools=tools)
|
||||
self._converter_adapter = LangGraphConverterAdapter(self)
|
||||
self._converter_adapter: LangGraphConverterAdapter = LangGraphConverterAdapter(
|
||||
self
|
||||
)
|
||||
self._max_iterations = max_iterations
|
||||
self._setup_graph()
|
||||
|
||||
def _setup_graph(self) -> None:
|
||||
"""Set up the LangGraph workflow graph."""
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
"""Set up the LangGraph workflow graph.
|
||||
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools,
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
**self._agent_config,
|
||||
)
|
||||
else:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools or [],
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
)
|
||||
Initializes the memory saver and creates a ReAct agent with the configured
|
||||
tools, memory checkpointer, and debug settings.
|
||||
"""
|
||||
|
||||
except ImportError as e:
|
||||
self._logger.log(
|
||||
"error", f"Failed to import LangGraph dependencies: {str(e)}"
|
||||
memory_saver: type[Any] = cast(
|
||||
LangGraphCheckPointMemoryModule,
|
||||
require(
|
||||
"langgraph.checkpoint.memory",
|
||||
purpose="LangGraph core functionality",
|
||||
),
|
||||
).MemorySaver
|
||||
create_react_agent: Callable[..., Any] = cast(
|
||||
LangGraphPrebuiltModule,
|
||||
require(
|
||||
"langgraph.prebuilt",
|
||||
purpose="LangGraph core functionality",
|
||||
),
|
||||
).create_react_agent
|
||||
|
||||
self._memory = memory_saver()
|
||||
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools,
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
**self._agent_config,
|
||||
)
|
||||
else:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools or [],
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
|
||||
raise
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
"""Build a system prompt for the LangGraph agent.
|
||||
|
||||
Creates a prompt that includes the agent's role, goal, and backstory,
|
||||
then enhances it through the converter adapter for structured output.
|
||||
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -123,10 +158,25 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
"""Execute a task using the LangGraph workflow.
|
||||
|
||||
Configures the agent, processes the task through the LangGraph workflow,
|
||||
and handles event emission for execution tracking.
|
||||
|
||||
Args:
|
||||
task: The task object to execute.
|
||||
context: Optional context information for the task.
|
||||
tools: Optional additional tools for this specific execution.
|
||||
|
||||
Returns:
|
||||
The final answer from the task execution.
|
||||
|
||||
Raises:
|
||||
Exception: If task execution fails.
|
||||
"""
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
self.configure_structured_output(task)
|
||||
@@ -151,9 +201,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
session_id = f"task_{id(task)}"
|
||||
|
||||
config = {"configurable": {"thread_id": session_id}}
|
||||
config: dict[str, dict[str, str]] = {
|
||||
"configurable": {"thread_id": session_id}
|
||||
}
|
||||
|
||||
result = self._graph.invoke(
|
||||
result: dict[str, Any] = self._graph.invoke(
|
||||
{
|
||||
"messages": [
|
||||
("system", self._build_system_prompt()),
|
||||
@@ -163,10 +215,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
config,
|
||||
)
|
||||
|
||||
messages = result.get("messages", [])
|
||||
last_message = messages[-1] if messages else None
|
||||
messages: list[Any] = result.get("messages", [])
|
||||
last_message: Any = messages[-1] if messages else None
|
||||
|
||||
final_answer = ""
|
||||
final_answer: str = ""
|
||||
if isinstance(last_message, dict):
|
||||
final_answer = last_message.get("content", "")
|
||||
elif hasattr(last_message, "content"):
|
||||
@@ -186,7 +238,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -197,29 +249,67 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the LangGraph agent for execution.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the agent.
|
||||
"""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the LangGraph agent.
|
||||
|
||||
Merges additional tools with existing ones and updates the graph's
|
||||
available tools through the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional additional tools to configure.
|
||||
"""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
|
||||
self._tool_adapter.configure_tools(all_tools)
|
||||
available_tools = self._tool_adapter.tools()
|
||||
available_tools: list[Any] = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph.
|
||||
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
Args:
|
||||
agents: List of agents available for delegation.
|
||||
|
||||
Returns:
|
||||
List of delegation tools.
|
||||
"""
|
||||
agent_tools: AgentTools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
@staticmethod
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: Any, instructions: str
|
||||
) -> Any:
|
||||
"""Convert output format if needed."""
|
||||
llm: Any, text: str, model: Any, instructions: str
|
||||
) -> Converter:
|
||||
"""Convert output format if needed.
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
text: Text to convert.
|
||||
model: Model configuration.
|
||||
instructions: Conversion instructions.
|
||||
|
||||
Returns:
|
||||
Converter instance for output transformation.
|
||||
"""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph.
|
||||
|
||||
Uses the converter adapter to set up structured output formatting
|
||||
based on the task requirements.
|
||||
|
||||
Args:
|
||||
task: Task object containing output requirements.
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,38 +1,72 @@
|
||||
"""LangGraph tool adapter for CrewAI tool integration.
|
||||
|
||||
This module contains the LangGraphToolAdapter class that converts CrewAI tools
|
||||
to LangGraph-compatible format using langchain_core.tools.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LangGraphToolAdapter(BaseToolAdapter):
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format.
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
Converts CrewAI BaseTool instances to langchain_core.tools format
|
||||
that can be used by LangGraph agents.
|
||||
"""
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Initialize the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional list of CrewAI tools to adapt.
|
||||
"""
|
||||
Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
LangGraph expects tools in langchain_core.tools format.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
super().__init__()
|
||||
self.original_tools: list[BaseTool] = tools or []
|
||||
self.converted_tools: list[Any] = []
|
||||
|
||||
converted_tools = []
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
|
||||
LangGraph expects tools in langchain_core.tools format. This method
|
||||
converts CrewAI BaseTool instances to StructuredTool instances.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to convert.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool as LangChainBaseTool
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
converted_tools: list[Any] = []
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
all_tools: list[BaseTool] = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
for tool in all_tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
if isinstance(tool, LangChainBaseTool):
|
||||
converted_tools.append(tool)
|
||||
continue
|
||||
|
||||
sanitized_name = self.sanitize_tool_name(tool.name)
|
||||
sanitized_name: str = self.sanitize_tool_name(tool.name)
|
||||
|
||||
async def tool_wrapper(*args, tool=tool, **kwargs):
|
||||
output = None
|
||||
async def tool_wrapper(
|
||||
*args: Any, tool: BaseTool = tool, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrapper function to adapt CrewAI tool calls to LangGraph format.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
tool: The CrewAI tool to wrap.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result from the tool execution.
|
||||
"""
|
||||
output: Any | Awaitable[Any]
|
||||
if len(args) > 0 and isinstance(args[0], str):
|
||||
output = tool.run(args[0])
|
||||
elif "input" in kwargs:
|
||||
@@ -41,12 +75,12 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
output = tool.run(**kwargs)
|
||||
|
||||
if inspect.isawaitable(output):
|
||||
result = await output
|
||||
result: Any = await output
|
||||
else:
|
||||
result = output
|
||||
return result
|
||||
|
||||
converted_tool = StructuredTool(
|
||||
converted_tool: StructuredTool = StructuredTool(
|
||||
name=sanitized_name,
|
||||
description=tool.description,
|
||||
func=tool_wrapper,
|
||||
@@ -57,5 +91,10 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
|
||||
self.converted_tools = converted_tools
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
def tools(self) -> list[Any]:
|
||||
"""Get the list of converted tools.
|
||||
|
||||
Returns:
|
||||
List of LangGraph-compatible tools.
|
||||
"""
|
||||
return self.converted_tools or []
|
||||
|
||||
55
src/crewai/agents/agent_adapters/langgraph/protocols.py
Normal file
55
src/crewai/agents/agent_adapters/langgraph/protocols.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Type protocols for LangGraph modules."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphMemorySaver(Protocol):
|
||||
"""Protocol for LangGraph MemorySaver.
|
||||
|
||||
Defines the interface for LangGraph's memory persistence mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the memory saver."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphCheckPointMemoryModule(Protocol):
|
||||
"""Protocol for LangGraph checkpoint memory module.
|
||||
|
||||
Defines the interface for modules containing memory checkpoint functionality.
|
||||
"""
|
||||
|
||||
MemorySaver: type[LangGraphMemorySaver]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphPrebuiltModule(Protocol):
|
||||
"""Protocol for LangGraph prebuilt module.
|
||||
|
||||
Defines the interface for modules containing prebuilt agent factories.
|
||||
"""
|
||||
|
||||
def create_react_agent(
|
||||
self,
|
||||
model: Any,
|
||||
tools: list[Any],
|
||||
checkpointer: Any,
|
||||
debug: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Create a ReAct agent with the given configuration.
|
||||
|
||||
Args:
|
||||
model: The language model to use for the agent.
|
||||
tools: List of tools available to the agent.
|
||||
checkpointer: Memory checkpointer for state persistence.
|
||||
debug: Whether to enable debug mode.
|
||||
**kwargs: Additional configuration options.
|
||||
|
||||
Returns:
|
||||
The configured ReAct agent instance.
|
||||
"""
|
||||
...
|
||||
@@ -1,21 +1,45 @@
|
||||
"""LangGraph structured output converter for CrewAI task integration.
|
||||
|
||||
This module contains the LangGraphConverterAdapter class that handles structured
|
||||
output conversion for LangGraph agents, supporting JSON and Pydantic model formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
"""Adapter for handling structured output conversion in LangGraph agents.
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._system_prompt_appendix = None
|
||||
Converts task output requirements into system prompt modifications and
|
||||
post-processing logic to ensure agents return properly structured outputs.
|
||||
"""
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter.
|
||||
|
||||
Args:
|
||||
agent_adapter: The LangGraph agent adapter instance.
|
||||
"""
|
||||
super().__init__(agent_adapter=agent_adapter)
|
||||
self.agent_adapter: Any = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: str | None = None
|
||||
self._system_prompt_appendix: str | None = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph.
|
||||
|
||||
Analyzes the task's output requirements and sets up the necessary
|
||||
formatting and validation logic.
|
||||
|
||||
Args:
|
||||
task: The task object containing output format specifications.
|
||||
"""
|
||||
if not (task.output_json or task.output_pydantic):
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
@@ -32,7 +56,14 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
self._system_prompt_appendix = self._generate_system_prompt_appendix()
|
||||
|
||||
def _generate_system_prompt_appendix(self) -> str:
|
||||
"""Generate an appendix for the system prompt to enforce structured output"""
|
||||
"""Generate an appendix for the system prompt to enforce structured output.
|
||||
|
||||
Creates instructions that are appended to the system prompt to guide
|
||||
the agent in producing properly formatted output.
|
||||
|
||||
Returns:
|
||||
System prompt appendix string, or empty string if no structured output.
|
||||
"""
|
||||
if not self._output_format or not self._schema:
|
||||
return ""
|
||||
|
||||
@@ -41,19 +72,36 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
def enhance_system_prompt(self, original_prompt: str) -> str:
|
||||
"""Add structured output instructions to the system prompt if needed"""
|
||||
"""Add structured output instructions to the system prompt if needed.
|
||||
|
||||
Args:
|
||||
original_prompt: The base system prompt.
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with structured output instructions.
|
||||
"""
|
||||
if not self._system_prompt_appendix:
|
||||
return original_prompt
|
||||
|
||||
return f"{original_prompt}\n{self._system_prompt_appendix}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format"""
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
|
||||
Attempts to extract and validate JSON content from agent responses,
|
||||
handling cases where JSON may be wrapped in markdown or other formatting.
|
||||
|
||||
Args:
|
||||
result: The raw result string from the agent.
|
||||
|
||||
Returns:
|
||||
Processed result string, ideally in valid JSON format.
|
||||
"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
@@ -65,16 +113,16 @@ The output should be raw JSON that exactly matches the specified schema.
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# Try to extract JSON from the text
|
||||
import re
|
||||
|
||||
json_match = re.search(r"(\{.*\})", result, re.DOTALL)
|
||||
json_match: re.Match[str] | None = re.search(
|
||||
r"(\{.*})", result, re.DOTALL
|
||||
)
|
||||
if json_match:
|
||||
try:
|
||||
extracted = json_match.group(1)
|
||||
extracted: str = json_match.group(1)
|
||||
# Validate it's proper JSON
|
||||
json.loads(extracted)
|
||||
return extracted
|
||||
except:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,78 +1,99 @@
|
||||
from typing import Any, List, Optional
|
||||
"""OpenAI agents adapter for CrewAI integration.
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
This module contains the OpenAIAgentAdapter class that integrates OpenAI Assistants
|
||||
with CrewAI's agent system, providing tool integration and structured output support.
|
||||
"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
|
||||
from crewai.agents.agent_adapters.openai_agents.openai_agent_tool_adapter import (
|
||||
OpenAIAgentToolAdapter,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
AgentKwargs,
|
||||
OpenAIAgentsModule,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
OpenAIAgent as OpenAIAgentProtocol,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.structured_output_converter import (
|
||||
OpenAIConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
try:
|
||||
from agents import Agent as OpenAIAgent # type: ignore
|
||||
from agents import Runner, enable_verbose_stdout_logging # type: ignore
|
||||
|
||||
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
|
||||
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
openai_agents_module = cast(
|
||||
OpenAIAgentsModule,
|
||||
require(
|
||||
"agents",
|
||||
purpose="OpenAI agents functionality",
|
||||
),
|
||||
)
|
||||
OpenAIAgent = openai_agents_module.Agent
|
||||
Runner = openai_agents_module.Runner
|
||||
enable_verbose_stdout_logging = openai_agents_module.enable_verbose_stdout_logging
|
||||
|
||||
|
||||
class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for OpenAI Assistants"""
|
||||
"""Adapter for OpenAI Assistants.
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
Integrates OpenAI Assistants API with CrewAI's agent system, providing
|
||||
tool configuration, structured output handling, and task execution.
|
||||
"""
|
||||
|
||||
_openai_agent: "OpenAIAgent" = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
_active_thread: Optional[str] = PrivateAttr(default=None)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_openai_agent: OpenAIAgentProtocol = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||
_active_thread: str | None = PrivateAttr(default=None)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
_tool_adapter: OpenAIAgentToolAdapter = PrivateAttr()
|
||||
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
)
|
||||
else:
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
**kwargs: Unpack[AgentKwargs],
|
||||
) -> None:
|
||||
"""Initialize the OpenAI agent adapter.
|
||||
|
||||
Args:
|
||||
**kwargs: All initialization arguments including role, goal, backstory,
|
||||
model, tools, and agent_config.
|
||||
|
||||
Raises:
|
||||
ImportError: If OpenAI agent dependencies are not installed.
|
||||
"""
|
||||
self.llm = kwargs.pop("model", "gpt-4o-mini")
|
||||
super().__init__(**kwargs)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=kwargs.get("tools"))
|
||||
self._converter_adapter = OpenAIConverterAdapter(agent_adapter=self)
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
"""Build a system prompt for the OpenAI agent.
|
||||
|
||||
Creates a prompt containing the agent's role, goal, and backstory,
|
||||
then enhances it with structured output instructions if needed.
|
||||
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -84,10 +105,25 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
"""Execute a task using the OpenAI Assistant.
|
||||
|
||||
Configures the assistant, processes the task, and handles event emission
|
||||
for execution tracking.
|
||||
|
||||
Args:
|
||||
task: The task object to execute.
|
||||
context: Optional context information for the task.
|
||||
tools: Optional additional tools for this execution.
|
||||
|
||||
Returns:
|
||||
The final answer from the task execution.
|
||||
|
||||
Raises:
|
||||
Exception: If task execution fails.
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
@@ -95,7 +131,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
enable_verbose_stdout_logging()
|
||||
|
||||
try:
|
||||
task_prompt = task.prompt()
|
||||
task_prompt: str = task.prompt()
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
@@ -109,8 +145,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer = self.handle_execution_result(result)
|
||||
result: Any = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer: str = self.handle_execution_result(result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
@@ -120,7 +156,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -131,15 +167,22 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
we can use this method to set up tools and configurations.
|
||||
"""
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the OpenAI agent for execution.
|
||||
|
||||
instructions = self._build_system_prompt()
|
||||
While OpenAI handles execution differently through Runner,
|
||||
this method sets up tools and agent configuration.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the agent.
|
||||
|
||||
Notes:
|
||||
TODO: Properly type agent_executor in BaseAgent to avoid type issues
|
||||
when assigning Runner class to this attribute.
|
||||
"""
|
||||
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
|
||||
|
||||
instructions: str = self._build_system_prompt()
|
||||
self._openai_agent = OpenAIAgent(
|
||||
name=self.role,
|
||||
instructions=instructions,
|
||||
@@ -152,27 +195,48 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the assistant.
|
||||
"""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
"""Process OpenAI Assistant execution result.
|
||||
|
||||
Converts any structured output to a string through the converter adapter.
|
||||
|
||||
Args:
|
||||
result: The execution result from the OpenAI assistant.
|
||||
|
||||
Returns:
|
||||
Processed result as a string.
|
||||
"""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support.
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
Args:
|
||||
agents: List of agents available for delegation.
|
||||
|
||||
Returns:
|
||||
List of delegation tools.
|
||||
"""
|
||||
agent_tools: AgentTools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
task: The task object containing output format specifications.
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,57 +1,125 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
"""OpenAI agent tool adapter for CrewAI tool integration.
|
||||
|
||||
from agents import FunctionTool, Tool
|
||||
This module contains the OpenAIAgentToolAdapter class that converts CrewAI tools
|
||||
to OpenAI Assistant-compatible format using the agents library.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
OpenAIFunctionTool,
|
||||
OpenAITool,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
agents_module = cast(
|
||||
Any,
|
||||
require(
|
||||
"agents",
|
||||
purpose="OpenAI agents functionality",
|
||||
),
|
||||
)
|
||||
FunctionTool = agents_module.FunctionTool
|
||||
Tool = agents_module.Tool
|
||||
|
||||
|
||||
class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
"""Adapter for OpenAI Assistant tools"""
|
||||
"""Adapter for OpenAI Assistant tools.
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.original_tools = tools or []
|
||||
Converts CrewAI BaseTool instances to OpenAI Assistant FunctionTool format
|
||||
that can be used by OpenAI agents.
|
||||
"""
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Initialize the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional list of CrewAI tools to adapt.
|
||||
"""
|
||||
super().__init__()
|
||||
self.original_tools: list[BaseTool] = tools or []
|
||||
self.converted_tools: list[OpenAITool] = []
|
||||
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant.
|
||||
|
||||
Merges provided tools with original tools and converts them to
|
||||
OpenAI Assistant format.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to configure.
|
||||
"""
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
all_tools: list[BaseTool] = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
if all_tools:
|
||||
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
|
||||
|
||||
@staticmethod
|
||||
def _convert_tools_to_openai_format(
|
||||
self, tools: Optional[List[BaseTool]]
|
||||
) -> List[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format"""
|
||||
tools: list[BaseTool] | None,
|
||||
) -> list[OpenAITool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to convert.
|
||||
|
||||
Returns:
|
||||
List of OpenAI Assistant FunctionTool instances.
|
||||
"""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
def sanitize_tool_name(name: str) -> str:
|
||||
"""Convert tool name to match OpenAI's required pattern"""
|
||||
import re
|
||||
"""Convert tool name to match OpenAI's required pattern.
|
||||
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
return sanitized
|
||||
Args:
|
||||
name: Original tool name.
|
||||
|
||||
def create_tool_wrapper(tool: BaseTool):
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface"""
|
||||
Returns:
|
||||
Sanitized tool name matching OpenAI requirements.
|
||||
"""
|
||||
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
|
||||
def create_tool_wrapper(tool: BaseTool) -> Any:
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface.
|
||||
|
||||
Args:
|
||||
tool: The CrewAI tool to wrap.
|
||||
|
||||
Returns:
|
||||
Async wrapper function for OpenAI agent integration.
|
||||
"""
|
||||
|
||||
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
|
||||
"""Wrapper function to adapt CrewAI tool calls to OpenAI format.
|
||||
|
||||
Args:
|
||||
context_wrapper: OpenAI context wrapper.
|
||||
arguments: Tool arguments from OpenAI.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
# Get the parameter name from the schema
|
||||
param_name = list(
|
||||
tool.args_schema.model_json_schema()["properties"].keys()
|
||||
)[0]
|
||||
param_name: str = next(
|
||||
iter(tool.args_schema.model_json_schema()["properties"].keys())
|
||||
)
|
||||
|
||||
# Handle different argument types
|
||||
args_dict: dict[str, Any]
|
||||
if isinstance(arguments, dict):
|
||||
args_dict = arguments
|
||||
elif isinstance(arguments, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
args_dict = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
args_dict = {param_name: arguments}
|
||||
@@ -59,11 +127,11 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
args_dict = {param_name: str(arguments)}
|
||||
|
||||
# Run the tool with the processed arguments
|
||||
output = tool._run(**args_dict)
|
||||
output: Any | Awaitable[Any] = tool._run(**args_dict)
|
||||
|
||||
# Await if the tool returned a coroutine
|
||||
if inspect.isawaitable(output):
|
||||
result = await output
|
||||
result: Any = await output
|
||||
else:
|
||||
result = output
|
||||
|
||||
@@ -74,17 +142,20 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
|
||||
return wrapper
|
||||
|
||||
openai_tools = []
|
||||
openai_tools: list[OpenAITool] = []
|
||||
for tool in tools:
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
schema: dict[str, Any] = tool.args_schema.model_json_schema()
|
||||
|
||||
schema.update({"additionalProperties": False, "type": "object"})
|
||||
|
||||
openai_tool = FunctionTool(
|
||||
name=sanitize_tool_name(tool.name),
|
||||
description=tool.description,
|
||||
params_json_schema=schema,
|
||||
on_invoke_tool=create_tool_wrapper(tool),
|
||||
openai_tool: OpenAIFunctionTool = cast(
|
||||
OpenAIFunctionTool,
|
||||
FunctionTool(
|
||||
name=sanitize_tool_name(tool.name),
|
||||
description=tool.description,
|
||||
params_json_schema=schema,
|
||||
on_invoke_tool=create_tool_wrapper(tool),
|
||||
),
|
||||
)
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
|
||||
74
src/crewai/agents/agent_adapters/openai_agents/protocols.py
Normal file
74
src/crewai/agents/agent_adapters/openai_agents/protocols.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Type protocols for OpenAI agents modules."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class AgentKwargs(TypedDict, total=False):
|
||||
"""Typed dict for agent initialization kwargs."""
|
||||
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
model: str
|
||||
tools: list[BaseTool] | None
|
||||
agent_config: dict[str, Any] | None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIAgent(Protocol):
|
||||
"""Protocol for OpenAI Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
instructions: str,
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI agent."""
|
||||
...
|
||||
|
||||
tools: list[Any]
|
||||
output_type: Any
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIRunner(Protocol):
|
||||
"""Protocol for OpenAI Runner."""
|
||||
|
||||
@classmethod
|
||||
def run_sync(cls, agent: OpenAIAgent, message: str) -> Any:
|
||||
"""Run agent synchronously with a message."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIAgentsModule(Protocol):
|
||||
"""Protocol for OpenAI agents module."""
|
||||
|
||||
Agent: type[OpenAIAgent]
|
||||
Runner: type[OpenAIRunner]
|
||||
enable_verbose_stdout_logging: Callable[[], None]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAITool(Protocol):
|
||||
"""Protocol for OpenAI Tool."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIFunctionTool(Protocol):
|
||||
"""Protocol for OpenAI FunctionTool."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params_json_schema: dict[str, Any],
|
||||
on_invoke_tool: Any,
|
||||
) -> None:
|
||||
"""Initialize the function tool."""
|
||||
...
|
||||
@@ -1,5 +1,12 @@
|
||||
"""OpenAI structured output converter for CrewAI task integration.
|
||||
|
||||
This module contains the OpenAIConverterAdapter class that handles structured
|
||||
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -7,8 +14,7 @@ from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
"""
|
||||
Adapter for handling structured output conversion in OpenAI agents.
|
||||
"""Adapter for handling structured output conversion in OpenAI agents.
|
||||
|
||||
This adapter enhances the OpenAI agent to handle structured output formats
|
||||
and post-processes the results when needed.
|
||||
@@ -19,19 +25,23 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_model: The Pydantic model for the output
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter.
|
||||
|
||||
Args:
|
||||
task: The task containing output format requirements
|
||||
agent_adapter: The OpenAI agent adapter instance.
|
||||
"""
|
||||
super().__init__(agent_adapter=agent_adapter)
|
||||
self.agent_adapter: Any = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: str | None = None
|
||||
self._output_model: Any = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
Args:
|
||||
task: The task containing output format requirements.
|
||||
"""
|
||||
# Reset configuration
|
||||
self._output_format = None
|
||||
@@ -55,19 +65,18 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
self._output_model = task.output_pydantic
|
||||
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""
|
||||
Enhance the base system prompt with structured output requirements if needed.
|
||||
"""Enhance the base system prompt with structured output requirements if needed.
|
||||
|
||||
Args:
|
||||
base_prompt: The original system prompt
|
||||
base_prompt: The original system prompt.
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with output format instructions if needed
|
||||
Enhanced system prompt with output format instructions if needed.
|
||||
"""
|
||||
if not self._output_format:
|
||||
return base_prompt
|
||||
|
||||
output_schema = (
|
||||
output_schema: str = (
|
||||
I18N()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=self._schema)
|
||||
@@ -76,16 +85,15 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
|
||||
This method attempts to extract valid JSON from the result if necessary.
|
||||
|
||||
Args:
|
||||
result: The raw result from the agent
|
||||
result: The raw result from the agent.
|
||||
|
||||
Returns:
|
||||
Processed result conforming to the expected output format
|
||||
Processed result conforming to the expected output format.
|
||||
"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
@@ -97,26 +105,30 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# Try to extract JSON from markdown code blocks
|
||||
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)```"
|
||||
code_blocks = re.findall(code_block_pattern, result)
|
||||
code_block_pattern: str = r"```(?:json)?\s*([\s\S]*?)```"
|
||||
code_blocks: list[str] = re.findall(code_block_pattern, result)
|
||||
|
||||
for block in code_blocks:
|
||||
stripped_block = block.strip()
|
||||
try:
|
||||
json.loads(block.strip())
|
||||
return block.strip()
|
||||
json.loads(stripped_block)
|
||||
return stripped_block
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
pass
|
||||
|
||||
# Try to extract any JSON-like structure
|
||||
json_pattern = r"(\{[\s\S]*\})"
|
||||
json_matches = re.findall(json_pattern, result, re.DOTALL)
|
||||
json_pattern: str = r"(\{[\s\S]*\})"
|
||||
json_matches: list[str] = re.findall(json_pattern, result, re.DOTALL)
|
||||
|
||||
for match in json_matches:
|
||||
is_valid = True
|
||||
try:
|
||||
json.loads(match)
|
||||
return match
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
return match
|
||||
|
||||
# If all extraction attempts fail, return the original
|
||||
return str(result)
|
||||
|
||||
25
src/crewai/context.py
Normal file
25
src/crewai/context.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"platform_integration_token", default=None
|
||||
)
|
||||
|
||||
def set_platform_integration_token(integration_token: str) -> None:
|
||||
_platform_integration_token.set(integration_token)
|
||||
|
||||
def get_platform_integration_token() -> Optional[str]:
|
||||
token = _platform_integration_token.get()
|
||||
if token is None:
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||
return token
|
||||
|
||||
@contextmanager
|
||||
def platform_context(integration_token: str):
|
||||
token = _platform_integration_token.set(integration_token)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_platform_integration_token.reset(token)
|
||||
@@ -3,26 +3,17 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -39,26 +30,15 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -70,16 +50,28 @@ from crewai.events.types.crew_events import (
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -94,28 +86,40 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||
Represents a group of agents, defining how they should collaborate and the
|
||||
tasks they should perform.
|
||||
|
||||
Attributes:
|
||||
tasks: List of tasks assigned to the crew.
|
||||
agents: List of agents part of this crew.
|
||||
tasks: list of tasks assigned to the crew.
|
||||
agents: list of agents part of this crew.
|
||||
manager_llm: The language model that will run manager agent.
|
||||
manager_agent: Custom agent that will be used as manager.
|
||||
memory: Whether the crew should use memory to store memories of it's execution.
|
||||
cache: Whether the crew should use a cache to store the results of the tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling for all the agents.
|
||||
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
|
||||
memory: Whether the crew should use memory to store memories of it's
|
||||
execution.
|
||||
cache: Whether the crew should use a cache to store the results of the
|
||||
tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling
|
||||
for all the agents.
|
||||
process: The process flow that the crew will follow (e.g., sequential,
|
||||
hierarchical).
|
||||
verbose: Indicates the verbosity level for logging during execution.
|
||||
config: Configuration settings for the crew.
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to
|
||||
be respected.
|
||||
prompt_file: Path to the prompt json file to be used for the crew.
|
||||
id: A unique identifier for the crew instance.
|
||||
task_callback: Callback to be executed after each task for every agents execution.
|
||||
step_callback: Callback to be executed after each step for every agents execution.
|
||||
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
|
||||
task_callback: Callback to be executed after each task for every agents
|
||||
execution.
|
||||
step_callback: Callback to be executed after each step for every agents
|
||||
execution.
|
||||
share_crew: Whether you want to share the complete crew information and
|
||||
execution with crewAI to make the library better, and allow us to
|
||||
train models.
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
security_config: Security configuration for the crew, including fingerprinting.
|
||||
chat_llm: The language model used for orchestrating chat interactions
|
||||
with the crew.
|
||||
security_config: Security configuration for the crew, including
|
||||
fingerprinting.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
@@ -124,13 +128,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
|
||||
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
|
||||
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
|
||||
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
|
||||
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
_logging_color: str = PrivateAttr(
|
||||
default="bold_purple",
|
||||
)
|
||||
@@ -138,107 +142,121 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=TaskOutputStorageHandler
|
||||
)
|
||||
|
||||
name: Optional[str] = Field(default="crew")
|
||||
name: str | None = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether the crew should use memory to store memories of it's execution",
|
||||
description="If crew should use memory to store memories of it's execution",
|
||||
)
|
||||
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
|
||||
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the ShortTermMemory to be used by the Crew",
|
||||
)
|
||||
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
|
||||
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the LongTermMemory to be used by the Crew",
|
||||
)
|
||||
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
|
||||
entity_memory: InstanceOf[EntityMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the EntityMemory to be used by the Crew",
|
||||
)
|
||||
external_memory: Optional[InstanceOf[ExternalMemory]] = Field(
|
||||
external_memory: InstanceOf[ExternalMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: Optional[dict] = Field(
|
||||
embedder: dict | None = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
usage_metrics: Optional[UsageMetrics] = Field(
|
||||
usage_metrics: UsageMetrics | None = Field(
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
config: Json | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: Optional[bool] = Field(default=False)
|
||||
step_callback: Optional[Any] = Field(
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Optional[Any] = Field(
|
||||
task_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: List[
|
||||
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[dict[str, Any] | None], dict[str, Any] | None]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
|
||||
description=(
|
||||
"List of callbacks to be executed before crew kickoff. "
|
||||
"It may be used to adjust inputs before the crew is executed."
|
||||
),
|
||||
)
|
||||
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
description=(
|
||||
"List of callbacks to be executed after crew kickoff. "
|
||||
"It may be used to adjust the output of the crew."
|
||||
),
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||
description=(
|
||||
"Maximum number of requests per minute for the crew execution "
|
||||
"to be respected."
|
||||
),
|
||||
)
|
||||
prompt_file: Optional[str] = Field(
|
||||
prompt_file: str | None = Field(
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
output_log_file: bool | str | None = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
planning: Optional[bool] = Field(
|
||||
planning: bool | None = Field(
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
),
|
||||
)
|
||||
task_execution_output_json_files: Optional[List[str]] = Field(
|
||||
task_execution_output_json_files: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of file paths for task execution JSON files.",
|
||||
description="list of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: List[Dict[str, Any]] = Field(
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="List of execution logs for tasks",
|
||||
description="list of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
description=(
|
||||
"Knowledge sources for the crew. Add knowledge sources to the "
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None,
|
||||
description="Knowledge for the crew.",
|
||||
)
|
||||
@@ -246,18 +264,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
token_usage: Optional[UsageMetrics] = Field(
|
||||
token_usage: UsageMetrics | None = Field(
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
tracing: Optional[bool] = Field(
|
||||
tracing: bool | None = Field(
|
||||
default=False,
|
||||
description="Whether to enable tracing for the crew.",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
"""Prevent manual setting of the 'id' field by users."""
|
||||
if v:
|
||||
raise PydanticCustomError(
|
||||
@@ -266,9 +284,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(
|
||||
cls, v: Union[Json, Dict[str, Any]]
|
||||
) -> Union[Json, Dict[str, Any]]:
|
||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -281,12 +297,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
"""set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener()
|
||||
|
||||
if is_tracing_enabled() or self.tracing:
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
event_listener.verbose = self.verbose
|
||||
@@ -314,7 +334,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
# External memory does not support a default value since it was
|
||||
# designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
)
|
||||
|
||||
@@ -355,7 +376,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
raise PydanticCustomError(
|
||||
"missing_manager_llm_or_manager_agent",
|
||||
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.",
|
||||
(
|
||||
"Attribute `manager_llm` or `manager_agent` is required "
|
||||
"when using hierarchical process."
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -398,7 +422,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
(
|
||||
f"Sequential process error: Agent is missing in the task "
|
||||
f"with the following description: {task.description}"
|
||||
), # type: ignore # Dynamic string in error message
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -459,7 +486,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
(
|
||||
f"Conditional Task: {task.description}, "
|
||||
f"cannot be executed asynchronously."
|
||||
),
|
||||
{},
|
||||
)
|
||||
return self
|
||||
@@ -478,7 +508,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for j in range(i - 1, -1, -1):
|
||||
if self.tasks[j] == context_task:
|
||||
raise ValueError(
|
||||
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
|
||||
f"Task '{task.description}' is asynchronous and "
|
||||
f"cannot include other sequential asynchronous "
|
||||
f"tasks in its context."
|
||||
)
|
||||
if not self.tasks[j].async_execution:
|
||||
break
|
||||
@@ -496,13 +528,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
continue # Skip context tasks not in the main tasks list
|
||||
if task_indices[id(context_task)] > task_indices[id(task)]:
|
||||
raise ValueError(
|
||||
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
|
||||
f"Task '{task.description}' has a context dependency "
|
||||
f"on a future task '{context_task.description}', "
|
||||
f"which is not allowed."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
task.key for task in self.tasks
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -518,9 +552,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self):
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
if self.config is None:
|
||||
raise ValueError("Config should not be None.")
|
||||
if not self.config.get("agents") or not self.config.get("tasks"):
|
||||
raise PydanticCustomError(
|
||||
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
|
||||
@@ -530,7 +564,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.agents = [Agent(**agent) for agent in self.config["agents"]]
|
||||
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
|
||||
|
||||
def _create_task(self, task_config: Dict[str, Any]) -> Task:
|
||||
def _create_task(self, task_config: dict[str, Any]) -> Task:
|
||||
"""Creates a task instance from its configuration.
|
||||
|
||||
Args:
|
||||
@@ -559,7 +593,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
|
||||
self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
inputs = inputs or {}
|
||||
@@ -611,7 +645,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
@@ -682,9 +716,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results: List[CrewOutput] = []
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results."""
|
||||
results: list[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
@@ -703,14 +737,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
self, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
|
||||
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
@@ -739,7 +771,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tasks=self.tasks, planning_agent_llm=self.planning_llm
|
||||
)._handle_crew_planning()
|
||||
|
||||
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task):
|
||||
for task, step_plan in zip(
|
||||
self.tasks, result.list_of_plans_per_task, strict=False
|
||||
):
|
||||
task.description += step_plan.plan
|
||||
|
||||
def _store_execution_log(
|
||||
@@ -776,7 +810,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _run_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
|
||||
"""Creates and assigns a manager agent to complete the tasks."""
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
@@ -807,23 +841,24 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: List[Task],
|
||||
start_index: Optional[int] = 0,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks sequentially and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks (List[Task]): List of tasks to execute
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for
|
||||
delegation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
@@ -838,7 +873,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
@@ -847,7 +884,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -867,7 +904,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -879,7 +916,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -893,11 +930,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
) -> TaskOutput | None:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
@@ -917,8 +954,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
@@ -947,22 +984,22 @@ class Crew(FlowTrackable, BaseModel):
|
||||
):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
# Return a List[BaseTool] compatible with Task.execute_sync and execute_async
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
|
||||
if self.process == Process.hierarchical:
|
||||
return self.manager_agent
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates."""
|
||||
if not new_tools:
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -973,41 +1010,41 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -1015,7 +1052,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
@@ -1024,8 +1061,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1033,18 +1070,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
context = (
|
||||
return (
|
||||
aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
if task.context is NOT_SPECIFIED
|
||||
else aggregate_raw_outputs_from_tasks(task.context)
|
||||
)
|
||||
return context
|
||||
|
||||
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
|
||||
role = task.agent.role if task.agent is not None else "None"
|
||||
@@ -1057,7 +1093,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output=output.raw,
|
||||
)
|
||||
|
||||
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
|
||||
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
@@ -1088,10 +1124,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs: List[TaskOutput] = []
|
||||
) -> list[TaskOutput]:
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
@@ -1101,9 +1137,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(
|
||||
self, task_id: str, stored_outputs: List[Any]
|
||||
) -> Optional[int]:
|
||||
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
return next(
|
||||
(
|
||||
index
|
||||
@@ -1113,9 +1147,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
None,
|
||||
)
|
||||
|
||||
def replay(
|
||||
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
"""Replay the crew execution from a specific task."""
|
||||
stored_outputs = self._task_output_handler.load()
|
||||
if not stored_outputs:
|
||||
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
|
||||
@@ -1151,19 +1184,19 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.tasks[i].output = task_output
|
||||
|
||||
self._logging_color = "bold_blue"
|
||||
result = self._execute_tasks(self.tasks, start_index, True)
|
||||
return result
|
||||
return self._execute_tasks(self.tasks, start_index, True)
|
||||
|
||||
def query_knowledge(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information."""
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
@@ -1172,11 +1205,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
@@ -1230,7 +1263,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
cloned_tasks.append(cloned_task)
|
||||
task_mapping[task.key] = cloned_task
|
||||
|
||||
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
|
||||
for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
|
||||
if isinstance(original_task.context, list):
|
||||
cloned_context = [
|
||||
task_mapping[context_task.key]
|
||||
@@ -1256,7 +1289,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
copied_data.pop("agents", None)
|
||||
copied_data.pop("tasks", None)
|
||||
|
||||
copied_crew = Crew(
|
||||
return Crew(
|
||||
**copied_data,
|
||||
agents=cloned_agents,
|
||||
tasks=cloned_tasks,
|
||||
@@ -1266,15 +1299,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_llm=manager_llm,
|
||||
)
|
||||
|
||||
return copied_crew
|
||||
|
||||
def _set_tasks_callbacks(self) -> None:
|
||||
"""Sets callback for every task suing task_callback"""
|
||||
for task in self.tasks:
|
||||
if not task.callback:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
[
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
@@ -1307,10 +1338,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
Uses concurrent.futures for concurrent execution.
|
||||
"""
|
||||
try:
|
||||
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||
llm_instance = create_llm(eval_llm)
|
||||
@@ -1350,7 +1384,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
return (
|
||||
f"Crew(id={self.id}, process={self.process}, "
|
||||
f"number_of_agents={len(self.agents)}, "
|
||||
f"number_of_tasks={len(self.tasks)})"
|
||||
)
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
"""Reset specific or all memories for the crew.
|
||||
@@ -1364,7 +1402,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ValueError: If an invalid command type is provided.
|
||||
RuntimeError: If memory reset operation fails.
|
||||
"""
|
||||
VALID_TYPES = frozenset(
|
||||
valid_types = frozenset(
|
||||
[
|
||||
"long",
|
||||
"short",
|
||||
@@ -1377,9 +1415,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
]
|
||||
)
|
||||
|
||||
if command_type not in VALID_TYPES:
|
||||
if command_type not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
|
||||
f"Invalid command type. Must be one of: "
|
||||
f"{', '.join(sorted(valid_types))}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1389,7 +1428,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
error_msg = f"Failed to reset {command_type} memory: {e!s}"
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
@@ -1397,7 +1436,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = self._get_memory_systems()
|
||||
|
||||
for memory_type, config in memory_systems.items():
|
||||
for config in memory_systems.values():
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
@@ -1405,11 +1444,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
@@ -1434,18 +1475,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self):
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing all memory systems with their reset functions and display names.
|
||||
Dict containing all memory systems with their reset functions and
|
||||
display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory):
|
||||
@@ -1506,7 +1550,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
|
||||
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
@@ -9,48 +9,158 @@ This module provides the event infrastructure that allows users to:
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
ReasoningEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskEvaluationEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationCompletedEvent",
|
||||
"AgentEvaluationFailedEvent",
|
||||
"AgentEvaluationStartedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"AgentExecutionErrorEvent",
|
||||
"AgentExecutionStartedEvent",
|
||||
"AgentLogsExecutionEvent",
|
||||
"AgentLogsStartedEvent",
|
||||
"AgentReasoningCompletedEvent",
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"crewai_event_bus",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewTestCompletedEvent",
|
||||
"CrewTestFailedEvent",
|
||||
"CrewTestResultEvent",
|
||||
"CrewTestStartedEvent",
|
||||
"CrewTrainCompletedEvent",
|
||||
"CrewTrainFailedEvent",
|
||||
"CrewTrainStartedEvent",
|
||||
"FlowCreatedEvent",
|
||||
"FlowEvent",
|
||||
"FlowFinishedEvent",
|
||||
"FlowPlotEvent",
|
||||
"FlowStartedEvent",
|
||||
"KnowledgeQueryCompletedEvent",
|
||||
"KnowledgeQueryFailedEvent",
|
||||
"KnowledgeQueryStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeSearchQueryFailedEvent",
|
||||
"LLMCallCompletedEvent",
|
||||
"LLMCallFailedEvent",
|
||||
"LLMCallStartedEvent",
|
||||
"LLMGuardrailCompletedEvent",
|
||||
"LLMGuardrailStartedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveStartedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
"MemoryRetrievalCompletedEvent",
|
||||
"MemoryRetrievalStartedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveFailedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
]
|
||||
"MemorySaveStartedEvent",
|
||||
"MethodExecutionFailedEvent",
|
||||
"MethodExecutionFinishedEvent",
|
||||
"MethodExecutionStartedEvent",
|
||||
"ReasoningEvent",
|
||||
"TaskCompletedEvent",
|
||||
"TaskEvaluationEvent",
|
||||
"TaskFailedEvent",
|
||||
"TaskStartedEvent",
|
||||
"ToolExecutionErrorEvent",
|
||||
"ToolSelectionErrorEvent",
|
||||
"ToolUsageErrorEvent",
|
||||
"ToolUsageEvent",
|
||||
"ToolUsageFinishedEvent",
|
||||
"ToolUsageStartedEvent",
|
||||
"ToolValidateInputErrorEvent",
|
||||
"crewai_event_bus",
|
||||
]
|
||||
|
||||
230
src/crewai/events/listeners/tracing/first_time_trace_handler.py
Normal file
230
src/crewai/events/listeners/tracing/first_time_trace_handler.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
mark_first_execution_completed,
|
||||
prompt_user_for_trace_viewing,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_or_create_env_file():
|
||||
"""Update or create .env file with CREWAI_TRACING_ENABLED=true."""
|
||||
env_path = Path(".env")
|
||||
env_content = ""
|
||||
variable_name = "CREWAI_TRACING_ENABLED"
|
||||
variable_value = "true"
|
||||
|
||||
# Read existing content if file exists
|
||||
if env_path.exists():
|
||||
with open(env_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Check if CREWAI_TRACING_ENABLED is already set
|
||||
lines = env_content.splitlines()
|
||||
variable_exists = False
|
||||
updated_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith(f"{variable_name}="):
|
||||
# Update existing variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
variable_exists = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
# Add variable if it doesn't exist
|
||||
if not variable_exists:
|
||||
if updated_lines and not updated_lines[-1].strip():
|
||||
# If last line is empty, replace it
|
||||
updated_lines[-1] = f"{variable_name}={variable_value}"
|
||||
else:
|
||||
# Add new line and then the variable
|
||||
updated_lines.append(f"{variable_name}={variable_value}")
|
||||
|
||||
# Write updated content
|
||||
with open(env_path, "w") as f:
|
||||
f.write("\n".join(updated_lines))
|
||||
if updated_lines: # Add final newline if there's content
|
||||
f.write("\n")
|
||||
|
||||
|
||||
class FirstTimeTraceHandler:
|
||||
"""Handles the first-time user trace collection and display flow."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_first_time: bool = False
|
||||
self.collected_events: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.ephemeral_url: str | None = None
|
||||
self.batch_manager: TraceBatchManager | None = None
|
||||
|
||||
def initialize_for_first_time_user(self) -> bool:
|
||||
"""Check if this is first time and initialize collection."""
|
||||
self.is_first_time = should_auto_collect_first_time_traces()
|
||||
return self.is_first_time
|
||||
|
||||
def set_batch_manager(self, batch_manager: TraceBatchManager):
|
||||
"""Set reference to batch manager for sending events."""
|
||||
self.batch_manager = batch_manager
|
||||
|
||||
def mark_events_collected(self):
|
||||
"""Mark that events have been collected during execution."""
|
||||
self.collected_events = True
|
||||
|
||||
def handle_execution_completion(self):
|
||||
"""Handle the completion flow as shown in your diagram."""
|
||||
if not self.is_first_time or not self.collected_events:
|
||||
return
|
||||
|
||||
try:
|
||||
user_wants_traces = prompt_user_for_trace_viewing(timeout_seconds=20)
|
||||
|
||||
if user_wants_traces:
|
||||
self._initialize_backend_and_send_events()
|
||||
|
||||
# Enable tracing for future runs by updating .env file
|
||||
try:
|
||||
_update_or_create_env_file()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if self.ephemeral_url:
|
||||
self._display_ephemeral_trace_link()
|
||||
|
||||
mark_first_execution_completed()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||
mark_first_execution_completed()
|
||||
|
||||
def _initialize_backend_and_send_events(self):
|
||||
"""Initialize backend batch and send collected events."""
|
||||
if not self.batch_manager:
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.batch_manager.backend_initialized:
|
||||
original_metadata = (
|
||||
self.batch_manager.current_batch.execution_metadata
|
||||
if self.batch_manager.current_batch
|
||||
else {}
|
||||
)
|
||||
|
||||
user_context = {
|
||||
"privacy_level": "standard",
|
||||
"user_id": "first_time_user",
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"trace_id": self.batch_manager.trace_batch_id,
|
||||
}
|
||||
|
||||
execution_metadata = {
|
||||
"execution_type": original_metadata.get("execution_type", "crew"),
|
||||
"crew_name": original_metadata.get(
|
||||
"crew_name", "First Time Execution"
|
||||
),
|
||||
"flow_name": original_metadata.get("flow_name"),
|
||||
"agent_count": original_metadata.get("agent_count", 1),
|
||||
"task_count": original_metadata.get("task_count", 1),
|
||||
"crewai_version": original_metadata.get("crewai_version"),
|
||||
}
|
||||
|
||||
self.batch_manager._initialize_backend_batch(
|
||||
user_context=user_context,
|
||||
execution_metadata=execution_metadata,
|
||||
use_ephemeral=True,
|
||||
)
|
||||
self.batch_manager.backend_initialized = True
|
||||
|
||||
if self.batch_manager.event_buffer:
|
||||
self.batch_manager._send_events_to_backend()
|
||||
|
||||
self.batch_manager.finalize_batch()
|
||||
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
|
||||
|
||||
if not self.ephemeral_url:
|
||||
self._show_local_trace_message()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
panel_content = f"""
|
||||
🎉 Your First CrewAI Execution Trace is Ready!
|
||||
|
||||
View your execution details here:
|
||||
{self.ephemeral_url}
|
||||
|
||||
This trace shows:
|
||||
• Agent decisions and interactions
|
||||
• Task execution timeline
|
||||
• Tool usage and results
|
||||
• LLM calls and responses
|
||||
|
||||
✅ Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
You can also add tracing=True to your Crew(tracing=True) / Flow(tracing=True) for more control.
|
||||
|
||||
📝 Note: This link will expire in 24 hours.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Execution Trace Generated",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _gracefully_fail(self, error_message: str):
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self):
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
📊 Your execution traces were collected locally!
|
||||
|
||||
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
|
||||
• {len(self.batch_manager.event_buffer)} trace events
|
||||
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
|
||||
• Batch ID: {self.batch_manager.trace_batch_id}
|
||||
|
||||
Tracing has been enabled for future runs! (CREWAI_TRACING_ENABLED=true added to .env)
|
||||
The traces include agent decisions, task execution, and tool usage.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Local Traces Collected",
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
@@ -1,18 +1,18 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from logging import getLogger
|
||||
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -23,11 +23,11 @@ class TraceBatch:
|
||||
|
||||
version: str = field(default_factory=get_crewai_version)
|
||||
batch_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
user_context: Dict[str, str] = field(default_factory=dict)
|
||||
execution_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
events: List[TraceEvent] = field(default_factory=list)
|
||||
user_context: dict[str, str] = field(default_factory=dict)
|
||||
execution_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
events: list[TraceEvent] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"batch_id": self.batch_id,
|
||||
@@ -40,26 +40,28 @@ class TraceBatch:
|
||||
class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
is_current_batch_ephemeral: bool = False
|
||||
trace_batch_id: Optional[str] = None
|
||||
current_batch: Optional[TraceBatch] = None
|
||||
event_buffer: List[TraceEvent] = []
|
||||
execution_start_times: Dict[str, datetime] = {}
|
||||
batch_owner_type: Optional[str] = None
|
||||
batch_owner_id: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.is_current_batch_ephemeral: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.current_batch: TraceBatch | None = None
|
||||
self.event_buffer: list[TraceEvent] = []
|
||||
self.execution_start_times: dict[str, datetime] = {}
|
||||
self.batch_owner_type: str | None = None
|
||||
self.batch_owner_id: str | None = None
|
||||
self.backend_initialized: bool = False
|
||||
self.ephemeral_trace_url: str | None = None
|
||||
try:
|
||||
self.plus_api = PlusAPI(
|
||||
api_key=get_auth_token(),
|
||||
)
|
||||
except AuthError:
|
||||
self.plus_api = PlusAPI(api_key="")
|
||||
self.ephemeral_trace_url = None
|
||||
|
||||
def initialize_batch(
|
||||
self,
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch"""
|
||||
@@ -70,14 +72,21 @@ class TraceBatchManager:
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
|
||||
self.record_start_time("execution")
|
||||
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
)
|
||||
self.backend_initialized = True
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
self,
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
"""Send batch initialization to backend"""
|
||||
@@ -129,13 +138,6 @@ class TraceBatchManager:
|
||||
if not use_ephemeral
|
||||
else response_data["ephemeral_trace_id"]
|
||||
)
|
||||
console = Console()
|
||||
panel = Panel(
|
||||
f"✅ Trace batch initialized with session ID: {self.trace_batch_id}",
|
||||
title="Trace Batch Initialization",
|
||||
border_style="green",
|
||||
)
|
||||
console.print(panel)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Trace batch initialization returned status {response.status_code}. Continuing without tracing."
|
||||
@@ -143,7 +145,7 @@ class TraceBatchManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error initializing trace batch: {str(e)}. Continuing without tracing."
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
@@ -154,7 +156,6 @@ class TraceBatchManager:
|
||||
"""Send buffered events to backend with graceful failure handling"""
|
||||
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
|
||||
return 500
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"events": [event.to_dict() for event in self.event_buffer],
|
||||
@@ -178,19 +179,19 @@ class TraceBatchManager:
|
||||
if response.status_code in [200, 201]:
|
||||
self.event_buffer.clear()
|
||||
return 200
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error sending events to backend: {str(e)}. Events will be lost."
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
def finalize_batch(self) -> Optional[TraceBatch]:
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error sending events to backend: {e}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
def finalize_batch(self) -> TraceBatch | None:
|
||||
"""Finalize batch and return it for sending"""
|
||||
if not self.current_batch:
|
||||
return None
|
||||
@@ -246,12 +247,27 @@ class TraceBatchManager:
|
||||
if not self.is_current_batch_ephemeral and access_code is None
|
||||
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
|
||||
)
|
||||
|
||||
if self.is_current_batch_ephemeral:
|
||||
self.ephemeral_trace_url = return_link
|
||||
|
||||
# Create a properly formatted message with URL on its own line
|
||||
message_parts = [
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}",
|
||||
"",
|
||||
f"🔗 View here: {return_link}",
|
||||
]
|
||||
|
||||
if access_code:
|
||||
message_parts.append(f"🔑 Access Code: {access_code}")
|
||||
|
||||
panel = Panel(
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
|
||||
"\n".join(message_parts),
|
||||
title="Trace Batch Finalization",
|
||||
border_style="green",
|
||||
)
|
||||
console.print(panel)
|
||||
if not should_auto_collect_first_time_traces():
|
||||
console.print(panel)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
@@ -259,8 +275,8 @@ class TraceBatchManager:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {str(e)}")
|
||||
# TODO: send error to app
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
# TODO: send error to app marking as failed
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
@@ -277,7 +293,7 @@ class TraceBatchManager:
|
||||
self.batch_sequence = 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Warning: Error during cleanup: {str(e)}")
|
||||
logger.error(f"Warning: Error during cleanup: {e}")
|
||||
|
||||
def has_events(self) -> bool:
|
||||
"""Check if there are events in the buffer"""
|
||||
@@ -306,7 +322,7 @@ class TraceBatchManager:
|
||||
return duration_ms
|
||||
return 0
|
||||
|
||||
def get_trace_id(self) -> Optional[str]:
|
||||
def get_trace_id(self) -> str | None:
|
||||
"""Get current trace ID"""
|
||||
if self.current_batch:
|
||||
return self.current_batch.user_context.get("trace_id")
|
||||
|
||||
@@ -1,28 +1,59 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
from crewai.events.listeners.tracing.utils import safe_serialize_to_dict
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -33,49 +64,16 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
FlowPlotEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
from .trace_batch_manager import TraceBatchManager
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""
|
||||
Trace collection listener that orchestrates trace collection
|
||||
"""
|
||||
|
||||
complex_events = [
|
||||
complex_events: ClassVar[list[str]] = [
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"llm_call_started",
|
||||
@@ -88,14 +86,14 @@ class TraceCollectionListener(BaseEventListener):
|
||||
_initialized = False
|
||||
_listeners_setup = False
|
||||
|
||||
def __new__(cls, batch_manager=None):
|
||||
def __new__(cls, batch_manager: TraceBatchManager | None = None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: Optional[TraceBatchManager] = None,
|
||||
batch_manager: TraceBatchManager | None = None,
|
||||
):
|
||||
if self._initialized:
|
||||
return
|
||||
@@ -103,16 +101,19 @@ class TraceCollectionListener(BaseEventListener):
|
||||
super().__init__()
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self._initialized = True
|
||||
self.first_time_handler = FirstTimeTraceHandler()
|
||||
|
||||
if self.first_time_handler.initialize_for_first_time_user():
|
||||
self.first_time_handler.set_batch_manager(self.batch_manager)
|
||||
|
||||
def _check_authenticated(self) -> bool:
|
||||
"""Check if tracing should be enabled"""
|
||||
try:
|
||||
res = bool(get_auth_token())
|
||||
return res
|
||||
return bool(get_auth_token())
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
def _get_user_context(self) -> Dict[str, str]:
|
||||
def _get_user_context(self) -> dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
return {
|
||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||
@@ -161,8 +162,14 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event):
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
# Normal flow finalization
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
@@ -181,12 +188,20 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_crew_completed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event):
|
||||
@@ -325,17 +340,19 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
):
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
self.batch_manager.initialize_batch(
|
||||
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
|
||||
|
||||
if self.first_time_handler.is_first_time:
|
||||
return self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=True
|
||||
)
|
||||
else:
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=False
|
||||
)
|
||||
|
||||
use_ephemeral = not self._check_authenticated()
|
||||
return self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=use_ephemeral
|
||||
)
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
@@ -371,11 +388,11 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return self._safe_serialize_to_dict(event)
|
||||
elif event_type == "task_started":
|
||||
return safe_serialize_to_dict(event)
|
||||
if event_type == "task_started":
|
||||
return {
|
||||
"task_description": event.task.description,
|
||||
"expected_output": event.task.expected_output,
|
||||
@@ -384,7 +401,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"agent_role": source.agent.role,
|
||||
"task_id": str(event.task.id),
|
||||
}
|
||||
elif event_type == "task_completed":
|
||||
if event_type == "task_completed":
|
||||
return {
|
||||
"task_description": event.task.description if event.task else None,
|
||||
"task_name": event.task.name or event.task.description
|
||||
@@ -397,63 +414,31 @@ class TraceCollectionListener(BaseEventListener):
|
||||
else None,
|
||||
"agent_role": event.output.agent if event.output else None,
|
||||
}
|
||||
elif event_type == "agent_execution_started":
|
||||
if event_type == "agent_execution_started":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
elif event_type == "agent_execution_completed":
|
||||
if event_type == "agent_execution_completed":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
elif event_type == "llm_call_started":
|
||||
event_data = self._safe_serialize_to_dict(event)
|
||||
if event_type == "llm_call_started":
|
||||
event_data = safe_serialize_to_dict(event)
|
||||
event_data["task_name"] = (
|
||||
event.task_name or event.task_description
|
||||
if hasattr(event, "task_name") and event.task_name
|
||||
else None
|
||||
)
|
||||
return event_data
|
||||
elif event_type == "llm_call_completed":
|
||||
return self._safe_serialize_to_dict(event)
|
||||
else:
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": self._safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
if event_type == "llm_call_completed":
|
||||
return safe_serialize_to_dict(event)
|
||||
|
||||
# TODO: move to utils
|
||||
def _safe_serialize_to_dict(
|
||||
self, obj, exclude: set[str] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
else:
|
||||
return {"serialized_data": serialized}
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
# TODO: move to utils
|
||||
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
# Limit number of messages
|
||||
limited_messages = messages[:max_messages]
|
||||
|
||||
# Truncate each message content
|
||||
for msg in limited_messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
content = msg["content"]
|
||||
if len(content) > max_content_length:
|
||||
msg["content"] = content[:max_content_length] + "..."
|
||||
|
||||
return limited_messages
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
import hashlib
|
||||
import subprocess
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import re
|
||||
import json
|
||||
import subprocess
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
@@ -43,13 +51,11 @@ def _get_machine_id() -> str:
|
||||
|
||||
try:
|
||||
mac = ":".join(
|
||||
["{:02x}".format((uuid.getnode() >> b) & 0xFF) for b in range(0, 12, 2)][
|
||||
::-1
|
||||
]
|
||||
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
|
||||
)
|
||||
parts.append(mac)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Error getting machine id for fingerprinting")
|
||||
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
@@ -57,7 +63,7 @@ def _get_machine_id() -> str:
|
||||
try:
|
||||
if sysname == "Darwin":
|
||||
res = subprocess.run(
|
||||
["system_profiler", "SPHardwareDataType"],
|
||||
["/usr/sbin/system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
@@ -72,7 +78,7 @@ def _get_machine_id() -> str:
|
||||
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
|
||||
elif sysname == "Windows":
|
||||
res = subprocess.run(
|
||||
["wmic", "csproduct", "get", "UUID"],
|
||||
["C:\\Windows\\System32\\wbem\\wmic.exe", "csproduct", "get", "UUID"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
@@ -81,7 +87,7 @@ def _get_machine_id() -> str:
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Error getting machine ID")
|
||||
|
||||
return hashlib.sha256("".join(parts).encode()).hexdigest()
|
||||
|
||||
@@ -97,8 +103,8 @@ def _load_user_data() -> dict:
|
||||
if p.exists():
|
||||
try:
|
||||
return json.loads(p.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
except (json.JSONDecodeError, OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to load user data: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
@@ -106,8 +112,8 @@ def _save_user_data(data: dict) -> None:
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except Exception:
|
||||
pass
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
@@ -151,3 +157,103 @@ def mark_first_execution_done() -> None:
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
return {"serialized_data": serialized}
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
|
||||
def truncate_messages(messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
limited_messages = messages[:max_messages]
|
||||
|
||||
for msg in limited_messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
content = msg["content"]
|
||||
if len(content) > max_content_length:
|
||||
msg["content"] = content[:max_content_length] + "..."
|
||||
|
||||
return limited_messages
|
||||
|
||||
|
||||
def should_auto_collect_first_time_traces() -> bool:
|
||||
"""True if we should auto-collect traces for first-time user."""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
return is_first_execution()
|
||||
|
||||
|
||||
def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
"""
|
||||
Prompt user if they want to see their traces with timeout.
|
||||
Returns True if user wants to see traces, False otherwise.
|
||||
"""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
||||
console = Console()
|
||||
|
||||
content = Text()
|
||||
content.append("🔍 ", style="cyan bold")
|
||||
content.append(
|
||||
"Detailed execution traces are available!\n\n", style="cyan bold"
|
||||
)
|
||||
content.append("View insights including:\n", style="white")
|
||||
content.append(" • Agent decision-making process\n", style="bright_blue")
|
||||
content.append(" • Task execution flow and timing\n", style="bright_blue")
|
||||
content.append(" • Tool usage details", style="bright_blue")
|
||||
|
||||
panel = Panel(
|
||||
content,
|
||||
title="[bold cyan]Execution Traces[/bold cyan]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
|
||||
prompt_text = click.style(
|
||||
f"Would you like to view your execution traces? [y/N] ({timeout_seconds}s timeout): ",
|
||||
fg="white",
|
||||
bold=True,
|
||||
)
|
||||
click.echo(prompt_text, nl=False)
|
||||
|
||||
result = [False]
|
||||
|
||||
def get_input():
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
input_thread.start()
|
||||
input_thread.join(timeout=timeout_seconds)
|
||||
|
||||
if input_thread.is_alive():
|
||||
return False
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def mark_first_execution_completed() -> None:
|
||||
"""Mark first execution as completed (called after trace prompt)."""
|
||||
mark_first_execution_done()
|
||||
|
||||
@@ -2,30 +2,22 @@ import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -35,12 +27,10 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -55,16 +45,14 @@ class FlowState(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Type variables with explicit bounds
|
||||
T = TypeVar(
|
||||
"T", bound=Union[Dict[str, Any], BaseModel]
|
||||
) # Generic flow state type parameter
|
||||
# type variables with explicit bounds
|
||||
T = TypeVar("T", bound=dict[str, Any] | BaseModel) # Generic flow state type parameter
|
||||
StateT = TypeVar(
|
||||
"StateT", bound=Union[Dict[str, Any], BaseModel]
|
||||
"StateT", bound=dict[str, Any] | BaseModel
|
||||
) # State validation type parameter
|
||||
|
||||
|
||||
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
Args:
|
||||
@@ -104,7 +92,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
raise TypeError(f"Invalid expected_type: {expected_type}")
|
||||
|
||||
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
"""
|
||||
Marks a method as a flow's starting point.
|
||||
|
||||
@@ -171,7 +159,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
def listen(condition: str | dict | Callable) -> Callable:
|
||||
"""
|
||||
Creates a listener that executes when specified conditions are met.
|
||||
|
||||
@@ -231,7 +219,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
def router(condition: str | dict | Callable) -> Callable:
|
||||
"""
|
||||
Creates a routing method that directs flow execution based on conditions.
|
||||
|
||||
@@ -297,7 +285,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
def or_(*conditions: str | dict | Callable) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
@@ -343,7 +331,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
@@ -425,10 +413,10 @@ class FlowMeta(type):
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
setattr(cls, "_routers", routers)
|
||||
setattr(cls, "_router_paths", router_paths)
|
||||
cls._start_methods = start_methods
|
||||
cls._listeners = listeners
|
||||
cls._routers = routers
|
||||
cls._router_paths = router_paths
|
||||
|
||||
return cls
|
||||
|
||||
@@ -436,29 +424,29 @@ class FlowMeta(type):
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""Base class for all flows.
|
||||
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||
_routers: Set[str] = set()
|
||||
_router_paths: Dict[str, List[str]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
name: Optional[str] = None
|
||||
tracing: Optional[bool] = False
|
||||
_start_methods: ClassVar[list[str]] = []
|
||||
_listeners: ClassVar[dict[str, tuple[str, list[str]]]] = {}
|
||||
_routers: ClassVar[set[str]] = set()
|
||||
_router_paths: ClassVar[dict[str, list[str]]] = {}
|
||||
initial_state: type[T] | T | None = None
|
||||
name: str | None = None
|
||||
tracing: bool | None = False
|
||||
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
_initial_state_t = item # type: ignore
|
||||
|
||||
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
|
||||
return _FlowGeneric
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
tracing: Optional[bool] = False,
|
||||
persistence: FlowPersistence | None = None,
|
||||
tracing: bool | None = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
@@ -468,18 +456,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
**kwargs: Additional state values to initialize or override
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._methods: dict[str, Callable] = {}
|
||||
self._method_execution_counts: dict[str, int] = {}
|
||||
self._pending_and_listeners: dict[str, set[str]] = {}
|
||||
self._method_outputs: list[Any] = [] # list to store all method outputs
|
||||
self._completed_methods: set[str] = set() # Track completed methods for reload
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self.tracing = tracing
|
||||
if is_tracing_enabled() or self.tracing:
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
# Apply any additional kwargs
|
||||
@@ -521,25 +513,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
"""
|
||||
# Handle case where initial_state is None but we have a type parameter
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
state_type = getattr(self, "_initial_state_T")
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_t"):
|
||||
state_type = self._initial_state_t
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
elif issubclass(state_type, BaseModel):
|
||||
if issubclass(state_type, BaseModel):
|
||||
# Create a new type that includes the ID field
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
elif state_type is dict:
|
||||
if state_type is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
# Handle case where no initial state is provided
|
||||
@@ -550,13 +542,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif issubclass(self.initial_state, BaseModel):
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif self.initial_state is dict:
|
||||
if self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
# Handle dictionary instance case
|
||||
@@ -600,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> List[Any]:
|
||||
def method_outputs(self) -> list[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
return self._method_outputs
|
||||
|
||||
@@ -631,13 +623,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
elif isinstance(self._state, BaseModel):
|
||||
if isinstance(self._state, BaseModel):
|
||||
return str(getattr(self._state, "id", ""))
|
||||
return ""
|
||||
except (AttributeError, TypeError):
|
||||
return "" # Safely handle any unexpected attribute access issues
|
||||
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
def _initialize_state(self, inputs: dict[str, Any]) -> None:
|
||||
"""Initialize or update flow state with new inputs.
|
||||
|
||||
Args:
|
||||
@@ -691,7 +683,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||
|
||||
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
||||
def _restore_state(self, stored_state: dict[str, Any]) -> None:
|
||||
"""Restore flow state from persistence.
|
||||
|
||||
Args:
|
||||
@@ -735,7 +727,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
execution_data: Flow execution data containing:
|
||||
- id: Flow execution ID
|
||||
- flow: Flow structure
|
||||
- completed_methods: List of successfully completed methods
|
||||
- completed_methods: list of successfully completed methods
|
||||
- execution_methods: All execution methods with their status
|
||||
"""
|
||||
flow_id = execution_data.get("id")
|
||||
@@ -771,7 +763,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if state_to_apply:
|
||||
self._apply_state_updates(state_to_apply)
|
||||
|
||||
for i, method in enumerate(sorted_methods[:-1]):
|
||||
for method in sorted_methods[:-1]:
|
||||
method_name = method.get("flow_method", {}).get("name")
|
||||
if method_name:
|
||||
self._completed_methods.add(method_name)
|
||||
@@ -783,7 +775,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif hasattr(self._state, field_name):
|
||||
object.__setattr__(self._state, field_name, value)
|
||||
|
||||
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
|
||||
def _apply_state_updates(self, updates: dict[str, Any]) -> None:
|
||||
"""Apply multiple state updates efficiently."""
|
||||
if isinstance(self._state, dict):
|
||||
self._state.update(updates)
|
||||
@@ -792,7 +784,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if hasattr(self._state, key):
|
||||
object.__setattr__(self._state, key, value)
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
|
||||
@@ -805,7 +797,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
|
||||
@@ -840,7 +832,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self._state, dict):
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
setattr(self._state, "id", inputs["id"]) # noqa: B010
|
||||
|
||||
# If persistence is enabled, attempt to restore the stored state using the provided id.
|
||||
if "id" in inputs and self._persistence is not None:
|
||||
@@ -1075,7 +1067,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method] + router_results
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
@@ -1109,7 +1101,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
@@ -1126,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
list[str]
|
||||
Names of methods that should be triggered.
|
||||
|
||||
Notes
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
@@ -13,23 +14,23 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
@@ -40,11 +41,10 @@ class Knowledge(BaseModel):
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -55,12 +55,11 @@ class Knowledge(BaseModel):
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
try:
|
||||
|
||||
@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
|
||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||
"""
|
||||
|
||||
results_limit: int = Field(default=3, description="The number of results to return")
|
||||
results_limit: int = Field(default=5, description="The number of results to return")
|
||||
score_threshold: float = Field(
|
||||
default=0.35,
|
||||
default=0.6,
|
||||
description="The minimum score for a result to be considered relevant",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
@@ -8,22 +10,17 @@ class BaseKnowledgeStorage(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
|
||||
) -> None:
|
||||
def save(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
pass
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -27,167 +20,105 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
embedder: dict[str, Any] | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder)
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
||||
result = {
|
||||
"id": fetched["ids"][0][i], # type: ignore
|
||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
||||
"context": fetched["documents"][0][i], # type: ignore
|
||||
"score": fetched["distances"][0][i], # type: ignore
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
self.app = create_persistent_client(
|
||||
path=os.path.join(db_storage_path(), "knowledge"),
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
if embedder:
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
if self.app:
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedder,
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
def reset(self):
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = create_persistent_client(
|
||||
path=base_path, settings=Settings(allow_reset=True)
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
self.app.reset()
|
||||
shutil.rmtree(base_path)
|
||||
self.app = None
|
||||
self.collection = None
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
try:
|
||||
# Create a dictionary to store unique documents
|
||||
unique_docs = {}
|
||||
|
||||
# Generate IDs and create a mapping of id -> (document, metadata)
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
|
||||
doc_metadata = None
|
||||
if metadata is not None:
|
||||
if isinstance(metadata, list):
|
||||
doc_metadata = metadata[idx]
|
||||
else:
|
||||
doc_metadata = metadata
|
||||
unique_docs[doc_id] = (doc, doc_metadata)
|
||||
|
||||
# Prepare filtered lists for ChromaDB
|
||||
filtered_docs = []
|
||||
filtered_metadata = []
|
||||
filtered_ids = []
|
||||
|
||||
# Build the filtered lists
|
||||
for doc_id, (doc, meta) in unique_docs.items():
|
||||
filtered_docs.append(doc)
|
||||
filtered_metadata.append(meta)
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=filtered_docs,
|
||||
metadatas=final_metadata,
|
||||
ids=filtered_ids,
|
||||
)
|
||||
except chromadb.errors.InvalidDimensionException as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, List
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
|
||||
def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
|
||||
"""Extract knowledge from the task prompt."""
|
||||
valid_snippets = [
|
||||
result["context"]
|
||||
result["content"]
|
||||
for result in knowledge_snippets
|
||||
if result and result.get("context")
|
||||
if result and result.get("content")
|
||||
]
|
||||
snippet = "\n".join(valid_snippets)
|
||||
return f"Additional Information: {snippet}" if valid_snippets else ""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -19,9 +21,9 @@ class ContextualMemory:
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Optional["Agent"] = None,
|
||||
task: Optional["Task"] = None,
|
||||
):
|
||||
agent: Agent | None = None,
|
||||
task: Task | None = None,
|
||||
) -> None:
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
@@ -42,7 +44,7 @@ class ContextualMemory:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
@@ -52,14 +54,15 @@ class ContextualMemory:
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context = []
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
context.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
context_parts = [
|
||||
self._fetch_ltm_context(task.description),
|
||||
self._fetch_stm_context(query),
|
||||
self._fetch_entity_context(query),
|
||||
self._fetch_external_context(query),
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -70,11 +73,11 @@ class ContextualMemory:
|
||||
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
def _fetch_ltm_context(self, task: str) -> str | None:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -90,14 +93,14 @@ class ContextualMemory:
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -107,7 +110,7 @@ class ContextualMemory:
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
@@ -128,6 +131,6 @@ class ContextualMemory:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['context']}" for result in external_memories
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item and return success status."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super(EntityMemory, self).save(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
success, error = save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the entity memory: {e}"
|
||||
) from e
|
||||
|
||||
34
src/crewai/memory/external/external_memory.py
vendored
34
src/crewai/memory/external/external_memory.py
vendored
@@ -1,41 +1,41 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -12,8 +12,8 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
_agent: Optional["Agent"] = None
|
||||
@@ -45,7 +45,7 @@ class Memory(BaseModel):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
|
||||
@@ -54,9 +54,9 @@ class Memory(BaseModel):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.chromadb.utils import _sanitize_collection_name
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
@@ -13,6 +16,7 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -28,7 +32,8 @@ class Mem0Storage(Storage):
|
||||
supported_types = {"short_term", "long_term", "entities", "external"}
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
|
||||
f"Invalid type '{type}' for Mem0Storage. "
|
||||
f"Must be one of: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
def _extract_config_values(self):
|
||||
@@ -66,7 +71,8 @@ class Mem0Storage(Storage):
|
||||
- Includes user_id and agent_id if both are present.
|
||||
- Includes user_id if only user_id is present.
|
||||
- Includes agent_id if only agent_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and
|
||||
mem0_run_id is present.
|
||||
"""
|
||||
filter = defaultdict(list)
|
||||
|
||||
@@ -86,21 +92,44 @@ class Mem0Storage(Storage):
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
|
||||
return next(
|
||||
(
|
||||
m.get("content", "")
|
||||
for m in reversed(list(messages))
|
||||
if m.get("role") == role
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
conversations = []
|
||||
messages = metadata.pop("messages", None)
|
||||
if messages:
|
||||
last_user = _last_content(messages, "user")
|
||||
last_assistant = _last_content(messages, "assistant")
|
||||
|
||||
if user_msg := self._get_user_message(last_user):
|
||||
conversations.append({"role": "user", "content": user_msg})
|
||||
|
||||
if assistant_msg := self._get_assistant_message(last_assistant):
|
||||
conversations.append({"role": "assistant", "content": assistant_msg})
|
||||
else:
|
||||
conversations.append({"role": "assistant", "content": value})
|
||||
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external"
|
||||
"external": "external",
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer
|
||||
"infer": self.infer,
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
@@ -119,15 +148,17 @@ class Mem0Storage(Storage):
|
||||
if agent_id := self.config.get("agent_id", self._get_agent_name()):
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
self.memory.add(assistant_message, **params)
|
||||
self.memory.add(conversations, **params)
|
||||
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
def search(
|
||||
self, query: str, limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
@@ -148,10 +179,10 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params['threshold'] = score_threshold
|
||||
params["threshold"] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
@@ -159,8 +190,8 @@ class Mem0Storage(Storage):
|
||||
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["context"] = result["memory"]
|
||||
|
||||
result["content"] = result["memory"]
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self):
|
||||
@@ -180,4 +211,19 @@ class Mem0Storage(Storage):
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
return _sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
|
||||
def _get_assistant_message(self, text: str) -> str:
|
||||
marker = "Final Answer:"
|
||||
if marker in text:
|
||||
return text.split(marker, 1)[1].strip()
|
||||
return text
|
||||
|
||||
def _get_user_message(self, text: str) -> str:
|
||||
pattern = r"User message:\s*(.*)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return text
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from chromadb.api import ClientAPI
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
import warnings
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
@@ -20,8 +20,6 @@ class RAGStorage(BaseRAGStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
@@ -33,37 +31,25 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
from chromadb.config import Settings
|
||||
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
self._set_embedder_config()
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(embedding_function=embedding_function)
|
||||
self._client = create_client(config)
|
||||
|
||||
self.app = create_persistent_client(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
logging.info(f"Collection found or created: {self.collection}")
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -85,77 +71,69 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
logging.error(
|
||||
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
self.app = None
|
||||
self.collection = None
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
if "attempt to write a readonly database" in str(
|
||||
e
|
||||
) or "does not exist" in str(e):
|
||||
# Ignore readonly database and collection not found errors (already reset)
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
) from e
|
||||
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
@@ -23,13 +24,13 @@ from crewai.rag.chromadb.utils import (
|
||||
_process_query_results,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
@@ -41,21 +42,29 @@ class ChromaDBClient(BaseClient):
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -300,16 +309,18 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
@@ -342,15 +353,17 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -385,9 +398,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -443,9 +461,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from chromadb.config import Settings
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
DEFAULT_TENANT,
|
||||
)
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -49,7 +49,17 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return cast(
|
||||
ChromaEmbeddingFunctionWrapper,
|
||||
OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import os
|
||||
from hashlib import md5
|
||||
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
@@ -23,6 +24,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
@@ -37,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -3,27 +3,28 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
@@ -44,7 +45,7 @@ class PreparedDocuments(NamedTuple):
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings
|
||||
metadatas: List of document metadata mappings (empty dict for no metadata)
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
@@ -85,7 +86,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
embedding_function: ChromaEmbeddingFunction
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
@@ -5,13 +5,14 @@ from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_COLLECTION,
|
||||
INVALID_CHARS_PATTERN,
|
||||
@@ -78,7 +79,7 @@ def _prepare_documents_for_chromadb(
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata else {})
|
||||
metadatas.append(metadata[0] if metadata and metadata[0] else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
@@ -154,7 +155,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include]
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
@@ -188,7 +189,9 @@ def _convert_chromadb_results_to_search_results(
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||
"metadata": dict(metadatas[i])
|
||||
if metadatas and i < len(metadatas) and metadatas[i] is not None
|
||||
else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
@@ -271,7 +274,7 @@ def _sanitize_collection_name(
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
|
||||
@@ -14,3 +14,5 @@ class BaseRagConfig:
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable, Annotated
|
||||
from typing_extensions import Unpack, Required, TypedDict
|
||||
from typing import Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from typing_extensions import Required, TypedDict, Unpack
|
||||
|
||||
from crewai.rag.types import (
|
||||
EmbeddingFunction,
|
||||
BaseRecord,
|
||||
EmbeddingFunction,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
|
||||
query: Required[str]
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any]
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
@@ -60,7 +60,7 @@ def get_embedding_function(
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings (default)
|
||||
- openai: OpenAI embeddings
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
@@ -77,7 +77,7 @@ def get_embedding_function(
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI with retry logic
|
||||
# Use default OpenAI embedding
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client, embedding_function=config.embedding_function
|
||||
client=qdrant_client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
@@ -32,45 +32,21 @@ class BaseRAGStorage(ABC):
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
from hashlib import md5
|
||||
from typing import Optional
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
MIN_COLLECTION_LENGTH = 3
|
||||
MAX_COLLECTION_LENGTH = 63
|
||||
DEFAULT_COLLECTION = "default_collection"
|
||||
|
||||
# Compiled regex patterns for better performance
|
||||
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
|
||||
|
||||
def is_ipv4_pattern(name: str) -> bool:
|
||||
"""
|
||||
Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def sanitize_collection_name(
|
||||
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||
) -> str:
|
||||
"""
|
||||
Sanitize a collection name to meet ChromaDB requirements:
|
||||
1. 3-63 characters long
|
||||
2. Starts and ends with alphanumeric character
|
||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||
4. No consecutive periods
|
||||
5. Not a valid IPv4 address
|
||||
|
||||
Args:
|
||||
name: The original collection name to sanitize
|
||||
|
||||
Returns:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
if is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
if not sanitized[0].isalnum():
|
||||
sanitized = "a" + sanitized
|
||||
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def create_persistent_client(path: str, **kwargs):
|
||||
"""
|
||||
Creates a persistent client for ChromaDB with a lock file to prevent
|
||||
concurrent creations. Works for both multi-threads and multi-processes
|
||||
environments.
|
||||
"""
|
||||
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(path=path, **kwargs)
|
||||
|
||||
return client
|
||||
@@ -9,19 +9,19 @@ import pytest
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.tools import tool
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
@@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
@tool
|
||||
def comapny_customer_data() -> float:
|
||||
def comapny_customer_data() -> str:
|
||||
"""Useful for getting customer related data."""
|
||||
return "The company has 42 customers"
|
||||
|
||||
@@ -500,6 +500,15 @@ def test_agent_custom_max_iterations():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_repeated_tool_usage(capsys):
|
||||
"""Test that agents handle repeated tool usage appropriately.
|
||||
|
||||
Notes:
|
||||
Investigate whether to pin down the specific execution flow by examining
|
||||
src/crewai/agents/crew_agent_executor.py:177-186 (max iterations check)
|
||||
and src/crewai/tools/tool_usage.py:152-157 (repeated usage detection)
|
||||
to ensure deterministic behavior.
|
||||
"""
|
||||
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
"""Get the final answer but don't give it yet, just re-use this tool non-stop."""
|
||||
@@ -527,41 +536,15 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = (
|
||||
captured.out.replace("\n", " ")
|
||||
.replace(" ", " ")
|
||||
.strip()
|
||||
.replace("╭", "")
|
||||
.replace("╮", "")
|
||||
.replace("╯", "")
|
||||
.replace("╰", "")
|
||||
.replace("│", "")
|
||||
.replace("─", "")
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.replace("bold", "")
|
||||
.replace("blue", "")
|
||||
.replace("yellow", "")
|
||||
.replace("green", "")
|
||||
.replace("red", "")
|
||||
.replace("dim", "")
|
||||
.replace("🤖", "")
|
||||
.replace("🔧", "")
|
||||
.replace("✅", "")
|
||||
.replace("\x1b[93m", "")
|
||||
.replace("\x1b[00m", "")
|
||||
.replace("\\", "")
|
||||
.replace('"', "")
|
||||
.replace("'", "")
|
||||
)
|
||||
output_lower = captured.out.lower()
|
||||
|
||||
# Look for the message in the normalized output, handling the apostrophe difference
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input."
|
||||
has_repeated_usage_message = "tried reusing the same input" in output_lower
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
)
|
||||
assert (
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -602,9 +585,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert (
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -880,7 +863,7 @@ def test_agent_step_callback():
|
||||
with patch.object(StepCallback, "callback") as callback:
|
||||
|
||||
@tool
|
||||
def learn_about_AI() -> str:
|
||||
def learn_about_ai() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -888,7 +871,7 @@ def test_agent_step_callback():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_AI],
|
||||
tools=[learn_about_ai],
|
||||
step_callback=StepCallback().callback,
|
||||
)
|
||||
|
||||
@@ -910,7 +893,7 @@ def test_agent_function_calling_llm():
|
||||
llm = "gpt-4o"
|
||||
|
||||
@tool
|
||||
def learn_about_AI() -> str:
|
||||
def learn_about_ai() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -918,7 +901,7 @@ def test_agent_function_calling_llm():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_AI],
|
||||
tools=[learn_about_ai],
|
||||
llm="gpt-4o",
|
||||
max_iter=2,
|
||||
function_calling_llm=llm,
|
||||
@@ -1356,7 +1339,7 @@ def test_agent_training_handler(crew_training_handler):
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
|
||||
f"{agent.id!s}": {"0": {"human_feedback": "good"}}
|
||||
}
|
||||
|
||||
result = agent._training_handler(task_prompt=task_prompt)
|
||||
@@ -1473,7 +1456,7 @@ def test_agent_with_custom_stop_words():
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
|
||||
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
|
||||
assert all(word in agent.llm.stop for word in stop_words)
|
||||
assert "\nObservation:" in agent.llm.stop
|
||||
|
||||
@@ -1530,7 +1513,7 @@ def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@@ -1830,11 +1813,11 @@ def test_agent_execute_task_with_ollama():
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with patch("crewai.knowledge") as mock_knowledge:
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.search.return_value = [{"content": content}]
|
||||
MockKnowledge.add_sources.return_value = [string_source]
|
||||
mock_knowledge.add_sources.return_value = [string_source]
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1863,12 +1846,25 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1898,15 +1894,27 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -1935,10 +1943,16 @@ def test_agent_with_knowledge_sources_extensive_role():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
|
||||
) as mock_save,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
mock_save.return_value = None
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
@@ -1968,8 +1982,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
with patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
|
||||
autospec=True,
|
||||
) as MockKnowledgeSource:
|
||||
mock_knowledge_source_instance = MockKnowledgeSource.return_value
|
||||
) as mock_knowledge_source:
|
||||
mock_knowledge_source_instance = mock_knowledge_source.return_value
|
||||
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
|
||||
mock_knowledge_source_instance.sources = [string_source]
|
||||
|
||||
@@ -1983,9 +1997,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledgeStorage:
|
||||
mock_knowledge_storage = MockKnowledgeStorage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage
|
||||
) as mock_knowledge_storage:
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
@@ -2004,11 +2018,30 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -2270,7 +2303,26 @@ def test_get_knowledge_search_query():
|
||||
i18n = I18N()
|
||||
task_prompt = task.prompt()
|
||||
|
||||
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query:
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
mock_get_query.return_value = "Capital of France"
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
@@ -2312,9 +2364,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
FileReadTool,
|
||||
SerperDevTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
@@ -2347,7 +2399,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
tool_action = EnterpriseActionTool(
|
||||
name="test_name",
|
||||
description="test_description",
|
||||
enterprise_action_token="test_token",
|
||||
enterprise_action_token="test_token", # noqa: S106
|
||||
action_name="test_action_name",
|
||||
action_schema={"test": "test"},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=NaXWifUGChHp6Ap1mvfMrNzmO4HdzddrqXkSR9T.hYo-1754508545647-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLBbtswDL37Kzid4yFx46bxbVixtsfssB22wlAl2lEri5okJ+uK/Psg
|
||||
OY3dtQV2MWA+vqf3SD5lAExJVgETWx5EZ3X++SrcY/Hrcec3l+SKP5frm/16Yx92m6/f9mwWGXR3
|
||||
jyI8sz4K6qzGoMgMsHDIA0bVxaq8WCzPyrN5AjqSqCOttSFfUt4po/JiXizz+SpfXBzZW1ICPavg
|
||||
RwYA8JS+0aeR+JtVkLRSpUPveYusOjUBMEc6Vhj3XvnATWCzERRkAppk/QYM7UFwA63aIXBoo23g
|
||||
xu/RAfw0X5ThGj6l/wquUWuawXdyWn6YSjpses9jLNNrPQG4MRR4HEsKc3tEDif7mlrr6M7/Q2WN
|
||||
Mspva4fck4lWfSDLEnrIAG7TmPoXyZl11NlQB3rA9NyiXA16bNzOFD2CgQLXk/qqmL2hV0sMXGk/
|
||||
GTQTXGxRjtRxK7yXiiZANkn92s1b2kNyZdr/kR8BIdAGlLV1KJV4mXhscxiP972205STYebR7ZTA
|
||||
Oih0cRMSG97r4aSYf/QBu7pRpkVnnRruqrF1eT7nzTmW5Zplh+wvAAAA//8DAGKunMhlAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980b99a73c1c22c6-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:12:11 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=Ahwkw3J9CDiluZudRgDmybz4FO07eXLz2MQDtkgfct4-1758143531-1.0.1.1-_3e8agfTZW.FPpRMLb1A2nET4OHQEGKNZeGeWT8LIiuSi8R2HWsGsJyueUyzYBYnfHqsfBUO16K1.TkEo2XiqVCaIi6pymeeQxwtXFF1wj8;
|
||||
path=/; expires=Wed, 17-Sep-25 21:42:11 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=iHqLoc_2sNQLMyzfGCLtGol8vf1Y44xirzQJUuUF_TI-1758143531242-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '419'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '609'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_ece5f999e09e4c189d38e5bc08b2fad9
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,128 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blBSmrabG0ICViqCCydYRbPOJDHreIztbIFV/zty
|
||||
0m1SPiQukTJv3vN7M/OUAAhVixKE7DDI3ur09dvgfa8P/FO+e/9wa/aHb4cPH9viEw5yK1aRwfdf
|
||||
SYZn1gvJvdUUFJsJlo4wUFTNd8U+32zybD0CPdekI621Id1w2iuj0nW23qTZLs33Z3bHSpIXJXxO
|
||||
AACexm/0aWr6LkrIVs+VnrzHlkR5aQIQjnWsCPRe+YAmiNUMSjaBzGj9FgwfQaKBVj0SILTRNqDx
|
||||
R3IAX8wbZVDDq/G/hI60Zjiy0/VS0FEzeIyhzKD1AkBjOGAcyhjl7oycLuY1t9bxvf+NKhpllO8q
|
||||
R+jZRKM+sBUjekoA7sYhDVe5hXXc21AFfqDxubzYTXpi3s0CfXkGAwfUi/ruPNprvaqmgEr7xZiF
|
||||
RNlRPVPnneBQK14AySL1n27+pj0lV6b9H/kZkJJsoLqyjmolrxPPbY7i6f6r7TLl0bDw5B6VpCoo
|
||||
cnETNTU46OmghP/hA/VVo0xLzjo1XVVjq2KbYbOlorgRySn5BQAA//8DALxsmCBjAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980ba79a4ab5f555-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:21:42 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=aMMf0fLckKHz0BLW_2lATxD.7R61uYo1ZVW8aeFbruA-1758144102-1.0.1.1-6EKM3UxpdczoiQ6VpPpqqVnY7ftnXndFRWE4vyTzVcy.CQ4N539D97Wh8Ye9EUAvpUuukhW.r5MznkXq4tPXgCCmEv44RvVz2GBAz_e31h8;
|
||||
path=/; expires=Wed, 17-Sep-25 21:51:42 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=VqrtvU8.QdEHc4.1XXUVmccaCcoj_CiNfI2zhKJoGRs-1758144102566-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '308'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '620'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_fa896433021140238115972280c05651
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,127 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
|
||||
is the expected criteria for your final answer: test output\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '812'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrKxY8W4WV+JHoVgR95NJD4UvaBgJDrSS2FJclV3bSwP9e
|
||||
kHYsuU2BXghwZ2c4s8vnDEDoWpQgVCdZ9c7kNx+4v7vb7jafnrrPX25/vtObX48f1Q31m+uFmEUG
|
||||
PXxHxS+sN4p6Z5A12QOsPErGqFqsl1fF4nJdzBPQU40m0lrH+YLyXludX8wvFvl8nRdXR3ZHWmEQ
|
||||
JXzNAACe0xl92hofRQlJK1V6DEG2KMpTE4DwZGJFyBB0YGlZzEZQkWW0yfotWNqBkhZavUWQ0Ebb
|
||||
IG3YoQf4Zt9rKw28TfcSNhgYaGA3nAl6bIYgYyg7GDMBpLXEMg4lRbk/IvuTeUOt8/QQ/qCKRlsd
|
||||
usqjDGSj0cDkREL3GcB9GtJwlls4T73jiukHpueK5eKgJ8bdTNDLI8jE0kzqq/XsFb2qRpbahMmY
|
||||
hZKqw3qkjjuRQ61pAmST1H+7eU37kFzb9n/kR0ApdIx15TzWWp0nHts8xq/7r7bTlJNhEdBvtcKK
|
||||
Nfq4iRobOZjD/kV4Cox91WjbondeH35V46rlai6bFS6X1yLbZ78BAAD//wMAZdfoWWMDAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980b9e0c5fa516a0-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:15:11 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=w6UZxbAZgYg9EFkKPfrSbMK97MB4jfs7YyvcEmgkvak-1758143711-1.0.1.1-j7YC1nvoMKxYK0T.5G2XDF6TXUCPu_HUs4YO9v65r3NHQFIcOaHbQXX4vqabSgynL2tZy23pbZgD8Cdmxhdw9dp4zkAXhU.imP43_pw4dSE;
|
||||
path=/; expires=Wed, 17-Sep-25 21:45:11 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=ij9Q8tB7sj2GczANlJ7gbXVjj6hMhz1iVb6oGHuRYu8-1758143711202-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '462'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '665'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_04536db97c8c4768a200e38c1368c176
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Test Knowledge creation and querying functionality."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +22,7 @@ def mock_vector_db():
|
||||
instance = mock.return_value
|
||||
instance.query.return_value = [
|
||||
{
|
||||
"context": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"content": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -44,13 +43,13 @@ def test_single_short_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite color?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("blue" in result["context"].lower() for result in results)
|
||||
assert any("blue" in result["content"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -84,14 +83,14 @@ def test_single_2k_character_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -109,7 +108,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
|
||||
# Mock the vector db query response
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon has a dog named Max.", "score": 0.9}
|
||||
{"content": "Brandon has a dog named Max.", "score": 0.9}
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
@@ -119,7 +118,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("max" in result["context"].lower() for result in results)
|
||||
assert any("max" in result["content"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -180,7 +179,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
@@ -188,7 +187,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -205,13 +204,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What sport does Brandon like?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("basketball" in result["context"].lower() for result in results)
|
||||
assert any("basketball" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -247,13 +246,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -286,13 +285,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
|
||||
]
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon lives in New York.", "score": 0.9}
|
||||
{"content": "Brandon lives in New York.", "score": 0.9}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What city does he reside in?"
|
||||
results = mock_vector_db.query(query)
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("new york" in result["context"].lower() for result in results)
|
||||
assert any("new york" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -360,7 +359,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -370,7 +369,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -407,14 +406,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Combine string and file sources
|
||||
mock_vector_db.sources = string_sources + file_sources
|
||||
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("the alchemist" in result["context"].lower() for result in results)
|
||||
assert any("the alchemist" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -430,7 +429,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
)
|
||||
mock_vector_db.sources = [pdf_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
{"content": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -439,7 +438,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"crewai create crew latest-ai-development" in result["context"].lower()
|
||||
"crewai create crew latest-ai-development" in result["content"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -467,7 +466,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [csv_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -475,7 +474,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["context"] for result in results)
|
||||
assert any("30" in result["content"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -502,7 +501,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [json_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
{"content": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -510,7 +509,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("los angeles" in result["context"].lower() for result in results)
|
||||
assert any("los angeles" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -518,7 +517,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
"""Test ExcelKnowledgeSource with a simple Excel file."""
|
||||
|
||||
# Create an Excel file with sample data
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
excel_data = {
|
||||
"Name": ["Brandon", "Alice", "Bob"],
|
||||
@@ -535,7 +534,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [excel_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -543,7 +542,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["context"] for result in results)
|
||||
assert any("30" in result["content"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -557,20 +556,20 @@ def test_docling_source(mock_vector_db):
|
||||
mock_vector_db.sources = [docling_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What is reward hacking?"
|
||||
results = mock_vector_db.query(query)
|
||||
assert any("reward hacking" in result["context"].lower() for result in results)
|
||||
assert any("reward hacking" in result["content"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_multiple_docling_sources():
|
||||
urls: List[Union[Path, str]] = [
|
||||
def test_multiple_docling_sources() -> None:
|
||||
urls: list[Path | str] = [
|
||||
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
||||
]
|
||||
|
||||
191
tests/knowledge/test_knowledge_searchresult.py
Normal file
191
tests/knowledge/test_knowledge_searchresult.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Tests for Knowledge SearchResult type conversion and integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
|
||||
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
|
||||
StringKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
|
||||
def test_knowledge_query_returns_searchresult() -> None:
|
||||
"""Test that Knowledge.query returns SearchResult format."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = [
|
||||
{
|
||||
"content": "AI is fascinating",
|
||||
"score": 0.9,
|
||||
"metadata": {"source": "doc1"},
|
||||
},
|
||||
{
|
||||
"content": "Machine learning rocks",
|
||||
"score": 0.8,
|
||||
"metadata": {"source": "doc2"},
|
||||
},
|
||||
]
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test knowledge content")]
|
||||
knowledge = Knowledge(collection_name="test_collection", sources=sources)
|
||||
|
||||
results = knowledge.query(
|
||||
["AI technology"], results_limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
mock_storage.search.assert_called_once_with(
|
||||
["AI technology"], limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 2
|
||||
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
assert "content" in result
|
||||
assert "score" in result
|
||||
assert "metadata" in result
|
||||
|
||||
assert results[0]["content"] == "AI is fascinating"
|
||||
assert results[0]["score"] == 0.9
|
||||
assert results[1]["content"] == "Machine learning rocks"
|
||||
assert results[1]["score"] == 0.8
|
||||
|
||||
|
||||
def test_knowledge_query_with_empty_results() -> None:
|
||||
"""Test Knowledge.query with empty search results."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = []
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="empty_test", sources=sources)
|
||||
|
||||
results = knowledge.query(["nonexistent query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_searchresult() -> None:
|
||||
"""Test extract_knowledge_context works with SearchResult format."""
|
||||
search_results = [
|
||||
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
|
||||
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
|
||||
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Python is great for AI" in context
|
||||
assert "Machine learning algorithms" in context
|
||||
assert "Deep learning frameworks" in context
|
||||
|
||||
expected_content = (
|
||||
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
|
||||
)
|
||||
assert expected_content in context
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_empty_content() -> None:
|
||||
"""Test extract_knowledge_context handles empty or invalid content."""
|
||||
search_results = [
|
||||
{"content": "", "score": 0.5, "metadata": {}},
|
||||
{"content": None, "score": 0.4, "metadata": {}},
|
||||
{"score": 0.3, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
def test_extract_knowledge_context_filters_invalid_results() -> None:
|
||||
"""Test that extract_knowledge_context filters out invalid results."""
|
||||
search_results: list[dict[str, Any] | None] = [
|
||||
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
|
||||
{"content": "", "score": 0.8, "metadata": {}},
|
||||
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
|
||||
None,
|
||||
{"content": None, "score": 0.6, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Valid content 1" in context
|
||||
assert "Valid content 2" in context
|
||||
assert context.count("\n") == 1
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_storage_exception_handling(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge handles storage exceptions gracefully."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.side_effect = Exception("Storage error")
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="error_test", sources=sources)
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.storage = None
|
||||
knowledge.query(["test query"])
|
||||
|
||||
|
||||
def test_knowledge_add_sources_integration() -> None:
|
||||
"""Test Knowledge.add_sources integrates properly with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [
|
||||
StringKnowledgeSource(content="Content 1"),
|
||||
StringKnowledgeSource(content="Content 2"),
|
||||
]
|
||||
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
|
||||
|
||||
knowledge.add_sources()
|
||||
|
||||
for source in sources:
|
||||
assert source.storage == mock_storage
|
||||
|
||||
|
||||
def test_knowledge_reset_integration() -> None:
|
||||
"""Test Knowledge.reset integrates with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="reset_test", sources=sources)
|
||||
|
||||
knowledge.reset()
|
||||
|
||||
mock_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_reset_without_storage(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge.reset raises error when storage is None."""
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
|
||||
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.reset()
|
||||
196
tests/knowledge/test_knowledge_storage_integration.py
Normal file
196
tests/knowledge/test_knowledge_storage_integration.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Integration tests for KnowledgeStorage RAG client migration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_knowledge_storage_uses_rag_client(
|
||||
mock_get_embedding: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that KnowledgeStorage properly integrates with RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_create_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
|
||||
storage = KnowledgeStorage(
|
||||
embedder=embedder_config, collection_name="test_knowledge"
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
results = storage.search(["test query"], limit=5, score_threshold=0.3)
|
||||
|
||||
mock_get_client.assert_not_called()
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name="knowledge_test_knowledge",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter=None,
|
||||
score_threshold=0.3,
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], dict)
|
||||
assert "content" in results[0]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
|
||||
"""Test that collection names are properly prefixed."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage(collection_name="custom_knowledge")
|
||||
storage.search(["test"], limit=1)
|
||||
|
||||
mock_client.search.assert_called_once()
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage_default = KnowledgeStorage()
|
||||
storage_default.search(["test"], limit=1)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test document saving through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_docs")
|
||||
documents = ["Document 1 content", "Document 2 content"]
|
||||
|
||||
storage.save(documents)
|
||||
|
||||
mock_client.get_or_create_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_docs"
|
||||
)
|
||||
mock_client.add_documents.assert_called_once()
|
||||
|
||||
call_kwargs = mock_client.add_documents.call_args.kwargs
|
||||
added_docs = call_kwargs["documents"]
|
||||
assert len(added_docs) == 2
|
||||
assert added_docs[0]["content"] == "Document 1 content"
|
||||
assert added_docs[1]["content"] == "Document 2 content"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_reset_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test collection reset through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_reset")
|
||||
storage.reset()
|
||||
|
||||
mock_client.delete_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_reset"
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_search_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test error handling during search operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = Exception("RAG client error")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="error_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_embedding_configuration_flow(
|
||||
mock_get_embedding: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that embedding configuration flows properly to RAG client."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_get_embedding.return_value = mock_embedding_func
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
|
||||
"""Test that query list is properly converted to string."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
storage.search(["single query"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "single query"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["query one", "query two"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "query one query two"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test metadata filter parameter handling."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
metadata_filter = {"category": "technical", "priority": "high"}
|
||||
storage.search(["test"], metadata_filter=metadata_filter)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] == metadata_filter
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["test"], metadata_filter=None)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] is None
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test specific handling of dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_test")
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
storage.save(["test document"])
|
||||
@@ -1,19 +1,20 @@
|
||||
from unittest.mock import patch, ANY
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.task import Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -38,22 +39,23 @@ def short_term_memory():
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -173,12 +175,12 @@ def test_save_and_search(short_term_memory):
|
||||
|
||||
expected_result = [
|
||||
{
|
||||
"context": memory.data,
|
||||
"content": memory.data,
|
||||
"metadata": {"agent": "test_agent"},
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
with patch.object(ShortTermMemory, "search", return_value=expected_result):
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["content"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
|
||||
@@ -236,7 +236,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -247,7 +247,7 @@ class TestChromaDBClient:
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
@@ -262,7 +262,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with custom document IDs."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -285,6 +285,43 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
def test_add_documents_all_without_metadata(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2"},
|
||||
{"content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] is None
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
@@ -298,7 +335,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -313,7 +350,7 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
@@ -331,7 +368,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test aadd_documents with custom document IDs."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -358,6 +395,31 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_without_metadata(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
@@ -372,7 +434,7 @@ class TestChromaDBClient:
|
||||
"""Test that search queries the collection correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
@@ -382,13 +444,13 @@ class TestChromaDBClient:
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -404,7 +466,7 @@ class TestChromaDBClient:
|
||||
"""Test search with optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
@@ -437,7 +499,7 @@ class TestChromaDBClient:
|
||||
"""Test that asearch queries the collection correctly."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
@@ -453,13 +515,13 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -478,7 +540,7 @@ class TestChromaDBClient:
|
||||
"""Test asearch with optional parameters."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
|
||||
95
tests/rag/chromadb/test_utils.py
Normal file
95
tests/rag/chromadb/test_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_is_ipv4_pattern,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
"""Test suite for ChromaDB utility functions."""
|
||||
|
||||
def test_sanitize_collection_name_long_name(self) -> None:
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = _sanitize_collection_name(long_name)
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self) -> None:
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = _sanitize_collection_name(special_chars)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_short_name(self) -> None:
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = _sanitize_collection_name(short_name)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self) -> None:
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = _sanitize_collection_name(bad_ends)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_none(self) -> None:
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = _sanitize_collection_name(None)
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = _sanitize_collection_name(ipv4)
|
||||
assert sanitized.startswith("ip_")
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_is_ipv4_pattern(self) -> None:
|
||||
"""Test IPv4 pattern detection."""
|
||||
assert _is_ipv4_pattern("192.168.1.1") is True
|
||||
assert _is_ipv4_pattern("not.an.ip.address") is False
|
||||
|
||||
def test_sanitize_collection_name_properties(self) -> None:
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases: list[str] = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = _sanitize_collection_name(test_case)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_empty_string(self) -> None:
|
||||
"""Test sanitizing an empty string."""
|
||||
sanitized = _sanitize_collection_name("")
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_whitespace_only(self) -> None:
|
||||
"""Test sanitizing a string with only whitespace."""
|
||||
sanitized = _sanitize_collection_name(" ")
|
||||
assert (
|
||||
sanitized == "a__z"
|
||||
) # Spaces become underscores, padded to meet requirements
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
250
tests/rag/embeddings/test_factory_enhanced.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
@@ -3,7 +3,8 @@
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client import QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
@@ -435,7 +436,7 @@ class TestQdrantClient:
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
@@ -540,7 +541,7 @@ class TestQdrantClient:
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
|
||||
218
tests/rag/test_error_handling.py
Normal file
218
tests/rag/test_error_handling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for RAG client error handling scenarios."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles RAG client connection failures."""
|
||||
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="connection_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles search timeouts gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = TimeoutError("Search operation timed out")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="timeout_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles missing collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = ValueError(
|
||||
"Collection 'knowledge_missing' does not exist"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="missing_collection")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles invalid embedding configurations."""
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
|
||||
) as mock_get_embedding:
|
||||
mock_get_embedding.side_effect = ValueError(
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception(
|
||||
"attempt to write a readonly database"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="readonly_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test KnowledgeStorage reset handles non-existent collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="nonexistent_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "valid result", "metadata": {"source": "test"}},
|
||||
{"invalid": "missing content field", "metadata": {"source": "test"}},
|
||||
None,
|
||||
{"content": None, "metadata": {"source": "test"}},
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="malformed_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles network interruptions during operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_client.search.side_effect = [
|
||||
ConnectionError("Network interruption"),
|
||||
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="network_test")
|
||||
|
||||
first_attempt = storage.search(["test query"])
|
||||
assert first_attempt == []
|
||||
|
||||
mock_client.search.side_effect = None
|
||||
mock_client.search.return_value = [
|
||||
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
second_attempt = storage.search(["test query"])
|
||||
assert len(second_attempt) == 1
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test detailed handling of embedding dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception(
|
||||
"Embedding dimension mismatch: expected 384, got 1536"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
storage.save(["test document"])
|
||||
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
assert "Make sure you're using the same embedding model" in str(exc_info.value)
|
||||
assert "crewai reset-memories -a" in str(exc_info.value)
|
||||
@@ -1,8 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.client.main import MemoryClient
|
||||
from mem0.memory.main import Memory
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
@@ -13,11 +12,71 @@ class MockCrew:
|
||||
self.agents = [MagicMock(role="Test Agent")]
|
||||
|
||||
|
||||
# Test data constants
|
||||
SYSTEM_CONTENT = (
|
||||
"You are Friendly chatbot assistant. You are a kind and "
|
||||
"knowledgeable chatbot assistant. You excel at understanding user needs, "
|
||||
"providing helpful responses, and maintaining engaging conversations. "
|
||||
"You remember previous interactions to provide a personalized experience.\n"
|
||||
"Your personal goal is: Engage in useful and interesting conversations "
|
||||
"with users while remembering context.\n"
|
||||
"To give my best complete final answer to the task respond using the exact "
|
||||
"following format:\n\n"
|
||||
"Thought: I now can give a great answer\n"
|
||||
"Final Answer: Your final answer must be the great and the most complete "
|
||||
"as possible, it must be outcome described.\n\n"
|
||||
"I MUST use these formats, my job depends on it!"
|
||||
)
|
||||
|
||||
USER_CONTENT = (
|
||||
"\nCurrent Task: Respond to user conversation. User message: "
|
||||
"What do you know about me?\n\n"
|
||||
"This is the expected criteria for your final answer: Contextually "
|
||||
"appropriate, helpful, and friendly response.\n"
|
||||
"you MUST return the actual complete content as the final answer, "
|
||||
"not a summary.\n\n"
|
||||
"# Useful context: \nExternal memories:\n"
|
||||
"- User is from India\n"
|
||||
"- User is interested in the solar system\n"
|
||||
"- User name is Vidit Ostwal\n"
|
||||
"- User is interested in French cuisine\n\n"
|
||||
"Begin! This is VERY important to you, use the tools available and give "
|
||||
"your best Final Answer, your job depends on it!\n\n"
|
||||
"Thought:"
|
||||
)
|
||||
|
||||
ASSISTANT_CONTENT = (
|
||||
"I now can give a great answer \n"
|
||||
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
TEST_DESCRIPTION = (
|
||||
"Respond to user conversation. User message: What do you know about me?"
|
||||
)
|
||||
|
||||
# Extracted content (after processing by _get_user_message and _get_assistant_message)
|
||||
EXTRACTED_USER_CONTENT = "What do you know about me?"
|
||||
EXTRACTED_ASSISTANT_CONTENT = (
|
||||
"Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
"""Fixture to create a mock Memory instance"""
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
return mock_memory
|
||||
return MagicMock(spec=Memory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -25,7 +84,9 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# Patch the Memory class to return our mock
|
||||
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
|
||||
with patch(
|
||||
"mem0.Memory.from_config", return_value=mock_mem0_memory
|
||||
) as mock_from_config:
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mock_vector_store",
|
||||
@@ -56,7 +117,14 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True}
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"local_mem0_config": config,
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage, mock_from_config, config
|
||||
@@ -73,8 +141,7 @@ def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory_client():
|
||||
"""Fixture to create a mock MemoryClient instance"""
|
||||
mock_memory = MagicMock(spec=MemoryClient)
|
||||
return mock_memory
|
||||
return MagicMock(spec=MemoryClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -85,34 +152,35 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True
|
||||
}
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage
|
||||
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
|
||||
def mem0_storage_with_memory_client_using_explictly_config(
|
||||
mock_mem0_memory_client, mock_mem0_memory
|
||||
):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch both MemoryClient and Memory to prevent actual initialization
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
|
||||
|
||||
with (
|
||||
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
|
||||
):
|
||||
crew = MockCrew()
|
||||
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=new_config)
|
||||
return mem0_storage
|
||||
return Mem0Storage(type="short_term", crew=crew, config=new_config)
|
||||
|
||||
|
||||
def test_mem0_storage_with_memory_client_initialization(
|
||||
@@ -142,18 +210,23 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
mock_mem0_memory_client.update_project = MagicMock()
|
||||
|
||||
new_categories = [
|
||||
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
|
||||
{
|
||||
"lifestyle_management_concerns": (
|
||||
"Tracks daily routines, habits, hobbies and interests "
|
||||
"including cooking, time management and work-life balance"
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
crew = MockCrew()
|
||||
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories
|
||||
}
|
||||
config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories,
|
||||
}
|
||||
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
_ = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
@@ -163,8 +236,6 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
@@ -172,68 +243,134 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent'
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
|
||||
mem0_storage.crew.agents = [
|
||||
MagicMock(role="Test Agent"),
|
||||
MagicMock(role="Test Agent 2"),
|
||||
MagicMock(role="Test Agent 3"),
|
||||
]
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
|
||||
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
def test_save_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {"key": "value"}
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': test_value}],
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
includes="include1",
|
||||
excludes="exclude1",
|
||||
output_format='v1.1',
|
||||
user_id='test_user',
|
||||
agent_id='Test_Agent'
|
||||
output_format="v1.1",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -242,18 +379,25 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="test_user",
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
def test_search_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -263,15 +407,15 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
|
||||
limit=5,
|
||||
metadata={"type": "short_term"},
|
||||
user_id="test_user",
|
||||
version='v2',
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
output_format='v1.1',
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
output_format="v1.1",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
@@ -279,14 +423,12 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH"
|
||||
}
|
||||
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
assert mem0_storage.infer is True
|
||||
|
||||
|
||||
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
@@ -297,19 +439,25 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
mem0_storage.save("test memory", {"key": "value"})
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{'role': 'assistant' , 'content': 'test memory'}],
|
||||
[{"role": "assistant", "content": "test memory"}],
|
||||
infer=True,
|
||||
metadata={"type": "external", "key": "value"},
|
||||
agent_id="agent-123",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_agent_entity():
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
@@ -318,22 +466,29 @@ def test_search_method_with_agent_entity():
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
|
||||
mem0_storage = Mem0Storage(
|
||||
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
|
||||
)
|
||||
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -341,10 +496,10 @@ def test_search_method_with_agent_id_and_user_id():
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id='user-123',
|
||||
user_id="user-123",
|
||||
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["context"] == "Result 1"
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
216
tests/test_context.py
Normal file
216
tests/test_context.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# ruff: noqa: S105
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.context import (
|
||||
set_platform_integration_token,
|
||||
get_platform_integration_token,
|
||||
platform_context,
|
||||
_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformIntegrationToken:
|
||||
def setup_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
def teardown_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
def test_set_platform_integration_token(self):
|
||||
test_token = "test-token-123"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
set_platform_integration_token(test_token)
|
||||
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
def test_get_platform_integration_token_from_context_var(self):
|
||||
test_token = "context-var-token"
|
||||
|
||||
_platform_integration_token.set(test_token)
|
||||
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
|
||||
def test_get_platform_integration_token_from_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
assert get_platform_integration_token() == "env-token-456"
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
|
||||
def test_context_var_takes_precedence_over_env_var(self):
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(context_token)
|
||||
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_platform_integration_token_returns_none_when_not_set(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_basic_usage(self):
|
||||
test_token = "context-manager-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_nested_contexts(self):
|
||||
"""Test nested platform_context context managers."""
|
||||
outer_token = "outer-token"
|
||||
inner_token = "inner-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(outer_token):
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
with platform_context(inner_token):
|
||||
assert get_platform_integration_token() == inner_token
|
||||
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_preserves_existing_token(self):
|
||||
"""Test that platform_context preserves existing token when exiting."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(initial_token)
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_platform_context_manager_exception_handling(self):
|
||||
"""Test that platform_context properly resets token even when exception occurs."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(initial_token)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise ValueError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_platform_context_manager_with_none_initial_state(self):
|
||||
"""Test platform_context when initial state is None."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise RuntimeError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
|
||||
def test_platform_context_with_env_fallback(self):
|
||||
"""Test platform_context interaction with environment variable fallback."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
def test_multiple_sequential_context_managers(self):
|
||||
"""Test multiple sequential uses of platform_context."""
|
||||
token1 = "token-1"
|
||||
token2 = "token-2"
|
||||
token3 = "token-3"
|
||||
|
||||
with platform_context(token1):
|
||||
assert get_platform_integration_token() == token1
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token2):
|
||||
assert get_platform_integration_token() == token2
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token3):
|
||||
assert get_platform_integration_token() == token3
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_empty_string_token(self):
|
||||
empty_token = ""
|
||||
|
||||
set_platform_integration_token(empty_token)
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
with platform_context(empty_token):
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
def test_special_characters_in_token(self):
|
||||
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
|
||||
|
||||
set_platform_integration_token(special_token)
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
with platform_context(special_token):
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
def test_very_long_token(self):
|
||||
long_token = "a" * 10000
|
||||
|
||||
set_platform_integration_token(long_token)
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
with platform_context(long_token):
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
|
||||
def test_empty_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
@patch('crewai.context.os.getenv')
|
||||
def test_env_var_access_error_handling(self, mock_getenv):
|
||||
mock_getenv.side_effect = OSError("Environment access error")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
get_platform_integration_token()
|
||||
|
||||
def test_context_var_isolation_between_tests(self):
|
||||
"""Test that context variable changes don't leak between test methods."""
|
||||
test_token = "isolation-test-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
set_platform_integration_token(test_token)
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
|
||||
def test_context_manager_return_value(self):
|
||||
"""Test that platform_context can be used in with statement with return value."""
|
||||
test_token = "return-value-token"
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
with platform_context(test_token) as ctx:
|
||||
assert ctx is None
|
||||
assert get_platform_integration_token() == test_token
|
||||
@@ -1,17 +1,20 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import (
|
||||
TraceBatchManager,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.flow.flow import Flow, start
|
||||
|
||||
|
||||
class TestTraceListenerSetup:
|
||||
@@ -281,9 +284,9 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
|
||||
assert (
|
||||
len(trace_handlers) == 0
|
||||
), f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
assert len(trace_handlers) == 0, (
|
||||
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
)
|
||||
|
||||
def test_trace_listener_setup_correctly_for_crew(self):
|
||||
"""Test that trace listener is set up correctly when enabled"""
|
||||
@@ -403,3 +406,254 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
|
||||
"""Test first-time user trace collection logic with timeout behavior"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
return_value=False,
|
||||
) as mock_prompt,
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
assert trace_listener.first_time_handler.collected_events is False
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"handle_execution_completion",
|
||||
wraps=trace_listener.first_time_handler.handle_execution_completion,
|
||||
) as mock_handle_completion,
|
||||
patch.object(
|
||||
trace_listener.batch_manager,
|
||||
"add_event",
|
||||
wraps=trace_listener.batch_manager.add_event,
|
||||
) as mock_add_event,
|
||||
):
|
||||
result = crew.kickoff()
|
||||
assert result is not None
|
||||
|
||||
assert mock_handle_completion.call_count >= 1
|
||||
assert mock_add_event.call_count >= 1
|
||||
|
||||
assert trace_listener.first_time_handler.collected_events is True
|
||||
|
||||
mock_prompt.assert_called_once_with(timeout_seconds=20)
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_user_accepts(self, mock_plus_api_calls):
|
||||
"""Test first-time user trace collection when user accepts viewing traces"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"_initialize_backend_and_send_events",
|
||||
wraps=trace_listener.first_time_handler._initialize_backend_and_send_events,
|
||||
) as mock_init_backend,
|
||||
patch.object(
|
||||
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
|
||||
) as mock_display_link,
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"handle_execution_completion",
|
||||
wraps=trace_listener.first_time_handler.handle_execution_completion,
|
||||
) as mock_handle_completion,
|
||||
):
|
||||
trace_listener.batch_manager.ephemeral_trace_url = (
|
||||
"https://crewai.com/trace/mock-id"
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_handle_completion.call_count >= 1, (
|
||||
"handle_execution_completion should be called"
|
||||
)
|
||||
|
||||
assert trace_listener.first_time_handler.collected_events is True, (
|
||||
"Events should be marked as collected"
|
||||
)
|
||||
|
||||
mock_init_backend.assert_called_once()
|
||||
|
||||
mock_display_link.assert_called_once()
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
|
||||
"""Test the consolidation logic for first-time users vs regular tracing"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Test task", expected_output="test output", agent=agent
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_initialize.call_count >= 1
|
||||
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
|
||||
assert result is not None
|
||||
|
||||
def test_first_time_handler_timeout_behavior(self):
|
||||
"""Test the timeout behavior of the first-time trace prompt"""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch("threading.Thread") as mock_thread,
|
||||
):
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
prompt_user_for_trace_viewing,
|
||||
)
|
||||
|
||||
mock_thread_instance = Mock()
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
mock_thread.return_value = mock_thread_instance
|
||||
|
||||
result = prompt_user_for_trace_viewing(timeout_seconds=5)
|
||||
|
||||
assert result is False
|
||||
mock_thread.assert_called_once()
|
||||
call_args = mock_thread.call_args
|
||||
assert call_args[1]["daemon"] is True
|
||||
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=5)
|
||||
mock_thread_instance.is_alive.assert_called_once()
|
||||
|
||||
def test_first_time_handler_graceful_error_handling(self):
|
||||
"""Test graceful error handling in first-time trace logic"""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
side_effect=Exception("Prompt failed"),
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
handler = FirstTimeTraceHandler()
|
||||
handler.is_first_time = True
|
||||
handler.collected_events = True
|
||||
|
||||
handler.handle_execution_completion()
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.chromadb import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
is_ipv4_pattern,
|
||||
sanitize_collection_name,
|
||||
create_persistent_client,
|
||||
)
|
||||
|
||||
|
||||
def persistent_client_worker(path, queue):
|
||||
try:
|
||||
create_persistent_client(path=path)
|
||||
queue.put(None)
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
|
||||
class TestChromadbUtils(unittest.TestCase):
|
||||
def test_sanitize_collection_name_long_name(self):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = sanitize_collection_name(long_name)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self):
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = sanitize_collection_name(special_chars)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_short_name(self):
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = sanitize_collection_name(short_name)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self):
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = sanitize_collection_name(bad_ends)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_none(self):
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = sanitize_collection_name(None)
|
||||
self.assertEqual(sanitized, "default_collection")
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self):
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = sanitize_collection_name(ipv4)
|
||||
self.assertTrue(sanitized.startswith("ip_"))
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_is_ipv4_pattern(self):
|
||||
"""Test IPv4 pattern detection."""
|
||||
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
||||
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
||||
|
||||
def test_sanitize_collection_name_properties(self):
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = sanitize_collection_name(test_case)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_create_persistent_client_passes_args(self):
|
||||
with patch(
|
||||
"crewai.utilities.chromadb.PersistentClient"
|
||||
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_instance = MagicMock()
|
||||
mock_persistent_client.return_value = mock_instance
|
||||
|
||||
settings = Settings(allow_reset=True)
|
||||
client = create_persistent_client(path=tmpdir, settings=settings)
|
||||
|
||||
mock_persistent_client.assert_called_once_with(
|
||||
path=tmpdir, settings=settings
|
||||
)
|
||||
self.assertIs(client, mock_instance)
|
||||
|
||||
def test_create_persistent_client_process_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = [
|
||||
multiprocessing.Process(
|
||||
target=persistent_client_worker, args=(tmpdir, queue)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
errors = [queue.get(timeout=5) for _ in processes]
|
||||
self.assertTrue(all(err is None for err in errors))
|
||||
@@ -29,13 +29,15 @@ def mock_knowledge_source():
|
||||
"""
|
||||
return StringKnowledgeSource(content=content)
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
|
||||
def test_knowledge_included_in_planning(mock_chroma):
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
def test_knowledge_included_in_planning(mock_get_client):
|
||||
"""Test that verifies knowledge sources are properly included in planning."""
|
||||
# Mock ChromaDB collection
|
||||
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
|
||||
mock_collection.add.return_value = None
|
||||
|
||||
# Mock RAG client
|
||||
mock_client = mock_get_client.return_value
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.return_value = None
|
||||
|
||||
# Create an agent with knowledge
|
||||
agent = Agent(
|
||||
role="AI Researcher",
|
||||
@@ -45,14 +47,14 @@ def test_knowledge_included_in_planning(mock_chroma):
|
||||
StringKnowledgeSource(
|
||||
content="AI systems require careful training and validation."
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
task = Task(
|
||||
description="Explain the basics of AI systems",
|
||||
expected_output="A clear explanation of AI fundamentals",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew planner
|
||||
@@ -62,23 +64,29 @@ def test_knowledge_included_in_planning(mock_chroma):
|
||||
task_summary = planner._create_tasks_summary()
|
||||
|
||||
# Verify that knowledge is included in planning when present
|
||||
assert "AI systems require careful training" in task_summary, \
|
||||
assert "AI systems require careful training" in task_summary, (
|
||||
"Knowledge content should be present in task summary when knowledge exists"
|
||||
assert '"agent_knowledge"' in task_summary, \
|
||||
)
|
||||
assert '"agent_knowledge"' in task_summary, (
|
||||
"agent_knowledge field should be present in task summary when knowledge exists"
|
||||
)
|
||||
|
||||
# Verify that knowledge is properly formatted
|
||||
assert isinstance(task.agent.knowledge_sources, list), \
|
||||
assert isinstance(task.agent.knowledge_sources, list), (
|
||||
"Knowledge sources should be stored in a list"
|
||||
assert len(task.agent.knowledge_sources) > 0, \
|
||||
)
|
||||
assert len(task.agent.knowledge_sources) > 0, (
|
||||
"At least one knowledge source should be present"
|
||||
assert task.agent.knowledge_sources[0].content in task_summary, \
|
||||
)
|
||||
assert task.agent.knowledge_sources[0].content in task_summary, (
|
||||
"Knowledge source content should be included in task summary"
|
||||
)
|
||||
|
||||
# Verify that other expected components are still present
|
||||
assert task.description in task_summary, \
|
||||
assert task.description in task_summary, (
|
||||
"Task description should be present in task summary"
|
||||
assert task.expected_output in task_summary, \
|
||||
)
|
||||
assert task.expected_output in task_summary, (
|
||||
"Expected output should be present in task summary"
|
||||
assert agent.role in task_summary, \
|
||||
"Agent role should be present in task summary"
|
||||
)
|
||||
assert agent.role in task_summary, "Agent role should be present in task summary"
|
||||
|
||||
Reference in New Issue
Block a user