mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
Compare commits
3 Commits
devin/1758
...
gl/chore/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
777cb74ab9 | ||
|
|
d0641a8084 | ||
|
|
c7d80348ec |
46
.github/workflows/build-uv-cache.yml
vendored
46
.github/workflows/build-uv-cache.yml
vendored
@@ -1,46 +0,0 @@
|
||||
name: Build uv cache
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "uv.lock"
|
||||
- "pyproject.toml"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-cache:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies and populate cache
|
||||
run: |
|
||||
echo "Building global UV cache for Python ${{ matrix.python-version }}..."
|
||||
uv sync --all-groups --all-extras --no-install-project
|
||||
echo "Cache populated successfully"
|
||||
|
||||
- name: Save uv caches
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
102
.github/workflows/codeql.yml
vendored
102
.github/workflows/codeql.yml
vendored
@@ -1,102 +0,0 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL Advanced"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths-ignore:
|
||||
- "src/crewai/cli/templates/**"
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze (${{ matrix.language }})
|
||||
# Runner size impacts CodeQL analysis time. To learn more, please see:
|
||||
# - https://gh.io/recommended-hardware-resources-for-running-codeql
|
||||
# - https://gh.io/supported-runners-and-hardware-resources
|
||||
# - https://gh.io/using-larger-runners (GitHub.com only)
|
||||
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
|
||||
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
# only required for workflows in private repositories
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- language: actions
|
||||
build-mode: none
|
||||
- language: python
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Add any setup steps before running the `github/codeql-action/init` action.
|
||||
# This includes steps like installing compilers or runtimes (`actions/setup-node`
|
||||
# or others). This is typically only required for manual builds.
|
||||
# - name: Setup runtime (example)
|
||||
# uses: actions/setup-example@v1
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
# If the analyze step fails for one of the languages you are analyzing with
|
||||
# "We were unable to automatically build your code", modify the matrix above
|
||||
# to set the build mode to "manual" for that language. Then modify this step
|
||||
# to build your code.
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
- if: matrix.build-mode == 'manual'
|
||||
shell: bash
|
||||
run: |
|
||||
echo 'If you are using a "manual" build mode for one or more of the' \
|
||||
'languages you are analyzing, replace this with the commands to build' \
|
||||
'your code, for example:'
|
||||
echo ' make bootstrap'
|
||||
echo ' make release'
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
37
.github/workflows/linter.yml
vendored
37
.github/workflows/linter.yml
vendored
@@ -2,9 +2,6 @@ name: Lint
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -18,27 +15,19 @@ jobs:
|
||||
- name: Fetch Target Branch
|
||||
run: git fetch origin $TARGET_BRANCH --depth=1
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py3.11-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py3.11-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: "3.11"
|
||||
enable-cache: false
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-groups --all-extras --no-install-project
|
||||
run: uv sync --dev --no-install-project
|
||||
|
||||
- name: Get Changed Python Files
|
||||
id: changed-files
|
||||
@@ -56,13 +45,3 @@ jobs:
|
||||
| tr ' ' '\n' \
|
||||
| grep -v 'src/crewai/cli/templates/' \
|
||||
| xargs -I{} uv run ruff check "{}"
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py3.11-${{ hashFiles('uv.lock') }}
|
||||
|
||||
29
.github/workflows/security-checker.yml
vendored
Normal file
29
.github/workflows/security-checker.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Security Checker
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
security-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev --no-install-project
|
||||
|
||||
- name: Run Bandit
|
||||
run: uv run bandit -c pyproject.toml -r src/ -ll
|
||||
|
||||
65
.github/workflows/tests.yml
vendored
65
.github/workflows/tests.yml
vendored
@@ -3,7 +3,7 @@ name: Run Tests
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
@@ -22,76 +22,29 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py${{ matrix.python-version }}-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: false
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Restore test durations
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: .test_durations_py*
|
||||
key: test-durations-py${{ matrix.python-version }}
|
||||
run: uv sync --dev --all-extras
|
||||
|
||||
- name: Run tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||
DURATION_FILE=".test_durations_py${PYTHON_VERSION_SAFE}"
|
||||
|
||||
# Temporarily always skip cached durations to fix test splitting
|
||||
# When durations don't match, pytest-split runs duplicate tests instead of splitting
|
||||
echo "Using even test splitting (duration cache disabled until fix merged)"
|
||||
DURATIONS_ARG=""
|
||||
|
||||
# Original logic (disabled temporarily):
|
||||
# if [ ! -f "$DURATION_FILE" ]; then
|
||||
# echo "No cached durations found, tests will be split evenly"
|
||||
# DURATIONS_ARG=""
|
||||
# elif git diff origin/${{ github.base_ref }}...HEAD --name-only 2>/dev/null | grep -q "^tests/.*\.py$"; then
|
||||
# echo "Test files have changed, skipping cached durations to avoid mismatches"
|
||||
# DURATIONS_ARG=""
|
||||
# else
|
||||
# echo "No test changes detected, using cached test durations for optimal splitting"
|
||||
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
|
||||
# fi
|
||||
|
||||
uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
|
||||
36
.github/workflows/type-checker.yml
vendored
36
.github/workflows/type-checker.yml
vendored
@@ -3,7 +3,7 @@ name: Run Type Checks
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
type-checker-matrix:
|
||||
@@ -20,27 +20,19 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py${{ matrix.python-version }}-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: false
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-groups --all-extras
|
||||
run: uv sync --dev --no-install-project
|
||||
|
||||
- name: Get changed Python files
|
||||
id: changed-files
|
||||
@@ -74,16 +66,6 @@ jobs:
|
||||
if: steps.changed-files.outputs.has_changes == 'false'
|
||||
run: echo "No Python files in src/ were modified - skipping type checks"
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
|
||||
# Summary job to provide single status for branch protection
|
||||
type-checker:
|
||||
name: type-checker
|
||||
|
||||
71
.github/workflows/update-test-durations.yml
vendored
71
.github/workflows/update-test-durations.yml
vendored
@@ -1,71 +0,0 @@
|
||||
name: Update Test Durations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'tests/**/*.py'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update-durations:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-main-py${{ matrix.python-version }}-
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Run all tests and store durations
|
||||
run: |
|
||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||
uv run pytest --store-durations --durations-path=.test_durations_py${PYTHON_VERSION_SAFE} -n auto
|
||||
continue-on-error: true
|
||||
|
||||
- name: Save durations to cache
|
||||
if: always()
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: .test_durations_py*
|
||||
key: test-durations-py${{ matrix.python-version }}
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
~/.local/share/uv
|
||||
.venv
|
||||
key: uv-main-py${{ matrix.python-version }}-${{ hashFiles('uv.lock') }}
|
||||
@@ -1,19 +1,14 @@
|
||||
repos:
|
||||
- repo: local
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.11
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: uv run ruff check
|
||||
language: system
|
||||
types: [python]
|
||||
args: ["--config", "pyproject.toml"]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
entry: uv run ruff format
|
||||
language: system
|
||||
types: [python]
|
||||
args: ["--config", "pyproject.toml"]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: uv run mypy
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^tests/
|
||||
args: ["--config-file", "pyproject.toml"]
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Overview
|
||||
|
||||
The CrewAI framework provides a sophisticated memory system designed to significantly enhance AI agent capabilities. CrewAI offers **two distinct memory approaches** that serve different use cases:
|
||||
The CrewAI framework provides a sophisticated memory system designed to significantly enhance AI agent capabilities. CrewAI offers **three distinct memory approaches** that serve different use cases:
|
||||
|
||||
1. **Basic Memory System** - Built-in short-term, long-term, and entity memory
|
||||
2. **External Memory** - Standalone external memory providers
|
||||
|
||||
@@ -142,7 +142,7 @@ with MCPServerAdapter(server_params, "tool_name", connect_timeout=60) as mcp_too
|
||||
|
||||
## Using with CrewBase
|
||||
|
||||
To use MCPServer tools within a CrewBase class, use the `get_mcp_tools` method. Server configurations should be provided via the `mcp_server_params` attribute. You can pass either a single configuration or a list of multiple server configurations.
|
||||
To use MCPServer tools within a CrewBase class, use the `mcp_tools` method. Server configurations should be provided via the mcp_server_params attribute. You can pass either a single configuration or a list of multiple server configurations.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
@@ -175,34 +175,6 @@ class CrewWithMCP:
|
||||
# ... rest of your crew setup ...
|
||||
```
|
||||
|
||||
### Connection Timeout Configuration
|
||||
|
||||
You can configure the connection timeout for MCP servers by setting the `mcp_connect_timeout` class attribute. If no timeout is specified, it defaults to 30 seconds.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithMCP:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 60 # 60 seconds timeout for all MCP connections
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithDefaultTimeout:
|
||||
mcp_server_params = [...]
|
||||
# No mcp_connect_timeout specified - uses default 30 seconds
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
### Filtering Tools
|
||||
|
||||
You can filter which tools are available to your agent by passing a list of tool names to the `get_mcp_tools` method.
|
||||
|
||||
```python
|
||||
@@ -214,22 +186,6 @@ def another_agent(self):
|
||||
)
|
||||
```
|
||||
|
||||
The timeout configuration applies to all MCP tool calls within the crew:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithCustomTimeout:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 90 # 90 seconds timeout for all MCP connections
|
||||
|
||||
@agent
|
||||
def filtered_agent(self):
|
||||
return Agent(
|
||||
config=self.agents_config["your_agent"],
|
||||
tools=self.get_mcp_tools("tool_1", "tool_2") # specific tools with custom timeout
|
||||
)
|
||||
```
|
||||
|
||||
## Explore MCP Integrations
|
||||
|
||||
<CardGroup cols={2}>
|
||||
|
||||
@@ -7,8 +7,8 @@ mode: "wide"
|
||||
|
||||
## 개요
|
||||
|
||||
[Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)는 AI 에이전트가 MCP 서버로 알려진 외부 서비스와 통신함으로써 LLM에 컨텍스트를 제공할 수 있도록 표준화된 방식을 제공합니다.
|
||||
`crewai-tools` 라이브러리는 CrewAI의 기능을 확장하여, 이러한 MCP 서버에서 제공하는 툴을 에이전트에 원활하게 통합할 수 있도록 해줍니다.
|
||||
[Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)는 AI 에이전트가 MCP 서버로 알려진 외부 서비스와 통신함으로써 LLM에 컨텍스트를 제공할 수 있도록 표준화된 방식을 제공합니다.
|
||||
`crewai-tools` 라이브러리는 CrewAI의 기능을 확장하여, 이러한 MCP 서버에서 제공하는 툴을 에이전트에 원활하게 통합할 수 있도록 해줍니다.
|
||||
이를 통해 여러분의 crew는 방대한 기능 에코시스템에 접근할 수 있습니다.
|
||||
|
||||
현재 다음과 같은 전송 메커니즘을 지원합니다:
|
||||
@@ -142,7 +142,7 @@ with MCPServerAdapter(server_params, "tool_name", connect_timeout=60) as mcp_too
|
||||
|
||||
## CrewBase와 함께 사용하기
|
||||
|
||||
CrewBase 클래스 내에서 MCPServer 도구를 사용하려면 `get_mcp_tools` 메서드를 사용하세요. 서버 구성은 `mcp_server_params` 속성을 통해 제공되어야 합니다. 단일 구성 또는 여러 서버 구성을 리스트 형태로 전달할 수 있습니다.
|
||||
CrewBase 클래스 내에서 MCPServer 도구를 사용하려면 `mcp_tools` 메서드를 사용하세요. 서버 구성은 mcp_server_params 속성을 통해 제공되어야 합니다. 단일 구성 또는 여러 서버 구성을 리스트 형태로 전달할 수 있습니다.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
@@ -175,34 +175,6 @@ class CrewWithMCP:
|
||||
# ... 나머지 crew 설정 ...
|
||||
```
|
||||
|
||||
### 연결 타임아웃 구성
|
||||
|
||||
`mcp_connect_timeout` 클래스 속성을 설정하여 MCP 서버의 연결 타임아웃을 구성할 수 있습니다. 타임아웃을 지정하지 않으면 기본값으로 30초가 사용됩니다.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithMCP:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 60 # 모든 MCP 연결에 60초 타임아웃
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithDefaultTimeout:
|
||||
mcp_server_params = [...]
|
||||
# mcp_connect_timeout 지정하지 않음 - 기본 30초 사용
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
### 도구 필터링
|
||||
|
||||
`get_mcp_tools` 메서드에 도구 이름의 리스트를 전달하여, 에이전트에 제공되는 도구를 필터링할 수 있습니다.
|
||||
|
||||
```python
|
||||
@@ -214,22 +186,6 @@ def another_agent(self):
|
||||
)
|
||||
```
|
||||
|
||||
타임아웃 구성은 crew 내의 모든 MCP 도구 호출에 적용됩니다:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithCustomTimeout:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 90 # 모든 MCP 연결에 90초 타임아웃
|
||||
|
||||
@agent
|
||||
def filtered_agent(self):
|
||||
return Agent(
|
||||
config=self.agents_config["your_agent"],
|
||||
tools=self.get_mcp_tools("tool_1", "tool_2") # 사용자 지정 타임아웃으로 특정 도구
|
||||
)
|
||||
```
|
||||
|
||||
## MCP 통합 탐색
|
||||
|
||||
<CardGroup cols={2}>
|
||||
@@ -305,4 +261,4 @@ SSE 전송은 적절하게 보안되지 않은 경우 DNS 리바인딩 공격에
|
||||
|
||||
### 제한 사항
|
||||
* **지원되는 프리미티브**: 현재 `MCPServerAdapter`는 주로 MCP `tools`를 어댑팅하는 기능을 지원합니다. 다른 MCP 프리미티브(예: `prompts` 또는 `resources`)는 현재 이 어댑터를 통해 CrewAI 컴포넌트로 직접 통합되어 있지 않습니다.
|
||||
* **출력 처리**: 어댑터는 일반적으로 MCP tool의 주요 텍스트 출력(예: `.content[0].text`)을 처리합니다. 복잡하거나 멀티모달 출력의 경우 이 패턴에 맞지 않으면 별도의 커스텀 처리가 필요할 수 있습니다.
|
||||
* **출력 처리**: 어댑터는 일반적으로 MCP tool의 주요 텍스트 출력(예: `.content[0].text`)을 처리합니다. 복잡하거나 멀티모달 출력의 경우 이 패턴에 맞지 않으면 별도의 커스텀 처리가 필요할 수 있습니다.
|
||||
@@ -118,7 +118,7 @@ with MCPServerAdapter(server_params, connect_timeout=60) as mcp_tools:
|
||||
|
||||
## Usando com CrewBase
|
||||
|
||||
Para usar ferramentas de servidores MCP dentro de uma classe CrewBase, utilize o método `get_mcp_tools`. As configurações dos servidores devem ser fornecidas via o atributo `mcp_server_params`. Você pode passar uma configuração única ou uma lista com múltiplas configurações.
|
||||
Para usar ferramentas de servidores MCP dentro de uma classe CrewBase, utilize o método `mcp_tools`. As configurações dos servidores devem ser fornecidas via o atributo mcp_server_params. Você pode passar uma configuração única ou uma lista com múltiplas configurações.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
@@ -146,65 +146,10 @@ class CrewWithMCP:
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools()) # obter todas as ferramentas disponíveis
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools()) # você também pode filtrar quais ferramentas estarão disponíveis
|
||||
|
||||
# ... restante da configuração do seu crew ...
|
||||
```
|
||||
|
||||
### Configuração de Timeout de Conexão
|
||||
|
||||
Você pode configurar o timeout de conexão para servidores MCP definindo o atributo de classe `mcp_connect_timeout`. Se nenhum timeout for especificado, o padrão é 30 segundos.
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithMCP:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 60 # timeout de 60 segundos para todas as conexões MCP
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithDefaultTimeout:
|
||||
mcp_server_params = [...]
|
||||
# Nenhum mcp_connect_timeout especificado - usa padrão de 30 segundos
|
||||
|
||||
@agent
|
||||
def your_agent(self):
|
||||
return Agent(config=self.agents_config["your_agent"], tools=self.get_mcp_tools())
|
||||
```
|
||||
|
||||
### Filtragem de Ferramentas
|
||||
|
||||
Você pode filtrar quais ferramentas estão disponíveis para seu agente passando uma lista de nomes de ferramentas para o método `get_mcp_tools`.
|
||||
|
||||
```python
|
||||
@agent
|
||||
def another_agent(self):
|
||||
return Agent(
|
||||
config=self.agents_config["your_agent"],
|
||||
tools=self.get_mcp_tools("tool_1", "tool_2") # obter ferramentas específicas
|
||||
)
|
||||
```
|
||||
|
||||
A configuração de timeout se aplica a todas as chamadas de ferramentas MCP dentro do crew:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class CrewWithCustomTimeout:
|
||||
mcp_server_params = [...]
|
||||
mcp_connect_timeout = 90 # timeout de 90 segundos para todas as conexões MCP
|
||||
|
||||
@agent
|
||||
def filtered_agent(self):
|
||||
return Agent(
|
||||
config=self.agents_config["your_agent"],
|
||||
tools=self.get_mcp_tools("tool_1", "tool_2") # ferramentas específicas com timeout personalizado
|
||||
)
|
||||
```
|
||||
## Explore Integrações MCP
|
||||
|
||||
<CardGroup cols={2}>
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
"""
|
||||
Example demonstrating prompt caching with CrewAI for cost optimization.
|
||||
|
||||
This example shows how to use prompt caching with kickoff_for_each() and
|
||||
kickoff_async() to reduce costs when processing multiple similar inputs.
|
||||
"""
|
||||
|
||||
from crewai import Agent, Crew, Task, LLM
|
||||
import asyncio
|
||||
|
||||
|
||||
def create_crew_with_caching():
|
||||
"""Create a crew with prompt caching enabled."""
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
analyst = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze data and provide insights",
|
||||
backstory="""You are an experienced data analyst with expertise in
|
||||
statistical analysis, data visualization, and business intelligence.
|
||||
You have worked with various industries including finance, healthcare,
|
||||
and technology. Your approach is methodical and you always provide
|
||||
actionable insights based on data patterns.""",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
analysis_task = Task(
|
||||
description="""Analyze the following dataset: {dataset}
|
||||
|
||||
Please provide:
|
||||
1. Summary statistics
|
||||
2. Key patterns and trends
|
||||
3. Actionable recommendations
|
||||
4. Potential risks or concerns
|
||||
|
||||
Be thorough in your analysis and provide specific examples.""",
|
||||
expected_output="A comprehensive analysis report with statistics, trends, and recommendations",
|
||||
agent=analyst
|
||||
)
|
||||
|
||||
return Crew(agents=[analyst], tasks=[analysis_task])
|
||||
|
||||
|
||||
def example_kickoff_for_each():
|
||||
"""Example using kickoff_for_each with prompt caching."""
|
||||
print("Running kickoff_for_each example with prompt caching...")
|
||||
|
||||
crew = create_crew_with_caching()
|
||||
|
||||
datasets = [
|
||||
{"dataset": "Q1 2024 sales data showing 15% growth in mobile segment"},
|
||||
{"dataset": "Q2 2024 customer satisfaction scores with 4.2/5 average rating"},
|
||||
{"dataset": "Q3 2024 website traffic data with 25% increase in organic search"},
|
||||
{"dataset": "Q4 2024 employee engagement survey with 78% satisfaction rate"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(datasets)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Analysis {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
if crew.usage_metrics:
|
||||
print(f"\nTotal usage metrics:")
|
||||
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||
print(f"Prompt tokens: {crew.usage_metrics.prompt_tokens}")
|
||||
print(f"Completion tokens: {crew.usage_metrics.completion_tokens}")
|
||||
|
||||
|
||||
async def example_kickoff_for_each_async():
|
||||
"""Example using kickoff_for_each_async with prompt caching."""
|
||||
print("Running kickoff_for_each_async example with prompt caching...")
|
||||
|
||||
crew = create_crew_with_caching()
|
||||
|
||||
datasets = [
|
||||
{"dataset": "Marketing campaign A: 12% CTR, 3.5% conversion rate"},
|
||||
{"dataset": "Marketing campaign B: 8% CTR, 4.1% conversion rate"},
|
||||
{"dataset": "Marketing campaign C: 15% CTR, 2.8% conversion rate"}
|
||||
]
|
||||
|
||||
results = await crew.kickoff_for_each_async(datasets)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Async Analysis {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
if crew.usage_metrics:
|
||||
print(f"\nTotal async usage metrics:")
|
||||
print(f"Total tokens: {crew.usage_metrics.total_tokens}")
|
||||
|
||||
|
||||
def example_bedrock_caching():
|
||||
"""Example using AWS Bedrock with prompt caching."""
|
||||
print("Running Bedrock example with prompt caching...")
|
||||
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Legal Analyst",
|
||||
goal="Review legal documents and identify key clauses",
|
||||
backstory="Expert legal analyst with 10+ years experience in contract review",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Review this contract section: {contract_section}",
|
||||
expected_output="Summary of key legal points and potential issues",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
contract_sections = [
|
||||
{"contract_section": "Section 1: Payment terms and conditions"},
|
||||
{"contract_section": "Section 2: Intellectual property rights"},
|
||||
{"contract_section": "Section 3: Termination clauses"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(contract_sections)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Legal Review {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
|
||||
def example_openai_caching():
|
||||
"""Example using OpenAI with prompt caching."""
|
||||
print("Running OpenAI example with prompt caching...")
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Content Writer",
|
||||
goal="Create engaging content for different audiences",
|
||||
backstory="Professional content writer with expertise in various writing styles and formats",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Write a {content_type} about: {topic}",
|
||||
expected_output="Well-structured and engaging content piece",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
content_requests = [
|
||||
{"content_type": "blog post", "topic": "benefits of renewable energy"},
|
||||
{"content_type": "social media post", "topic": "importance of cybersecurity"},
|
||||
{"content_type": "newsletter", "topic": "latest AI developments"}
|
||||
]
|
||||
|
||||
results = crew.kickoff_for_each(content_requests)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Content Piece {i} ---")
|
||||
print(result.raw)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== CrewAI Prompt Caching Examples ===\n")
|
||||
|
||||
example_kickoff_for_each()
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
asyncio.run(example_kickoff_for_each_async())
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
example_bedrock_caching()
|
||||
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
example_openai_caching()
|
||||
@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools~=0.71.0"]
|
||||
tools = ["crewai-tools~=0.69.0"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -131,14 +131,10 @@ select = [
|
||||
"I001", # sort imports
|
||||
"I002", # remove unused imports
|
||||
]
|
||||
ignore = ["E501"] # ignore line too long globally
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101", "RET504"] # Allow assert statements and unnecessary assignments before return in tests
|
||||
ignore = ["E501"] # ignore line too long
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ["src/crewai/cli/templates", "tests/"]
|
||||
|
||||
exclude = ["src/crewai/cli/templates", "tests"]
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -1,21 +1,6 @@
|
||||
import threading
|
||||
import urllib.request
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
@@ -35,12 +20,27 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
return None
|
||||
return original_warn(message, category, stacklevel + 1, source)
|
||||
|
||||
warnings.warn = filtered_warn # type: ignore[assignment]
|
||||
setattr(warnings, "warn", filtered_warn)
|
||||
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.186.1"
|
||||
import threading
|
||||
import urllib.request
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -54,12 +54,13 @@ def _track_install() -> None:
|
||||
try:
|
||||
pixel_url = "https://api.scarf.sh/v2/packages/CrewAI/crewai/docs/00f2dad1-8334-4a39-934e-003b2e1146db"
|
||||
|
||||
req = urllib.request.Request(pixel_url) # noqa: S310
|
||||
req = urllib.request.Request(pixel_url)
|
||||
req.add_header("User-Agent", f"CrewAI-Python/{__version__}")
|
||||
|
||||
with urllib.request.urlopen(req, timeout=2): # noqa: S310
|
||||
with urllib.request.urlopen(req, timeout=2): # nosec B310
|
||||
_telemetry_submitted = True
|
||||
except Exception: # noqa: S110
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -71,17 +72,19 @@ def _track_install_async() -> None:
|
||||
|
||||
|
||||
_track_install_async()
|
||||
|
||||
__version__ = "0.177.0"
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"Agent",
|
||||
"BaseLLM",
|
||||
"Crew",
|
||||
"CrewOutput",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"LLMGuardrail",
|
||||
"Process",
|
||||
"Task",
|
||||
"LLM",
|
||||
"BaseLLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"TaskOutput",
|
||||
"LLMGuardrail",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -1,58 +1,29 @@
|
||||
"""Base converter adapter for structured output conversion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class BaseConverterAdapter(ABC):
|
||||
"""Abstract base class for converter adapters in CrewAI.
|
||||
"""Base class for all converter adapters in CrewAI.
|
||||
|
||||
Defines the common interface for converting agent outputs to structured formats.
|
||||
All converter adapters must implement the methods defined here.
|
||||
This abstract class defines the common interface and functionality that all
|
||||
converter adapters must implement for converting structured output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter: BaseAgentAdapter) -> None:
|
||||
"""Initialize the converter adapter.
|
||||
|
||||
Args:
|
||||
agent_adapter: The agent adapter to configure for structured output.
|
||||
"""
|
||||
def __init__(self, agent_adapter):
|
||||
self.agent_adapter = agent_adapter
|
||||
|
||||
@abstractmethod
|
||||
def configure_structured_output(self, task: Task) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure agents to return structured output.
|
||||
|
||||
Must support both JSON and Pydantic output formats.
|
||||
|
||||
Args:
|
||||
task: The task requiring structured output.
|
||||
Must support json and pydantic output.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""Enhance the system prompt with structured output instructions.
|
||||
|
||||
Args:
|
||||
base_prompt: The original system prompt.
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with structured output guidance.
|
||||
"""
|
||||
"""Enhance the system prompt with structured output instructions."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure proper string format.
|
||||
|
||||
Args:
|
||||
result: The raw result from agent execution.
|
||||
|
||||
Returns:
|
||||
Processed result as a string.
|
||||
"""
|
||||
"""Post-process the result to ensure it matches the expected format: string."""
|
||||
pass
|
||||
|
||||
@@ -1,56 +1,47 @@
|
||||
"""LangGraph agent adapter for CrewAI integration.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
This module contains the LangGraphAgentAdapter class that integrates LangGraph ReAct agents
|
||||
with CrewAI's agent system. Provides memory persistence, tool integration, and structured
|
||||
output functionality.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
|
||||
from crewai.agents.agent_adapters.langgraph.langgraph_tool_adapter import (
|
||||
LangGraphToolAdapter,
|
||||
)
|
||||
from crewai.agents.agent_adapters.langgraph.protocols import (
|
||||
LangGraphCheckPointMemoryModule,
|
||||
LangGraphPrebuiltModule,
|
||||
)
|
||||
from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
|
||||
LangGraphConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
LANGGRAPH_AVAILABLE = True
|
||||
except ImportError:
|
||||
LANGGRAPH_AVAILABLE = False
|
||||
|
||||
|
||||
class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for LangGraph agents to work with CrewAI.
|
||||
"""Adapter for LangGraph agents to work with CrewAI."""
|
||||
|
||||
This adapter integrates LangGraph's ReAct agents with CrewAI's agent system,
|
||||
providing memory persistence, tool integration, and structured output support.
|
||||
"""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
_tool_adapter: LangGraphToolAdapter = PrivateAttr()
|
||||
_graph: Any = PrivateAttr(default=None)
|
||||
_memory: Any = PrivateAttr(default=None)
|
||||
_max_iterations: int = PrivateAttr(default=10)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Callable[..., Any] | None = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
|
||||
model: str = Field(default="gpt-4o")
|
||||
verbose: bool = Field(default=False)
|
||||
@@ -60,24 +51,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: list[BaseTool] | None = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter.
|
||||
|
||||
Args:
|
||||
role: The role description for the agent.
|
||||
goal: The primary goal the agent should achieve.
|
||||
backstory: Background information about the agent.
|
||||
tools: Optional list of tools available to the agent.
|
||||
llm: Language model to use, defaults to gpt-4o.
|
||||
max_iterations: Maximum number of iterations for task execution.
|
||||
agent_config: Additional configuration for the LangGraph agent.
|
||||
**kwargs: Additional arguments passed to the base adapter.
|
||||
"""
|
||||
):
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
raise ImportError(
|
||||
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
@@ -88,65 +72,46 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = LangGraphToolAdapter(tools=tools)
|
||||
self._converter_adapter: LangGraphConverterAdapter = LangGraphConverterAdapter(
|
||||
self
|
||||
)
|
||||
self._converter_adapter = LangGraphConverterAdapter(self)
|
||||
self._max_iterations = max_iterations
|
||||
self._setup_graph()
|
||||
|
||||
def _setup_graph(self) -> None:
|
||||
"""Set up the LangGraph workflow graph.
|
||||
"""Set up the LangGraph workflow graph."""
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
|
||||
Initializes the memory saver and creates a ReAct agent with the configured
|
||||
tools, memory checkpointer, and debug settings.
|
||||
"""
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools,
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
**self._agent_config,
|
||||
)
|
||||
else:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools or [],
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
)
|
||||
|
||||
memory_saver: type[Any] = cast(
|
||||
LangGraphCheckPointMemoryModule,
|
||||
require(
|
||||
"langgraph.checkpoint.memory",
|
||||
purpose="LangGraph core functionality",
|
||||
),
|
||||
).MemorySaver
|
||||
create_react_agent: Callable[..., Any] = cast(
|
||||
LangGraphPrebuiltModule,
|
||||
require(
|
||||
"langgraph.prebuilt",
|
||||
purpose="LangGraph core functionality",
|
||||
),
|
||||
).create_react_agent
|
||||
|
||||
self._memory = memory_saver()
|
||||
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools,
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
**self._agent_config,
|
||||
)
|
||||
else:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools or [],
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
except ImportError as e:
|
||||
self._logger.log(
|
||||
"error", f"Failed to import LangGraph dependencies: {str(e)}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
|
||||
raise
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the LangGraph agent.
|
||||
|
||||
Creates a prompt that includes the agent's role, goal, and backstory,
|
||||
then enhances it through the converter adapter for structured output.
|
||||
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -158,25 +123,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow.
|
||||
|
||||
Configures the agent, processes the task through the LangGraph workflow,
|
||||
and handles event emission for execution tracking.
|
||||
|
||||
Args:
|
||||
task: The task object to execute.
|
||||
context: Optional context information for the task.
|
||||
tools: Optional additional tools for this specific execution.
|
||||
|
||||
Returns:
|
||||
The final answer from the task execution.
|
||||
|
||||
Raises:
|
||||
Exception: If task execution fails.
|
||||
"""
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
self.configure_structured_output(task)
|
||||
@@ -201,11 +151,9 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
session_id = f"task_{id(task)}"
|
||||
|
||||
config: dict[str, dict[str, str]] = {
|
||||
"configurable": {"thread_id": session_id}
|
||||
}
|
||||
config = {"configurable": {"thread_id": session_id}}
|
||||
|
||||
result: dict[str, Any] = self._graph.invoke(
|
||||
result = self._graph.invoke(
|
||||
{
|
||||
"messages": [
|
||||
("system", self._build_system_prompt()),
|
||||
@@ -215,10 +163,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
config,
|
||||
)
|
||||
|
||||
messages: list[Any] = result.get("messages", [])
|
||||
last_message: Any = messages[-1] if messages else None
|
||||
messages = result.get("messages", [])
|
||||
last_message = messages[-1] if messages else None
|
||||
|
||||
final_answer: str = ""
|
||||
final_answer = ""
|
||||
if isinstance(last_message, dict):
|
||||
final_answer = last_message.get("content", "")
|
||||
elif hasattr(last_message, "content"):
|
||||
@@ -238,7 +186,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
|
||||
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -249,67 +197,29 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the LangGraph agent for execution.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the agent.
|
||||
"""
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the LangGraph agent.
|
||||
|
||||
Merges additional tools with existing ones and updates the graph's
|
||||
available tools through the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional additional tools to configure.
|
||||
"""
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
self._tool_adapter.configure_tools(all_tools)
|
||||
available_tools: list[Any] = self._tool_adapter.tools()
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph.
|
||||
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
Args:
|
||||
agents: List of agents available for delegation.
|
||||
|
||||
Returns:
|
||||
List of delegation tools.
|
||||
"""
|
||||
agent_tools: AgentTools = AgentTools(agents=agents)
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
@staticmethod
|
||||
def get_output_converter(
|
||||
llm: Any, text: str, model: Any, instructions: str
|
||||
) -> Converter:
|
||||
"""Convert output format if needed.
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
text: Text to convert.
|
||||
model: Model configuration.
|
||||
instructions: Conversion instructions.
|
||||
|
||||
Returns:
|
||||
Converter instance for output transformation.
|
||||
"""
|
||||
self, llm: Any, text: str, model: Any, instructions: str
|
||||
) -> Any:
|
||||
"""Convert output format if needed."""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph.
|
||||
|
||||
Uses the converter adapter to set up structured output formatting
|
||||
based on the task requirements.
|
||||
|
||||
Args:
|
||||
task: Task object containing output requirements.
|
||||
"""
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,72 +1,38 @@
|
||||
"""LangGraph tool adapter for CrewAI tool integration.
|
||||
|
||||
This module contains the LangGraphToolAdapter class that converts CrewAI tools
|
||||
to LangGraph-compatible format using langchain_core.tools.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LangGraphToolAdapter(BaseToolAdapter):
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format.
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
|
||||
|
||||
Converts CrewAI BaseTool instances to langchain_core.tools format
|
||||
that can be used by LangGraph agents.
|
||||
"""
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Initialize the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional list of CrewAI tools to adapt.
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.original_tools: list[BaseTool] = tools or []
|
||||
self.converted_tools: list[Any] = []
|
||||
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
|
||||
LangGraph expects tools in langchain_core.tools format. This method
|
||||
converts CrewAI BaseTool instances to StructuredTool instances.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to convert.
|
||||
Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
LangGraph expects tools in langchain_core.tools format.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool as LangChainBaseTool
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
|
||||
converted_tools: list[Any] = []
|
||||
converted_tools = []
|
||||
if self.original_tools:
|
||||
all_tools: list[BaseTool] = tools + self.original_tools
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
for tool in all_tools:
|
||||
if isinstance(tool, LangChainBaseTool):
|
||||
if isinstance(tool, BaseTool):
|
||||
converted_tools.append(tool)
|
||||
continue
|
||||
|
||||
sanitized_name: str = self.sanitize_tool_name(tool.name)
|
||||
sanitized_name = self.sanitize_tool_name(tool.name)
|
||||
|
||||
async def tool_wrapper(
|
||||
*args: Any, tool: BaseTool = tool, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrapper function to adapt CrewAI tool calls to LangGraph format.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
tool: The CrewAI tool to wrap.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result from the tool execution.
|
||||
"""
|
||||
output: Any | Awaitable[Any]
|
||||
async def tool_wrapper(*args, tool=tool, **kwargs):
|
||||
output = None
|
||||
if len(args) > 0 and isinstance(args[0], str):
|
||||
output = tool.run(args[0])
|
||||
elif "input" in kwargs:
|
||||
@@ -75,12 +41,12 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
output = tool.run(**kwargs)
|
||||
|
||||
if inspect.isawaitable(output):
|
||||
result: Any = await output
|
||||
result = await output
|
||||
else:
|
||||
result = output
|
||||
return result
|
||||
|
||||
converted_tool: StructuredTool = StructuredTool(
|
||||
converted_tool = StructuredTool(
|
||||
name=sanitized_name,
|
||||
description=tool.description,
|
||||
func=tool_wrapper,
|
||||
@@ -91,10 +57,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
|
||||
self.converted_tools = converted_tools
|
||||
|
||||
def tools(self) -> list[Any]:
|
||||
"""Get the list of converted tools.
|
||||
|
||||
Returns:
|
||||
List of LangGraph-compatible tools.
|
||||
"""
|
||||
def tools(self) -> List[Any]:
|
||||
return self.converted_tools or []
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Type protocols for LangGraph modules."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphMemorySaver(Protocol):
|
||||
"""Protocol for LangGraph MemorySaver.
|
||||
|
||||
Defines the interface for LangGraph's memory persistence mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the memory saver."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphCheckPointMemoryModule(Protocol):
|
||||
"""Protocol for LangGraph checkpoint memory module.
|
||||
|
||||
Defines the interface for modules containing memory checkpoint functionality.
|
||||
"""
|
||||
|
||||
MemorySaver: type[LangGraphMemorySaver]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LangGraphPrebuiltModule(Protocol):
|
||||
"""Protocol for LangGraph prebuilt module.
|
||||
|
||||
Defines the interface for modules containing prebuilt agent factories.
|
||||
"""
|
||||
|
||||
def create_react_agent(
|
||||
self,
|
||||
model: Any,
|
||||
tools: list[Any],
|
||||
checkpointer: Any,
|
||||
debug: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Create a ReAct agent with the given configuration.
|
||||
|
||||
Args:
|
||||
model: The language model to use for the agent.
|
||||
tools: List of tools available to the agent.
|
||||
checkpointer: Memory checkpointer for state persistence.
|
||||
debug: Whether to enable debug mode.
|
||||
**kwargs: Additional configuration options.
|
||||
|
||||
Returns:
|
||||
The configured ReAct agent instance.
|
||||
"""
|
||||
...
|
||||
@@ -1,45 +1,21 @@
|
||||
"""LangGraph structured output converter for CrewAI task integration.
|
||||
|
||||
This module contains the LangGraphConverterAdapter class that handles structured
|
||||
output conversion for LangGraph agents, supporting JSON and Pydantic model formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents.
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
|
||||
Converts task output requirements into system prompt modifications and
|
||||
post-processing logic to ensure agents return properly structured outputs.
|
||||
"""
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._system_prompt_appendix = None
|
||||
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter.
|
||||
|
||||
Args:
|
||||
agent_adapter: The LangGraph agent adapter instance.
|
||||
"""
|
||||
super().__init__(agent_adapter=agent_adapter)
|
||||
self.agent_adapter: Any = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: str | None = None
|
||||
self._system_prompt_appendix: str | None = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph.
|
||||
|
||||
Analyzes the task's output requirements and sets up the necessary
|
||||
formatting and validation logic.
|
||||
|
||||
Args:
|
||||
task: The task object containing output format specifications.
|
||||
"""
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
if not (task.output_json or task.output_pydantic):
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
@@ -56,14 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
self._system_prompt_appendix = self._generate_system_prompt_appendix()
|
||||
|
||||
def _generate_system_prompt_appendix(self) -> str:
|
||||
"""Generate an appendix for the system prompt to enforce structured output.
|
||||
|
||||
Creates instructions that are appended to the system prompt to guide
|
||||
the agent in producing properly formatted output.
|
||||
|
||||
Returns:
|
||||
System prompt appendix string, or empty string if no structured output.
|
||||
"""
|
||||
"""Generate an appendix for the system prompt to enforce structured output"""
|
||||
if not self._output_format or not self._schema:
|
||||
return ""
|
||||
|
||||
@@ -72,36 +41,19 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
def enhance_system_prompt(self, original_prompt: str) -> str:
|
||||
"""Add structured output instructions to the system prompt if needed.
|
||||
|
||||
Args:
|
||||
original_prompt: The base system prompt.
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with structured output instructions.
|
||||
"""
|
||||
"""Add structured output instructions to the system prompt if needed"""
|
||||
if not self._system_prompt_appendix:
|
||||
return original_prompt
|
||||
|
||||
return f"{original_prompt}\n{self._system_prompt_appendix}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
|
||||
Attempts to extract and validate JSON content from agent responses,
|
||||
handling cases where JSON may be wrapped in markdown or other formatting.
|
||||
|
||||
Args:
|
||||
result: The raw result string from the agent.
|
||||
|
||||
Returns:
|
||||
Processed result string, ideally in valid JSON format.
|
||||
"""
|
||||
"""Post-process the result to ensure it matches the expected format"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
@@ -113,16 +65,16 @@ The output should be raw JSON that exactly matches the specified schema.
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# Try to extract JSON from the text
|
||||
json_match: re.Match[str] | None = re.search(
|
||||
r"(\{.*})", result, re.DOTALL
|
||||
)
|
||||
import re
|
||||
|
||||
json_match = re.search(r"(\{.*\})", result, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
extracted: str = json_match.group(1)
|
||||
extracted = json_match.group(1)
|
||||
# Validate it's proper JSON
|
||||
json.loads(extracted)
|
||||
return extracted
|
||||
except json.JSONDecodeError:
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,99 +1,78 @@
|
||||
"""OpenAI agents adapter for CrewAI integration.
|
||||
from typing import Any, List, Optional
|
||||
|
||||
This module contains the OpenAIAgentAdapter class that integrates OpenAI Assistants
|
||||
with CrewAI's agent system, providing tool integration and structured output support.
|
||||
"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
from typing_extensions import Unpack
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai.agents.agent_adapters.base_agent_adapter import BaseAgentAdapter
|
||||
from crewai.agents.agent_adapters.openai_agents.openai_agent_tool_adapter import (
|
||||
OpenAIAgentToolAdapter,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
AgentKwargs,
|
||||
OpenAIAgentsModule,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
OpenAIAgent as OpenAIAgentProtocol,
|
||||
)
|
||||
from crewai.agents.agent_adapters.openai_agents.structured_output_converter import (
|
||||
OpenAIConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
openai_agents_module = cast(
|
||||
OpenAIAgentsModule,
|
||||
require(
|
||||
"agents",
|
||||
purpose="OpenAI agents functionality",
|
||||
),
|
||||
)
|
||||
OpenAIAgent = openai_agents_module.Agent
|
||||
Runner = openai_agents_module.Runner
|
||||
enable_verbose_stdout_logging = openai_agents_module.enable_verbose_stdout_logging
|
||||
try:
|
||||
from agents import Agent as OpenAIAgent # type: ignore
|
||||
from agents import Runner, enable_verbose_stdout_logging # type: ignore
|
||||
|
||||
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
|
||||
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
|
||||
|
||||
class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for OpenAI Assistants.
|
||||
"""Adapter for OpenAI Assistants"""
|
||||
|
||||
Integrates OpenAI Assistants API with CrewAI's agent system, providing
|
||||
tool configuration, structured output handling, and task execution.
|
||||
"""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_openai_agent: OpenAIAgentProtocol = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=Logger)
|
||||
_active_thread: str | None = PrivateAttr(default=None)
|
||||
_openai_agent: "OpenAIAgent" = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
_active_thread: Optional[str] = PrivateAttr(default=None)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: OpenAIAgentToolAdapter = PrivateAttr()
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Unpack[AgentKwargs],
|
||||
) -> None:
|
||||
"""Initialize the OpenAI agent adapter.
|
||||
|
||||
Args:
|
||||
**kwargs: All initialization arguments including role, goal, backstory,
|
||||
model, tools, and agent_config.
|
||||
|
||||
Raises:
|
||||
ImportError: If OpenAI agent dependencies are not installed.
|
||||
"""
|
||||
self.llm = kwargs.pop("model", "gpt-4o-mini")
|
||||
super().__init__(**kwargs)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=kwargs.get("tools"))
|
||||
self._converter_adapter = OpenAIConverterAdapter(agent_adapter=self)
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
)
|
||||
else:
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the OpenAI agent.
|
||||
|
||||
Creates a prompt containing the agent's role, goal, and backstory,
|
||||
then enhances it with structured output instructions if needed.
|
||||
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -105,25 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task using the OpenAI Assistant.
|
||||
|
||||
Configures the assistant, processes the task, and handles event emission
|
||||
for execution tracking.
|
||||
|
||||
Args:
|
||||
task: The task object to execute.
|
||||
context: Optional context information for the task.
|
||||
tools: Optional additional tools for this execution.
|
||||
|
||||
Returns:
|
||||
The final answer from the task execution.
|
||||
|
||||
Raises:
|
||||
Exception: If task execution fails.
|
||||
"""
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
@@ -131,7 +95,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
enable_verbose_stdout_logging()
|
||||
|
||||
try:
|
||||
task_prompt: str = task.prompt()
|
||||
task_prompt = task.prompt()
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
@@ -145,8 +109,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
result: Any = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer: str = self.handle_execution_result(result)
|
||||
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer = self.handle_execution_result(result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
@@ -156,7 +120,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
|
||||
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -167,22 +131,15 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the OpenAI agent for execution.
|
||||
|
||||
While OpenAI handles execution differently through Runner,
|
||||
this method sets up tools and agent configuration.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the agent.
|
||||
|
||||
Notes:
|
||||
TODO: Properly type agent_executor in BaseAgent to avoid type issues
|
||||
when assigning Runner class to this attribute.
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
all_tools: list[BaseTool] = list(self.tools or []) + list(tools or [])
|
||||
Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
we can use this method to set up tools and configurations.
|
||||
"""
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
|
||||
instructions: str = self._build_system_prompt()
|
||||
instructions = self._build_system_prompt()
|
||||
self._openai_agent = OpenAIAgent(
|
||||
name=self.role,
|
||||
instructions=instructions,
|
||||
@@ -195,48 +152,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant.
|
||||
|
||||
Args:
|
||||
tools: Optional tools to configure for the assistant.
|
||||
"""
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
"""Process OpenAI Assistant execution result.
|
||||
|
||||
Converts any structured output to a string through the converter adapter.
|
||||
|
||||
Args:
|
||||
result: The execution result from the OpenAI assistant.
|
||||
|
||||
Returns:
|
||||
Processed result as a string.
|
||||
"""
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support.
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
Creates delegation tools that allow this agent to delegate tasks to other agents.
|
||||
|
||||
Args:
|
||||
agents: List of agents available for delegation.
|
||||
|
||||
Returns:
|
||||
List of delegation tools.
|
||||
"""
|
||||
agent_tools: AgentTools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
task: The task object containing output format specifications.
|
||||
structured_output: The structured output to be configured
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,125 +1,57 @@
|
||||
"""OpenAI agent tool adapter for CrewAI tool integration.
|
||||
|
||||
This module contains the OpenAIAgentToolAdapter class that converts CrewAI tools
|
||||
to OpenAI Assistant-compatible format using the agents library.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, cast
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from agents import FunctionTool, Tool
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.agents.agent_adapters.openai_agents.protocols import (
|
||||
OpenAIFunctionTool,
|
||||
OpenAITool,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
agents_module = cast(
|
||||
Any,
|
||||
require(
|
||||
"agents",
|
||||
purpose="OpenAI agents functionality",
|
||||
),
|
||||
)
|
||||
FunctionTool = agents_module.FunctionTool
|
||||
Tool = agents_module.Tool
|
||||
|
||||
|
||||
class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
"""Adapter for OpenAI Assistant tools.
|
||||
"""Adapter for OpenAI Assistant tools"""
|
||||
|
||||
Converts CrewAI BaseTool instances to OpenAI Assistant FunctionTool format
|
||||
that can be used by OpenAI agents.
|
||||
"""
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.original_tools = tools or []
|
||||
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Initialize the tool adapter.
|
||||
|
||||
Args:
|
||||
tools: Optional list of CrewAI tools to adapt.
|
||||
"""
|
||||
super().__init__()
|
||||
self.original_tools: list[BaseTool] = tools or []
|
||||
self.converted_tools: list[OpenAITool] = []
|
||||
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant.
|
||||
|
||||
Merges provided tools with original tools and converts them to
|
||||
OpenAI Assistant format.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to configure.
|
||||
"""
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if self.original_tools:
|
||||
all_tools: list[BaseTool] = tools + self.original_tools
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
if all_tools:
|
||||
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
|
||||
|
||||
@staticmethod
|
||||
def _convert_tools_to_openai_format(
|
||||
tools: list[BaseTool] | None,
|
||||
) -> list[OpenAITool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tools to convert.
|
||||
|
||||
Returns:
|
||||
List of OpenAI Assistant FunctionTool instances.
|
||||
"""
|
||||
self, tools: Optional[List[BaseTool]]
|
||||
) -> List[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format"""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
def sanitize_tool_name(name: str) -> str:
|
||||
"""Convert tool name to match OpenAI's required pattern.
|
||||
"""Convert tool name to match OpenAI's required pattern"""
|
||||
import re
|
||||
|
||||
Args:
|
||||
name: Original tool name.
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
return sanitized
|
||||
|
||||
Returns:
|
||||
Sanitized tool name matching OpenAI requirements.
|
||||
"""
|
||||
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
|
||||
def create_tool_wrapper(tool: BaseTool) -> Any:
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface.
|
||||
|
||||
Args:
|
||||
tool: The CrewAI tool to wrap.
|
||||
|
||||
Returns:
|
||||
Async wrapper function for OpenAI agent integration.
|
||||
"""
|
||||
def create_tool_wrapper(tool: BaseTool):
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface"""
|
||||
|
||||
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
|
||||
"""Wrapper function to adapt CrewAI tool calls to OpenAI format.
|
||||
|
||||
Args:
|
||||
context_wrapper: OpenAI context wrapper.
|
||||
arguments: Tool arguments from OpenAI.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
# Get the parameter name from the schema
|
||||
param_name: str = next(
|
||||
iter(tool.args_schema.model_json_schema()["properties"].keys())
|
||||
)
|
||||
param_name = list(
|
||||
tool.args_schema.model_json_schema()["properties"].keys()
|
||||
)[0]
|
||||
|
||||
# Handle different argument types
|
||||
args_dict: dict[str, Any]
|
||||
if isinstance(arguments, dict):
|
||||
args_dict = arguments
|
||||
elif isinstance(arguments, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
args_dict = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
args_dict = {param_name: arguments}
|
||||
@@ -127,11 +59,11 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
args_dict = {param_name: str(arguments)}
|
||||
|
||||
# Run the tool with the processed arguments
|
||||
output: Any | Awaitable[Any] = tool._run(**args_dict)
|
||||
output = tool._run(**args_dict)
|
||||
|
||||
# Await if the tool returned a coroutine
|
||||
if inspect.isawaitable(output):
|
||||
result: Any = await output
|
||||
result = await output
|
||||
else:
|
||||
result = output
|
||||
|
||||
@@ -142,20 +74,17 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
|
||||
return wrapper
|
||||
|
||||
openai_tools: list[OpenAITool] = []
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
schema: dict[str, Any] = tool.args_schema.model_json_schema()
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
|
||||
schema.update({"additionalProperties": False, "type": "object"})
|
||||
|
||||
openai_tool: OpenAIFunctionTool = cast(
|
||||
OpenAIFunctionTool,
|
||||
FunctionTool(
|
||||
name=sanitize_tool_name(tool.name),
|
||||
description=tool.description,
|
||||
params_json_schema=schema,
|
||||
on_invoke_tool=create_tool_wrapper(tool),
|
||||
),
|
||||
openai_tool = FunctionTool(
|
||||
name=sanitize_tool_name(tool.name),
|
||||
description=tool.description,
|
||||
params_json_schema=schema,
|
||||
on_invoke_tool=create_tool_wrapper(tool),
|
||||
)
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Type protocols for OpenAI agents modules."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class AgentKwargs(TypedDict, total=False):
|
||||
"""Typed dict for agent initialization kwargs."""
|
||||
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
model: str
|
||||
tools: list[BaseTool] | None
|
||||
agent_config: dict[str, Any] | None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIAgent(Protocol):
|
||||
"""Protocol for OpenAI Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
instructions: str,
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI agent."""
|
||||
...
|
||||
|
||||
tools: list[Any]
|
||||
output_type: Any
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIRunner(Protocol):
|
||||
"""Protocol for OpenAI Runner."""
|
||||
|
||||
@classmethod
|
||||
def run_sync(cls, agent: OpenAIAgent, message: str) -> Any:
|
||||
"""Run agent synchronously with a message."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIAgentsModule(Protocol):
|
||||
"""Protocol for OpenAI agents module."""
|
||||
|
||||
Agent: type[OpenAIAgent]
|
||||
Runner: type[OpenAIRunner]
|
||||
enable_verbose_stdout_logging: Callable[[], None]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAITool(Protocol):
|
||||
"""Protocol for OpenAI Tool."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIFunctionTool(Protocol):
|
||||
"""Protocol for OpenAI FunctionTool."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params_json_schema: dict[str, Any],
|
||||
on_invoke_tool: Any,
|
||||
) -> None:
|
||||
"""Initialize the function tool."""
|
||||
...
|
||||
@@ -1,12 +1,5 @@
|
||||
"""OpenAI structured output converter for CrewAI task integration.
|
||||
|
||||
This module contains the OpenAIConverterAdapter class that handles structured
|
||||
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -14,7 +7,8 @@ from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in OpenAI agents.
|
||||
"""
|
||||
Adapter for handling structured output conversion in OpenAI agents.
|
||||
|
||||
This adapter enhances the OpenAI agent to handle structured output formats
|
||||
and post-processes the results when needed.
|
||||
@@ -25,23 +19,19 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_model: The Pydantic model for the output
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter.
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
|
||||
Args:
|
||||
agent_adapter: The OpenAI agent adapter instance.
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""
|
||||
super().__init__(agent_adapter=agent_adapter)
|
||||
self.agent_adapter: Any = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: str | None = None
|
||||
self._output_model: Any = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for OpenAI agent based on task requirements.
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
Args:
|
||||
task: The task containing output format requirements.
|
||||
task: The task containing output format requirements
|
||||
"""
|
||||
# Reset configuration
|
||||
self._output_format = None
|
||||
@@ -65,18 +55,19 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
self._output_model = task.output_pydantic
|
||||
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""Enhance the base system prompt with structured output requirements if needed.
|
||||
"""
|
||||
Enhance the base system prompt with structured output requirements if needed.
|
||||
|
||||
Args:
|
||||
base_prompt: The original system prompt.
|
||||
base_prompt: The original system prompt
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with output format instructions if needed.
|
||||
Enhanced system prompt with output format instructions if needed
|
||||
"""
|
||||
if not self._output_format:
|
||||
return base_prompt
|
||||
|
||||
output_schema: str = (
|
||||
output_schema = (
|
||||
I18N()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=self._schema)
|
||||
@@ -85,15 +76,16 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
|
||||
This method attempts to extract valid JSON from the result if necessary.
|
||||
|
||||
Args:
|
||||
result: The raw result from the agent.
|
||||
result: The raw result from the agent
|
||||
|
||||
Returns:
|
||||
Processed result conforming to the expected output format.
|
||||
Processed result conforming to the expected output format
|
||||
"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
@@ -105,30 +97,26 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# Try to extract JSON from markdown code blocks
|
||||
code_block_pattern: str = r"```(?:json)?\s*([\s\S]*?)```"
|
||||
code_blocks: list[str] = re.findall(code_block_pattern, result)
|
||||
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)```"
|
||||
code_blocks = re.findall(code_block_pattern, result)
|
||||
|
||||
for block in code_blocks:
|
||||
stripped_block = block.strip()
|
||||
try:
|
||||
json.loads(stripped_block)
|
||||
return stripped_block
|
||||
json.loads(block.strip())
|
||||
return block.strip()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Try to extract any JSON-like structure
|
||||
json_pattern: str = r"(\{[\s\S]*\})"
|
||||
json_matches: list[str] = re.findall(json_pattern, result, re.DOTALL)
|
||||
json_pattern = r"(\{[\s\S]*\})"
|
||||
json_matches = re.findall(json_pattern, result, re.DOTALL)
|
||||
|
||||
for match in json_matches:
|
||||
is_valid = True
|
||||
try:
|
||||
json.loads(match)
|
||||
except json.JSONDecodeError:
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
return match
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If all extraction attempts fail, return the original
|
||||
return str(result)
|
||||
|
||||
@@ -1,32 +1,29 @@
|
||||
"""Base output converter for transforming text into structured formats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OutputConverter(BaseModel, ABC):
|
||||
"""Abstract base class for converting text to structured formats.
|
||||
"""
|
||||
Abstract base class for converting task results into structured formats.
|
||||
|
||||
Uses language models to transform unstructured text into either Pydantic models
|
||||
or JSON objects based on provided instructions and target schemas.
|
||||
This class provides a framework for converting unstructured text into
|
||||
either Pydantic models or JSON, tailored for specific agent requirements.
|
||||
It uses a language model to interpret and structure the input text based
|
||||
on given instructions.
|
||||
|
||||
Attributes:
|
||||
text: The input text to be converted.
|
||||
llm: The language model used for conversion.
|
||||
model: The target Pydantic model class for structuring output.
|
||||
instructions: Specific instructions for the conversion process.
|
||||
max_attempts: Maximum number of conversion attempts (default: 3).
|
||||
text (str): The input text to be converted.
|
||||
llm (Any): The language model used for conversion.
|
||||
model (Any): The target model for structuring the output.
|
||||
instructions (str): Specific instructions for the conversion process.
|
||||
max_attempts (int): Maximum number of conversion attempts (default: 3).
|
||||
"""
|
||||
|
||||
text: str = Field(description="Text to be converted.")
|
||||
llm: Any = Field(description="The language model to be used to convert the text.")
|
||||
model: type[BaseModel] = Field(
|
||||
description="The model to be used to convert the text."
|
||||
)
|
||||
model: Any = Field(description="The model to be used to convert the text.")
|
||||
instructions: str = Field(description="Conversion instructions to the LLM.")
|
||||
max_attempts: int = Field(
|
||||
description="Max number of attempts to try to get the output formatted.",
|
||||
@@ -34,23 +31,11 @@ class OutputConverter(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
|
||||
"""Convert text to a Pydantic model instance.
|
||||
|
||||
Args:
|
||||
current_attempt: Current attempt number for retry logic.
|
||||
|
||||
Returns:
|
||||
Pydantic model instance with structured data.
|
||||
"""
|
||||
def to_pydantic(self, current_attempt=1) -> BaseModel:
|
||||
"""Convert text to pydantic."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_json(self, current_attempt: int = 1) -> dict[str, Any]:
|
||||
"""Convert text to a JSON dictionary.
|
||||
|
||||
Args:
|
||||
current_attempt: Current attempt number for retry logic.
|
||||
|
||||
Returns:
|
||||
Dictionary containing structured JSON data.
|
||||
"""
|
||||
def to_json(self, current_attempt=1) -> dict:
|
||||
"""Convert text to json."""
|
||||
pass
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
"""Token usage tracking utilities.
|
||||
|
||||
This module provides utilities for tracking token consumption and request
|
||||
metrics during agent execution.
|
||||
"""
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
"""Track token usage during agent processing.
|
||||
|
||||
Attributes:
|
||||
total_tokens: Total number of tokens used.
|
||||
prompt_tokens: Number of tokens used in prompts.
|
||||
cached_prompt_tokens: Number of cached prompt tokens used.
|
||||
completion_tokens: Number of tokens used in completions.
|
||||
successful_requests: Number of successful requests made.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize token tracking with zero values."""
|
||||
self.total_tokens: int = 0
|
||||
self.prompt_tokens: int = 0
|
||||
self.cached_prompt_tokens: int = 0
|
||||
@@ -27,45 +10,20 @@ class TokenProcess:
|
||||
self.successful_requests: int = 0
|
||||
|
||||
def sum_prompt_tokens(self, tokens: int) -> None:
|
||||
"""Add prompt tokens to the running totals.
|
||||
|
||||
Args:
|
||||
tokens: Number of prompt tokens to add.
|
||||
"""
|
||||
self.prompt_tokens += tokens
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_completion_tokens(self, tokens: int) -> None:
|
||||
"""Add completion tokens to the running totals.
|
||||
|
||||
Args:
|
||||
tokens: Number of completion tokens to add.
|
||||
"""
|
||||
self.completion_tokens += tokens
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
||||
"""Add cached prompt tokens to the running total.
|
||||
|
||||
Args:
|
||||
tokens: Number of cached prompt tokens to add.
|
||||
"""
|
||||
self.cached_prompt_tokens += tokens
|
||||
|
||||
def sum_successful_requests(self, requests: int) -> None:
|
||||
"""Add successful requests to the running total.
|
||||
|
||||
Args:
|
||||
requests: Number of successful requests to add.
|
||||
"""
|
||||
self.successful_requests += requests
|
||||
|
||||
def get_summary(self) -> UsageMetrics:
|
||||
"""Get a summary of all tracked metrics.
|
||||
|
||||
Returns:
|
||||
UsageMetrics object with current totals.
|
||||
"""
|
||||
return UsageMetrics(
|
||||
total_tokens=self.total_tokens,
|
||||
prompt_tokens=self.prompt_tokens,
|
||||
|
||||
40
src/crewai/agents/cache/cache_handler.py
vendored
40
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,45 +1,15 @@
|
||||
"""Cache handler for tool usage results."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class CacheHandler(BaseModel):
|
||||
"""Handles caching of tool execution results.
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
Provides in-memory caching for tool outputs based on tool name and input.
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
Notes:
|
||||
- TODO: Make thread-safe.
|
||||
"""
|
||||
|
||||
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool: str, input: str, output: Any) -> None:
|
||||
"""Add a tool result to the cache.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool.
|
||||
input: Input string used for the tool.
|
||||
output: Output result from tool execution.
|
||||
|
||||
Notes:
|
||||
- TODO: Rename 'input' parameter to avoid shadowing builtin.
|
||||
"""
|
||||
def add(self, tool, input, output):
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
def read(self, tool: str, input: str) -> Any | None:
|
||||
"""Retrieve a cached tool result.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool.
|
||||
input: Input string used for the tool.
|
||||
|
||||
Returns:
|
||||
Cached result if found, None otherwise.
|
||||
|
||||
Notes:
|
||||
- TODO: Rename 'input' parameter to avoid shadowing builtin.
|
||||
"""
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1,<1.0.0"
|
||||
"crewai[tools]>=0.177.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1,<1.0.0",
|
||||
"crewai[tools]>=0.177.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.186.1"
|
||||
"crewai[tools]>=0.177.0"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import os
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"platform_integration_token", default=None
|
||||
)
|
||||
|
||||
def set_platform_integration_token(integration_token: str) -> None:
|
||||
_platform_integration_token.set(integration_token)
|
||||
|
||||
def get_platform_integration_token() -> Optional[str]:
|
||||
token = _platform_integration_token.get()
|
||||
if token is None:
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||
return token
|
||||
|
||||
@contextmanager
|
||||
def platform_context(integration_token: str):
|
||||
token = _platform_integration_token.set(integration_token)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_platform_integration_token.reset(token)
|
||||
@@ -3,17 +3,26 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -30,15 +39,26 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -50,28 +70,16 @@ from crewai.events.types.crew_events import (
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -86,40 +94,28 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
Represents a group of agents, defining how they should collaborate and the
|
||||
tasks they should perform.
|
||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||
|
||||
Attributes:
|
||||
tasks: list of tasks assigned to the crew.
|
||||
agents: list of agents part of this crew.
|
||||
tasks: List of tasks assigned to the crew.
|
||||
agents: List of agents part of this crew.
|
||||
manager_llm: The language model that will run manager agent.
|
||||
manager_agent: Custom agent that will be used as manager.
|
||||
memory: Whether the crew should use memory to store memories of it's
|
||||
execution.
|
||||
cache: Whether the crew should use a cache to store the results of the
|
||||
tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling
|
||||
for all the agents.
|
||||
process: The process flow that the crew will follow (e.g., sequential,
|
||||
hierarchical).
|
||||
memory: Whether the crew should use memory to store memories of it's execution.
|
||||
cache: Whether the crew should use a cache to store the results of the tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling for all the agents.
|
||||
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
|
||||
verbose: Indicates the verbosity level for logging during execution.
|
||||
config: Configuration settings for the crew.
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to
|
||||
be respected.
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
||||
prompt_file: Path to the prompt json file to be used for the crew.
|
||||
id: A unique identifier for the crew instance.
|
||||
task_callback: Callback to be executed after each task for every agents
|
||||
execution.
|
||||
step_callback: Callback to be executed after each step for every agents
|
||||
execution.
|
||||
share_crew: Whether you want to share the complete crew information and
|
||||
execution with crewAI to make the library better, and allow us to
|
||||
train models.
|
||||
task_callback: Callback to be executed after each task for every agents execution.
|
||||
step_callback: Callback to be executed after each step for every agents execution.
|
||||
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions
|
||||
with the crew.
|
||||
security_config: Security configuration for the crew, including
|
||||
fingerprinting.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
security_config: Security configuration for the crew, including fingerprinting.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
@@ -128,13 +124,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
|
||||
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
|
||||
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
|
||||
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
|
||||
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_logging_color: str = PrivateAttr(
|
||||
default="bold_purple",
|
||||
)
|
||||
@@ -142,121 +138,107 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=TaskOutputStorageHandler
|
||||
)
|
||||
|
||||
name: str | None = Field(default="crew")
|
||||
name: Optional[str] = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
default=False,
|
||||
description="If crew should use memory to store memories of it's execution",
|
||||
description="Whether the crew should use memory to store memories of it's execution",
|
||||
)
|
||||
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
|
||||
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
|
||||
default=None,
|
||||
description="An Instance of the ShortTermMemory to be used by the Crew",
|
||||
)
|
||||
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
|
||||
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
|
||||
default=None,
|
||||
description="An Instance of the LongTermMemory to be used by the Crew",
|
||||
)
|
||||
entity_memory: InstanceOf[EntityMemory] | None = Field(
|
||||
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
|
||||
default=None,
|
||||
description="An Instance of the EntityMemory to be used by the Crew",
|
||||
)
|
||||
external_memory: InstanceOf[ExternalMemory] | None = Field(
|
||||
external_memory: Optional[InstanceOf[ExternalMemory]] = Field(
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: dict | None = Field(
|
||||
embedder: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
usage_metrics: UsageMetrics | None = Field(
|
||||
usage_metrics: Optional[UsageMetrics] = Field(
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json | dict[str, Any] | None = Field(default=None)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
share_crew: Optional[bool] = Field(default=False)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Any | None = Field(
|
||||
task_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[dict[str, Any] | None], dict[str, Any] | None]
|
||||
before_kickoff_callbacks: List[
|
||||
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of callbacks to be executed before crew kickoff. "
|
||||
"It may be used to adjust inputs before the crew is executed."
|
||||
),
|
||||
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
|
||||
)
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of callbacks to be executed after crew kickoff. "
|
||||
"It may be used to adjust the output of the crew."
|
||||
),
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
)
|
||||
max_rpm: int | None = Field(
|
||||
max_rpm: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Maximum number of requests per minute for the crew execution "
|
||||
"to be respected."
|
||||
),
|
||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||
)
|
||||
prompt_file: str | None = Field(
|
||||
prompt_file: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: bool | str | None = Field(
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
planning: bool | None = Field(
|
||||
planning: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
),
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
task_execution_output_json_files: list[str] | None = Field(
|
||||
task_execution_output_json_files: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="list of file paths for task execution JSON files.",
|
||||
description="List of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
execution_logs: List[Dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="list of execution logs for tasks",
|
||||
description="List of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Knowledge sources for the crew. Add knowledge sources to the "
|
||||
"knowledge object."
|
||||
),
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
knowledge: Knowledge | None = Field(
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None,
|
||||
description="Knowledge for the crew.",
|
||||
)
|
||||
@@ -264,18 +246,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
token_usage: UsageMetrics | None = Field(
|
||||
token_usage: Optional[UsageMetrics] = Field(
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
tracing: bool | None = Field(
|
||||
tracing: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether to enable tracing for the crew.",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
"""Prevent manual setting of the 'id' field by users."""
|
||||
if v:
|
||||
raise PydanticCustomError(
|
||||
@@ -284,7 +266,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||
def check_config_type(
|
||||
cls, v: Union[Json, Dict[str, Any]]
|
||||
) -> Union[Json, Dict[str, Any]]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -297,16 +281,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""set private attributes."""
|
||||
"""Set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener()
|
||||
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
if is_tracing_enabled() or self.tracing:
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
event_listener.verbose = self.verbose
|
||||
@@ -334,8 +314,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory does not support a default value since it was
|
||||
# designed to be managed entirely externally
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
)
|
||||
|
||||
@@ -376,10 +355,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
raise PydanticCustomError(
|
||||
"missing_manager_llm_or_manager_agent",
|
||||
(
|
||||
"Attribute `manager_llm` or `manager_agent` is required "
|
||||
"when using hierarchical process."
|
||||
),
|
||||
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.",
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -422,10 +398,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
(
|
||||
f"Sequential process error: Agent is missing in the task "
|
||||
f"with the following description: {task.description}"
|
||||
), # type: ignore # Dynamic string in error message
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
{},
|
||||
)
|
||||
|
||||
@@ -486,10 +459,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
(
|
||||
f"Conditional Task: {task.description}, "
|
||||
f"cannot be executed asynchronously."
|
||||
),
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
{},
|
||||
)
|
||||
return self
|
||||
@@ -508,9 +478,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for j in range(i - 1, -1, -1):
|
||||
if self.tasks[j] == context_task:
|
||||
raise ValueError(
|
||||
f"Task '{task.description}' is asynchronous and "
|
||||
f"cannot include other sequential asynchronous "
|
||||
f"tasks in its context."
|
||||
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
|
||||
)
|
||||
if not self.tasks[j].async_execution:
|
||||
break
|
||||
@@ -528,15 +496,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
continue # Skip context tasks not in the main tasks list
|
||||
if task_indices[id(context_task)] > task_indices[id(task)]:
|
||||
raise ValueError(
|
||||
f"Task '{task.description}' has a context dependency "
|
||||
f"on a future task '{context_task.description}', "
|
||||
f"which is not allowed."
|
||||
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
task.key for task in self.tasks
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -552,9 +518,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self):
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
if self.config is None:
|
||||
raise ValueError("Config should not be None.")
|
||||
if not self.config.get("agents") or not self.config.get("tasks"):
|
||||
raise PydanticCustomError(
|
||||
"missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {}
|
||||
@@ -564,7 +530,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.agents = [Agent(**agent) for agent in self.config["agents"]]
|
||||
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
|
||||
|
||||
def _create_task(self, task_config: dict[str, Any]) -> Task:
|
||||
def _create_task(self, task_config: Dict[str, Any]) -> Task:
|
||||
"""Creates a task instance from its configuration.
|
||||
|
||||
Args:
|
||||
@@ -593,7 +559,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
inputs = inputs or {}
|
||||
@@ -645,7 +611,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
@@ -716,9 +682,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input and aggregates results."""
|
||||
results: list[CrewOutput] = []
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results: List[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
@@ -737,12 +703,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
async def kickoff_async(
|
||||
self, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
@@ -771,9 +739,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tasks=self.tasks, planning_agent_llm=self.planning_llm
|
||||
)._handle_crew_planning()
|
||||
|
||||
for task, step_plan in zip(
|
||||
self.tasks, result.list_of_plans_per_task, strict=False
|
||||
):
|
||||
for task, step_plan in zip(self.tasks, result.list_of_plans_per_task):
|
||||
task.description += step_plan.plan
|
||||
|
||||
def _store_execution_log(
|
||||
@@ -810,7 +776,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _run_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to complete the tasks."""
|
||||
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
@@ -841,24 +807,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
tasks: List[Task],
|
||||
start_index: Optional[int] = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks sequentially and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks (List[Task]): List of tasks to execute
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for
|
||||
delegation. Defaults to None.
|
||||
manager (Optional[BaseAgent], optional): Manager agent to use for delegation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
@@ -873,9 +838,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
@@ -884,7 +847,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -904,7 +867,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -916,7 +879,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -930,11 +893,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: List[TaskOutput],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
) -> Optional[TaskOutput]:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
@@ -954,8 +917,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
@@ -984,22 +947,22 @@ class Crew(FlowTrackable, BaseModel):
|
||||
):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] compatible with Task.execute_sync and execute_async
|
||||
return cast(list[BaseTool], tools)
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
return self.manager_agent
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates."""
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -1010,41 +973,41 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
task_agent: BaseAgent,
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -1052,7 +1015,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
@@ -1061,8 +1024,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1070,17 +1033,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
return (
|
||||
context = (
|
||||
aggregate_raw_outputs_from_task_outputs(task_outputs)
|
||||
if task.context is NOT_SPECIFIED
|
||||
else aggregate_raw_outputs_from_tasks(task.context)
|
||||
)
|
||||
return context
|
||||
|
||||
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
|
||||
role = task.agent.role if task.agent is not None else "None"
|
||||
@@ -1093,7 +1057,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output=output.raw,
|
||||
)
|
||||
|
||||
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
|
||||
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
@@ -1124,10 +1088,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
task_outputs: list[TaskOutput] = []
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs: List[TaskOutput] = []
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
@@ -1137,7 +1101,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
def _find_task_index(
|
||||
self, task_id: str, stored_outputs: List[Any]
|
||||
) -> Optional[int]:
|
||||
return next(
|
||||
(
|
||||
index
|
||||
@@ -1147,8 +1113,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
None,
|
||||
)
|
||||
|
||||
def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutput:
|
||||
"""Replay the crew execution from a specific task."""
|
||||
def replay(
|
||||
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
stored_outputs = self._task_output_handler.load()
|
||||
if not stored_outputs:
|
||||
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
|
||||
@@ -1184,19 +1151,19 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.tasks[i].output = task_output
|
||||
|
||||
self._logging_color = "bold_blue"
|
||||
return self._execute_tasks(self.tasks, start_index, True)
|
||||
result = self._execute_tasks(self.tasks, start_index, True)
|
||||
return result
|
||||
|
||||
def query_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information."""
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
@@ -1205,11 +1172,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: set[str] = set()
|
||||
required_inputs: Set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
@@ -1263,7 +1230,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
cloned_tasks.append(cloned_task)
|
||||
task_mapping[task.key] = cloned_task
|
||||
|
||||
for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False):
|
||||
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
|
||||
if isinstance(original_task.context, list):
|
||||
cloned_context = [
|
||||
task_mapping[context_task.key]
|
||||
@@ -1289,7 +1256,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
copied_data.pop("agents", None)
|
||||
copied_data.pop("tasks", None)
|
||||
|
||||
return Crew(
|
||||
copied_crew = Crew(
|
||||
**copied_data,
|
||||
agents=cloned_agents,
|
||||
tasks=cloned_tasks,
|
||||
@@ -1299,13 +1266,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_llm=manager_llm,
|
||||
)
|
||||
|
||||
return copied_crew
|
||||
|
||||
def _set_tasks_callbacks(self) -> None:
|
||||
"""Sets callback for every task suing task_callback"""
|
||||
for task in self.tasks:
|
||||
if not task.callback:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
[
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
@@ -1338,13 +1307,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: dict[str, Any] | None = None,
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
Uses concurrent.futures for concurrent execution.
|
||||
"""
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||
llm_instance = create_llm(eval_llm)
|
||||
@@ -1384,11 +1350,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Crew(id={self.id}, process={self.process}, "
|
||||
f"number_of_agents={len(self.agents)}, "
|
||||
f"number_of_tasks={len(self.tasks)})"
|
||||
)
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
"""Reset specific or all memories for the crew.
|
||||
@@ -1402,7 +1364,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ValueError: If an invalid command type is provided.
|
||||
RuntimeError: If memory reset operation fails.
|
||||
"""
|
||||
valid_types = frozenset(
|
||||
VALID_TYPES = frozenset(
|
||||
[
|
||||
"long",
|
||||
"short",
|
||||
@@ -1415,10 +1377,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
]
|
||||
)
|
||||
|
||||
if command_type not in valid_types:
|
||||
if command_type not in VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid command type. Must be one of: "
|
||||
f"{', '.join(sorted(valid_types))}"
|
||||
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1428,7 +1389,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {e!s}"
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
@@ -1436,7 +1397,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = self._get_memory_systems()
|
||||
|
||||
for config in memory_systems.values():
|
||||
for memory_type, config in memory_systems.items():
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
@@ -1444,13 +1405,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
@@ -1475,21 +1434,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self):
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing all memory systems with their reset functions and
|
||||
display names.
|
||||
Dict containing all memory systems with their reset functions and display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory):
|
||||
@@ -1550,7 +1506,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
|
||||
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
@@ -9,158 +9,48 @@ This module provides the event infrastructure that allows users to:
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
ReasoningEvent,
|
||||
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskEvaluationEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationCompletedEvent",
|
||||
"AgentEvaluationFailedEvent",
|
||||
"AgentEvaluationStartedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"AgentExecutionErrorEvent",
|
||||
"AgentExecutionStartedEvent",
|
||||
"AgentLogsExecutionEvent",
|
||||
"AgentLogsStartedEvent",
|
||||
"AgentReasoningCompletedEvent",
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewTestCompletedEvent",
|
||||
"CrewTestFailedEvent",
|
||||
"CrewTestResultEvent",
|
||||
"CrewTestStartedEvent",
|
||||
"CrewTrainCompletedEvent",
|
||||
"CrewTrainFailedEvent",
|
||||
"CrewTrainStartedEvent",
|
||||
"FlowCreatedEvent",
|
||||
"FlowEvent",
|
||||
"FlowFinishedEvent",
|
||||
"FlowPlotEvent",
|
||||
"FlowStartedEvent",
|
||||
"KnowledgeQueryCompletedEvent",
|
||||
"KnowledgeQueryFailedEvent",
|
||||
"KnowledgeQueryStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeSearchQueryFailedEvent",
|
||||
"LLMCallCompletedEvent",
|
||||
"LLMCallFailedEvent",
|
||||
"LLMCallStartedEvent",
|
||||
"LLMGuardrailCompletedEvent",
|
||||
"LLMGuardrailStartedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"crewai_event_bus",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveStartedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
"MemoryRetrievalCompletedEvent",
|
||||
"MemoryRetrievalStartedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveFailedEvent",
|
||||
"MemorySaveStartedEvent",
|
||||
"MethodExecutionFailedEvent",
|
||||
"MethodExecutionFinishedEvent",
|
||||
"MethodExecutionStartedEvent",
|
||||
"ReasoningEvent",
|
||||
"TaskCompletedEvent",
|
||||
"TaskEvaluationEvent",
|
||||
"TaskFailedEvent",
|
||||
"TaskStartedEvent",
|
||||
"ToolExecutionErrorEvent",
|
||||
"ToolSelectionErrorEvent",
|
||||
"ToolUsageErrorEvent",
|
||||
"ToolUsageEvent",
|
||||
"ToolUsageFinishedEvent",
|
||||
"ToolUsageStartedEvent",
|
||||
"ToolValidateInputErrorEvent",
|
||||
"crewai_event_bus",
|
||||
]
|
||||
"MemoryQueryFailedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
]
|
||||
@@ -1,174 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
mark_first_execution_completed,
|
||||
prompt_user_for_trace_viewing,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FirstTimeTraceHandler:
|
||||
"""Handles the first-time user trace collection and display flow."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_first_time: bool = False
|
||||
self.collected_events: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.ephemeral_url: str | None = None
|
||||
self.batch_manager: TraceBatchManager | None = None
|
||||
|
||||
def initialize_for_first_time_user(self) -> bool:
|
||||
"""Check if this is first time and initialize collection."""
|
||||
self.is_first_time = should_auto_collect_first_time_traces()
|
||||
return self.is_first_time
|
||||
|
||||
def set_batch_manager(self, batch_manager: TraceBatchManager):
|
||||
"""Set reference to batch manager for sending events."""
|
||||
self.batch_manager = batch_manager
|
||||
|
||||
def mark_events_collected(self):
|
||||
"""Mark that events have been collected during execution."""
|
||||
self.collected_events = True
|
||||
|
||||
def handle_execution_completion(self):
|
||||
"""Handle the completion flow as shown in your diagram."""
|
||||
if not self.is_first_time or not self.collected_events:
|
||||
return
|
||||
|
||||
try:
|
||||
user_wants_traces = prompt_user_for_trace_viewing(timeout_seconds=20)
|
||||
|
||||
if user_wants_traces:
|
||||
self._initialize_backend_and_send_events()
|
||||
|
||||
if self.ephemeral_url:
|
||||
self._display_ephemeral_trace_link()
|
||||
|
||||
mark_first_execution_completed()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||
mark_first_execution_completed()
|
||||
|
||||
def _initialize_backend_and_send_events(self):
|
||||
"""Initialize backend batch and send collected events."""
|
||||
if not self.batch_manager:
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.batch_manager.backend_initialized:
|
||||
original_metadata = (
|
||||
self.batch_manager.current_batch.execution_metadata
|
||||
if self.batch_manager.current_batch
|
||||
else {}
|
||||
)
|
||||
|
||||
user_context = {
|
||||
"privacy_level": "standard",
|
||||
"user_id": "first_time_user",
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"trace_id": self.batch_manager.trace_batch_id,
|
||||
}
|
||||
|
||||
execution_metadata = {
|
||||
"execution_type": original_metadata.get("execution_type", "crew"),
|
||||
"crew_name": original_metadata.get(
|
||||
"crew_name", "First Time Execution"
|
||||
),
|
||||
"flow_name": original_metadata.get("flow_name"),
|
||||
"agent_count": original_metadata.get("agent_count", 1),
|
||||
"task_count": original_metadata.get("task_count", 1),
|
||||
"crewai_version": original_metadata.get("crewai_version"),
|
||||
}
|
||||
|
||||
self.batch_manager._initialize_backend_batch(
|
||||
user_context=user_context,
|
||||
execution_metadata=execution_metadata,
|
||||
use_ephemeral=True,
|
||||
)
|
||||
self.batch_manager.backend_initialized = True
|
||||
|
||||
if self.batch_manager.event_buffer:
|
||||
self.batch_manager._send_events_to_backend()
|
||||
|
||||
self.batch_manager.finalize_batch()
|
||||
self.ephemeral_url = self.batch_manager.ephemeral_trace_url
|
||||
|
||||
if not self.ephemeral_url:
|
||||
self._show_local_trace_message()
|
||||
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
"""Display the ephemeral trace link to the user."""
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
🎉 Your First CrewAI Execution Trace is Ready!
|
||||
|
||||
View your execution details here:
|
||||
{self.ephemeral_url}
|
||||
|
||||
This trace shows:
|
||||
• Agent decisions and interactions
|
||||
• Task execution timeline
|
||||
• Tool usage and results
|
||||
• LLM calls and responses
|
||||
|
||||
To use traces add tracing=True to your Crew(tracing=True) / Flow(tracing=True)
|
||||
|
||||
📝 Note: This link will expire in 24 hours.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Execution Trace Generated",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _gracefully_fail(self, error_message: str):
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self):
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
📊 Your execution traces were collected locally!
|
||||
|
||||
Unfortunately, we couldn't upload them to the server right now, but here's what we captured:
|
||||
• {len(self.batch_manager.event_buffer)} trace events
|
||||
• Execution duration: {self.batch_manager.calculate_duration("execution")}ms
|
||||
• Batch ID: {self.batch_manager.trace_batch_id}
|
||||
|
||||
The traces include agent decisions, task execution, and tool usage.
|
||||
Try running with CREWAI_TRACING_ENABLED=true next time for persistent traces.
|
||||
""".strip()
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
title="🔍 Local Traces Collected",
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
@@ -1,18 +1,18 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -23,11 +23,11 @@ class TraceBatch:
|
||||
|
||||
version: str = field(default_factory=get_crewai_version)
|
||||
batch_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
user_context: dict[str, str] = field(default_factory=dict)
|
||||
execution_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
events: list[TraceEvent] = field(default_factory=list)
|
||||
user_context: Dict[str, str] = field(default_factory=dict)
|
||||
execution_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
events: List[TraceEvent] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"batch_id": self.batch_id,
|
||||
@@ -40,28 +40,26 @@ class TraceBatch:
|
||||
class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
is_current_batch_ephemeral: bool = False
|
||||
trace_batch_id: Optional[str] = None
|
||||
current_batch: Optional[TraceBatch] = None
|
||||
event_buffer: List[TraceEvent] = []
|
||||
execution_start_times: Dict[str, datetime] = {}
|
||||
batch_owner_type: Optional[str] = None
|
||||
batch_owner_id: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.is_current_batch_ephemeral: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.current_batch: TraceBatch | None = None
|
||||
self.event_buffer: list[TraceEvent] = []
|
||||
self.execution_start_times: dict[str, datetime] = {}
|
||||
self.batch_owner_type: str | None = None
|
||||
self.batch_owner_id: str | None = None
|
||||
self.backend_initialized: bool = False
|
||||
self.ephemeral_trace_url: str | None = None
|
||||
try:
|
||||
self.plus_api = PlusAPI(
|
||||
api_key=get_auth_token(),
|
||||
)
|
||||
except AuthError:
|
||||
self.plus_api = PlusAPI(api_key="")
|
||||
self.ephemeral_trace_url = None
|
||||
|
||||
def initialize_batch(
|
||||
self,
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch"""
|
||||
@@ -72,21 +70,14 @@ class TraceBatchManager:
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
|
||||
self.record_start_time("execution")
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
)
|
||||
self.backend_initialized = True
|
||||
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
self,
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
"""Send batch initialization to backend"""
|
||||
@@ -152,7 +143,7 @@ class TraceBatchManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
f"Error initializing trace batch: {str(e)}. Continuing without tracing."
|
||||
)
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
@@ -163,6 +154,7 @@ class TraceBatchManager:
|
||||
"""Send buffered events to backend with graceful failure handling"""
|
||||
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
|
||||
return 500
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"events": [event.to_dict() for event in self.event_buffer],
|
||||
@@ -186,19 +178,19 @@ class TraceBatchManager:
|
||||
if response.status_code in [200, 201]:
|
||||
self.event_buffer.clear()
|
||||
return 200
|
||||
|
||||
logger.warning(
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error sending events to backend: {e}. Events will be lost."
|
||||
f"Error sending events to backend: {str(e)}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
def finalize_batch(self) -> TraceBatch | None:
|
||||
def finalize_batch(self) -> Optional[TraceBatch]:
|
||||
"""Finalize batch and return it for sending"""
|
||||
if not self.current_batch:
|
||||
return None
|
||||
@@ -254,10 +246,6 @@ class TraceBatchManager:
|
||||
if not self.is_current_batch_ephemeral and access_code is None
|
||||
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
|
||||
)
|
||||
|
||||
if self.is_current_batch_ephemeral:
|
||||
self.ephemeral_trace_url = return_link
|
||||
|
||||
panel = Panel(
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
|
||||
title="Trace Batch Finalization",
|
||||
@@ -271,8 +259,8 @@ class TraceBatchManager:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
# TODO: send error to app marking as failed
|
||||
logger.error(f"❌ Error finalizing trace batch: {str(e)}")
|
||||
# TODO: send error to app
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
@@ -289,7 +277,7 @@ class TraceBatchManager:
|
||||
self.batch_sequence = 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Warning: Error during cleanup: {e}")
|
||||
logger.error(f"Warning: Error during cleanup: {str(e)}")
|
||||
|
||||
def has_events(self) -> bool:
|
||||
"""Check if there are events in the buffer"""
|
||||
@@ -318,7 +306,7 @@ class TraceBatchManager:
|
||||
return duration_ms
|
||||
return 0
|
||||
|
||||
def get_trace_id(self) -> str | None:
|
||||
def get_trace_id(self) -> Optional[str]:
|
||||
"""Get current trace ID"""
|
||||
if self.current_batch:
|
||||
return self.current_batch.user_context.get("trace_id")
|
||||
|
||||
@@ -1,59 +1,28 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import safe_serialize_to_dict
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -64,16 +33,49 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
FlowPlotEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
from .trace_batch_manager import TraceBatchManager
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""
|
||||
Trace collection listener that orchestrates trace collection
|
||||
"""
|
||||
|
||||
complex_events: ClassVar[list[str]] = [
|
||||
complex_events = [
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"llm_call_started",
|
||||
@@ -86,14 +88,14 @@ class TraceCollectionListener(BaseEventListener):
|
||||
_initialized = False
|
||||
_listeners_setup = False
|
||||
|
||||
def __new__(cls, batch_manager: TraceBatchManager | None = None):
|
||||
def __new__(cls, batch_manager=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: TraceBatchManager | None = None,
|
||||
batch_manager: Optional[TraceBatchManager] = None,
|
||||
):
|
||||
if self._initialized:
|
||||
return
|
||||
@@ -101,19 +103,16 @@ class TraceCollectionListener(BaseEventListener):
|
||||
super().__init__()
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self._initialized = True
|
||||
self.first_time_handler = FirstTimeTraceHandler()
|
||||
|
||||
if self.first_time_handler.initialize_for_first_time_user():
|
||||
self.first_time_handler.set_batch_manager(self.batch_manager)
|
||||
|
||||
def _check_authenticated(self) -> bool:
|
||||
"""Check if tracing should be enabled"""
|
||||
try:
|
||||
return bool(get_auth_token())
|
||||
res = bool(get_auth_token())
|
||||
return res
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
def _get_user_context(self) -> dict[str, str]:
|
||||
def _get_user_context(self) -> Dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
return {
|
||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||
@@ -162,14 +161,8 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event):
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
# Normal flow finalization
|
||||
self.batch_manager.finalize_batch()
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
@@ -188,20 +181,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_crew_completed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
self.batch_manager.finalize_batch()
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
self.batch_manager.finalize_batch()
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event):
|
||||
@@ -340,19 +325,17 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
):
|
||||
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
|
||||
|
||||
if self.first_time_handler.is_first_time:
|
||||
return self.batch_manager.initialize_batch(
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=True
|
||||
)
|
||||
|
||||
use_ephemeral = not self._check_authenticated()
|
||||
return self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=use_ephemeral
|
||||
)
|
||||
else:
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=False
|
||||
)
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
@@ -388,11 +371,11 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return safe_serialize_to_dict(event)
|
||||
if event_type == "task_started":
|
||||
return self._safe_serialize_to_dict(event)
|
||||
elif event_type == "task_started":
|
||||
return {
|
||||
"task_description": event.task.description,
|
||||
"expected_output": event.task.expected_output,
|
||||
@@ -401,7 +384,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"agent_role": source.agent.role,
|
||||
"task_id": str(event.task.id),
|
||||
}
|
||||
if event_type == "task_completed":
|
||||
elif event_type == "task_completed":
|
||||
return {
|
||||
"task_description": event.task.description if event.task else None,
|
||||
"task_name": event.task.name or event.task.description
|
||||
@@ -414,31 +397,63 @@ class TraceCollectionListener(BaseEventListener):
|
||||
else None,
|
||||
"agent_role": event.output.agent if event.output else None,
|
||||
}
|
||||
if event_type == "agent_execution_started":
|
||||
elif event_type == "agent_execution_started":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
if event_type == "agent_execution_completed":
|
||||
elif event_type == "agent_execution_completed":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
if event_type == "llm_call_started":
|
||||
event_data = safe_serialize_to_dict(event)
|
||||
elif event_type == "llm_call_started":
|
||||
event_data = self._safe_serialize_to_dict(event)
|
||||
event_data["task_name"] = (
|
||||
event.task_name or event.task_description
|
||||
if hasattr(event, "task_name") and event.task_name
|
||||
else None
|
||||
)
|
||||
return event_data
|
||||
if event_type == "llm_call_completed":
|
||||
return safe_serialize_to_dict(event)
|
||||
elif event_type == "llm_call_completed":
|
||||
return self._safe_serialize_to_dict(event)
|
||||
else:
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": self._safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
# TODO: move to utils
|
||||
def _safe_serialize_to_dict(
|
||||
self, obj, exclude: set[str] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
else:
|
||||
return {"serialized_data": serialized}
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
# TODO: move to utils
|
||||
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
# Limit number of messages
|
||||
limited_messages = messages[:max_messages]
|
||||
|
||||
# Truncate each message content
|
||||
for msg in limited_messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
content = msg["content"]
|
||||
if len(content) > max_content_length:
|
||||
msg["content"] = content[:max_content_length] + "..."
|
||||
|
||||
return limited_messages
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import subprocess
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
import re
|
||||
import json
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
@@ -51,11 +43,13 @@ def _get_machine_id() -> str:
|
||||
|
||||
try:
|
||||
mac = ":".join(
|
||||
[f"{(uuid.getnode() >> b) & 0xFF:02x}" for b in range(0, 12, 2)][::-1]
|
||||
["{:02x}".format((uuid.getnode() >> b) & 0xFF) for b in range(0, 12, 2)][
|
||||
::-1
|
||||
]
|
||||
)
|
||||
parts.append(mac)
|
||||
except Exception:
|
||||
logger.warning("Error getting machine id for fingerprinting")
|
||||
pass
|
||||
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
@@ -63,7 +57,7 @@ def _get_machine_id() -> str:
|
||||
try:
|
||||
if sysname == "Darwin":
|
||||
res = subprocess.run(
|
||||
["/usr/sbin/system_profiler", "SPHardwareDataType"],
|
||||
["system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
@@ -78,7 +72,7 @@ def _get_machine_id() -> str:
|
||||
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
|
||||
elif sysname == "Windows":
|
||||
res = subprocess.run(
|
||||
["C:\\Windows\\System32\\wbem\\wmic.exe", "csproduct", "get", "UUID"],
|
||||
["wmic", "csproduct", "get", "UUID"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
@@ -87,7 +81,7 @@ def _get_machine_id() -> str:
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception:
|
||||
logger.exception("Error getting machine ID")
|
||||
pass
|
||||
|
||||
return hashlib.sha256("".join(parts).encode()).hexdigest()
|
||||
|
||||
@@ -103,8 +97,8 @@ def _load_user_data() -> dict:
|
||||
if p.exists():
|
||||
try:
|
||||
return json.loads(p.read_text())
|
||||
except (json.JSONDecodeError, OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to load user data: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
@@ -112,8 +106,8 @@ def _save_user_data(data: dict) -> None:
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
@@ -157,103 +151,3 @@ def mark_first_execution_done() -> None:
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
return {"serialized_data": serialized}
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
|
||||
def truncate_messages(messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
limited_messages = messages[:max_messages]
|
||||
|
||||
for msg in limited_messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
content = msg["content"]
|
||||
if len(content) > max_content_length:
|
||||
msg["content"] = content[:max_content_length] + "..."
|
||||
|
||||
return limited_messages
|
||||
|
||||
|
||||
def should_auto_collect_first_time_traces() -> bool:
|
||||
"""True if we should auto-collect traces for first-time user."""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
return is_first_execution()
|
||||
|
||||
|
||||
def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
"""
|
||||
Prompt user if they want to see their traces with timeout.
|
||||
Returns True if user wants to see traces, False otherwise.
|
||||
"""
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
||||
console = Console()
|
||||
|
||||
content = Text()
|
||||
content.append("🔍 ", style="cyan bold")
|
||||
content.append(
|
||||
"Detailed execution traces are available!\n\n", style="cyan bold"
|
||||
)
|
||||
content.append("View insights including:\n", style="white")
|
||||
content.append(" • Agent decision-making process\n", style="bright_blue")
|
||||
content.append(" • Task execution flow and timing\n", style="bright_blue")
|
||||
content.append(" • Tool usage details", style="bright_blue")
|
||||
|
||||
panel = Panel(
|
||||
content,
|
||||
title="[bold cyan]Execution Traces[/bold cyan]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
|
||||
prompt_text = click.style(
|
||||
f"Would you like to view your execution traces? [y/N] ({timeout_seconds}s timeout): ",
|
||||
fg="white",
|
||||
bold=True,
|
||||
)
|
||||
click.echo(prompt_text, nl=False)
|
||||
|
||||
result = [False]
|
||||
|
||||
def get_input():
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
input_thread.start()
|
||||
input_thread.join(timeout=timeout_seconds)
|
||||
|
||||
if input_thread.is_alive():
|
||||
return False
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def mark_first_execution_completed() -> None:
|
||||
"""Mark first execution as completed (called after trace prompt)."""
|
||||
mark_first_execution_done()
|
||||
|
||||
@@ -2,22 +2,30 @@ import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -27,10 +35,12 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,14 +55,16 @@ class FlowState(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# type variables with explicit bounds
|
||||
T = TypeVar("T", bound=dict[str, Any] | BaseModel) # Generic flow state type parameter
|
||||
# Type variables with explicit bounds
|
||||
T = TypeVar(
|
||||
"T", bound=Union[Dict[str, Any], BaseModel]
|
||||
) # Generic flow state type parameter
|
||||
StateT = TypeVar(
|
||||
"StateT", bound=dict[str, Any] | BaseModel
|
||||
"StateT", bound=Union[Dict[str, Any], BaseModel]
|
||||
) # State validation type parameter
|
||||
|
||||
|
||||
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
|
||||
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
Args:
|
||||
@@ -92,7 +104,7 @@ def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
|
||||
raise TypeError(f"Invalid expected_type: {expected_type}")
|
||||
|
||||
|
||||
def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
"""
|
||||
Marks a method as a flow's starting point.
|
||||
|
||||
@@ -159,7 +171,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition: str | dict | Callable) -> Callable:
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a listener that executes when specified conditions are met.
|
||||
|
||||
@@ -219,7 +231,7 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def router(condition: str | dict | Callable) -> Callable:
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a routing method that directs flow execution based on conditions.
|
||||
|
||||
@@ -285,7 +297,7 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions: str | dict | Callable) -> dict:
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
@@ -331,7 +343,7 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
@@ -413,10 +425,10 @@ class FlowMeta(type):
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
|
||||
cls._start_methods = start_methods
|
||||
cls._listeners = listeners
|
||||
cls._routers = routers
|
||||
cls._router_paths = router_paths
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
setattr(cls, "_routers", routers)
|
||||
setattr(cls, "_router_paths", router_paths)
|
||||
|
||||
return cls
|
||||
|
||||
@@ -424,29 +436,29 @@ class FlowMeta(type):
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""Base class for all flows.
|
||||
|
||||
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
_start_methods: ClassVar[list[str]] = []
|
||||
_listeners: ClassVar[dict[str, tuple[str, list[str]]]] = {}
|
||||
_routers: ClassVar[set[str]] = set()
|
||||
_router_paths: ClassVar[dict[str, list[str]]] = {}
|
||||
initial_state: type[T] | T | None = None
|
||||
name: str | None = None
|
||||
tracing: bool | None = False
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||
_routers: Set[str] = set()
|
||||
_router_paths: Dict[str, List[str]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
name: Optional[str] = None
|
||||
tracing: Optional[bool] = False
|
||||
|
||||
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_t = item # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
|
||||
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
|
||||
return _FlowGeneric
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistence: FlowPersistence | None = None,
|
||||
tracing: bool | None = False,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
tracing: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
@@ -456,22 +468,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
**kwargs: Additional state values to initialize or override
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
self._methods: dict[str, Callable] = {}
|
||||
self._method_execution_counts: dict[str, int] = {}
|
||||
self._pending_and_listeners: dict[str, set[str]] = {}
|
||||
self._method_outputs: list[Any] = [] # list to store all method outputs
|
||||
self._completed_methods: set[str] = set() # Track completed methods for reload
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self.tracing = tracing
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
if is_tracing_enabled() or self.tracing:
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
# Apply any additional kwargs
|
||||
@@ -513,25 +521,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
"""
|
||||
# Handle case where initial_state is None but we have a type parameter
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_t"):
|
||||
state_type = self._initial_state_t
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
state_type = getattr(self, "_initial_state_T")
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
if issubclass(state_type, BaseModel):
|
||||
elif issubclass(state_type, BaseModel):
|
||||
# Create a new type that includes the ID field
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
if state_type is dict:
|
||||
elif state_type is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
# Handle case where no initial state is provided
|
||||
@@ -542,13 +550,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
elif issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
if self.initial_state is dict:
|
||||
elif self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
# Handle dictionary instance case
|
||||
@@ -592,7 +600,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> list[Any]:
|
||||
def method_outputs(self) -> List[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
return self._method_outputs
|
||||
|
||||
@@ -623,13 +631,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
if isinstance(self._state, BaseModel):
|
||||
elif isinstance(self._state, BaseModel):
|
||||
return str(getattr(self._state, "id", ""))
|
||||
return ""
|
||||
except (AttributeError, TypeError):
|
||||
return "" # Safely handle any unexpected attribute access issues
|
||||
|
||||
def _initialize_state(self, inputs: dict[str, Any]) -> None:
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Initialize or update flow state with new inputs.
|
||||
|
||||
Args:
|
||||
@@ -683,7 +691,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||
|
||||
def _restore_state(self, stored_state: dict[str, Any]) -> None:
|
||||
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
||||
"""Restore flow state from persistence.
|
||||
|
||||
Args:
|
||||
@@ -727,7 +735,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
execution_data: Flow execution data containing:
|
||||
- id: Flow execution ID
|
||||
- flow: Flow structure
|
||||
- completed_methods: list of successfully completed methods
|
||||
- completed_methods: List of successfully completed methods
|
||||
- execution_methods: All execution methods with their status
|
||||
"""
|
||||
flow_id = execution_data.get("id")
|
||||
@@ -763,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if state_to_apply:
|
||||
self._apply_state_updates(state_to_apply)
|
||||
|
||||
for method in sorted_methods[:-1]:
|
||||
for i, method in enumerate(sorted_methods[:-1]):
|
||||
method_name = method.get("flow_method", {}).get("name")
|
||||
if method_name:
|
||||
self._completed_methods.add(method_name)
|
||||
@@ -775,7 +783,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif hasattr(self._state, field_name):
|
||||
object.__setattr__(self._state, field_name, value)
|
||||
|
||||
def _apply_state_updates(self, updates: dict[str, Any]) -> None:
|
||||
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
|
||||
"""Apply multiple state updates efficiently."""
|
||||
if isinstance(self._state, dict):
|
||||
self._state.update(updates)
|
||||
@@ -784,7 +792,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if hasattr(self._state, key):
|
||||
object.__setattr__(self._state, key, value)
|
||||
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
|
||||
@@ -797,7 +805,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
|
||||
@@ -832,7 +840,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self._state, dict):
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"]) # noqa: B010
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
|
||||
# If persistence is enabled, attempt to restore the stored state using the provided id.
|
||||
if "id" in inputs and self._persistence is not None:
|
||||
@@ -1067,7 +1075,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
all_triggers = [trigger_method] + router_results
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
@@ -1101,7 +1109,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
@@ -1118,7 +1126,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List[str]
|
||||
Names of methods that should be triggered.
|
||||
|
||||
Notes
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
@@ -14,23 +13,23 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
collection_name: str | None = None
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
@@ -41,10 +40,11 @@ class Knowledge(BaseModel):
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult]:
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -55,11 +55,12 @@ class Knowledge(BaseModel):
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
return self.storage.search(
|
||||
results = self.storage.search(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.types import SearchResult
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
@@ -10,17 +8,22 @@ class BaseKnowledgeStorage(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[SearchResult]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, documents: list[str]) -> None:
|
||||
def save(
|
||||
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
pass
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -19,101 +27,167 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
self._set_embedder_config(embedder)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
||||
result = {
|
||||
"id": fetched["ids"][0][i], # type: ignore
|
||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
||||
"context": fetched["documents"][0][i], # type: ignore
|
||||
"score": fetched["distances"][0][i], # type: ignore
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
self.app = create_persistent_client(
|
||||
path=os.path.join(db_storage_path(), "knowledge"),
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
|
||||
try:
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
self._client = create_client(config)
|
||||
if self.app:
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedder,
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
def reset(self):
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = create_persistent_client(
|
||||
path=base_path, settings=Settings(allow_reset=True)
|
||||
)
|
||||
|
||||
def search(
|
||||
self.app.reset()
|
||||
shutil.rmtree(base_path)
|
||||
self.app = None
|
||||
self.collection = None
|
||||
|
||||
def save(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[SearchResult]:
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
# Create a dictionary to store unique documents
|
||||
unique_docs = {}
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
# Generate IDs and create a mapping of id -> (document, metadata)
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = hashlib.sha256(doc.encode("utf-8")).hexdigest()
|
||||
doc_metadata = None
|
||||
if metadata is not None:
|
||||
if isinstance(metadata, list):
|
||||
doc_metadata = metadata[idx]
|
||||
else:
|
||||
doc_metadata = metadata
|
||||
unique_docs[doc_id] = (doc, doc_metadata)
|
||||
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
# Prepare filtered lists for ChromaDB
|
||||
filtered_docs = []
|
||||
filtered_metadata = []
|
||||
filtered_ids = []
|
||||
|
||||
# Build the filtered lists
|
||||
for doc_id, (doc, meta) in unique_docs.items():
|
||||
filtered_docs.append(doc)
|
||||
filtered_metadata.append(meta)
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=filtered_docs,
|
||||
metadatas=final_metadata,
|
||||
ids=filtered_ids,
|
||||
)
|
||||
except chromadb.errors.InvalidDimensionException as e:
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge search: {e!s}")
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge reset: {e!s}")
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from crewai.rag.types import SearchResult
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def extract_knowledge_context(knowledge_snippets: list[SearchResult]) -> str:
|
||||
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
|
||||
"""Extract knowledge from the task prompt."""
|
||||
valid_snippets = [
|
||||
result["content"]
|
||||
result["context"]
|
||||
for result in knowledge_snippets
|
||||
if result and result.get("content")
|
||||
if result and result.get("context")
|
||||
]
|
||||
snippet = "\n".join(valid_snippets)
|
||||
return f"Additional Information: {snippet}" if valid_snippets else ""
|
||||
|
||||
@@ -6,14 +6,19 @@ import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -26,9 +31,9 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -46,8 +51,8 @@ with warnings.catch_warnings():
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -263,14 +268,14 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: str | None
|
||||
role: str | None
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: str | None
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -283,34 +288,32 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
completion_cost: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
enable_prompt_caching: bool = False,
|
||||
cache_control: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -337,14 +340,12 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.enable_prompt_caching = enable_prompt_caching
|
||||
self.cache_control = cache_control or {"type": "ephemeral"}
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -362,82 +363,14 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
if "bedrock/" in model.lower():
|
||||
return False
|
||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def _supports_prompt_caching(self) -> bool:
|
||||
"""Check if the current model supports prompt caching.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports prompt caching, False otherwise.
|
||||
"""
|
||||
supported_prefixes = (
|
||||
"gpt-",
|
||||
"openai/",
|
||||
"anthropic/",
|
||||
"claude-",
|
||||
"bedrock/",
|
||||
"deepseek/",
|
||||
)
|
||||
return any(prefix in self.model.lower() for prefix in supported_prefixes)
|
||||
|
||||
def _apply_prompt_caching(
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Apply prompt caching to messages for supported providers.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: Messages with cache_control applied where appropriate
|
||||
"""
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
|
||||
# For Anthropic models, add cache_control to the last system message
|
||||
formatted_messages = []
|
||||
system_message_indices = [
|
||||
i for i, msg in enumerate(messages) if msg.get("role") == "system"
|
||||
]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
formatted_message = message.copy()
|
||||
|
||||
if (
|
||||
message.get("role") == "system"
|
||||
and system_message_indices
|
||||
and i == system_message_indices[-1]
|
||||
):
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
formatted_message["content"] = [ # type: ignore[assignment]
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": self.cache_control,
|
||||
}
|
||||
]
|
||||
elif isinstance(content, list) and content:
|
||||
content_copy = content.copy()
|
||||
if content_copy:
|
||||
content_copy[-1] = {
|
||||
**content_copy[-1],
|
||||
"cache_control": self.cache_control,
|
||||
}
|
||||
formatted_message["content"] = content_copy
|
||||
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
@@ -486,11 +419,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -514,7 +447,7 @@ class LLM(BaseLLM):
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
@@ -539,16 +472,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -558,7 +491,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = choice.delta
|
||||
delta = getattr(choice, "delta")
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -568,7 +501,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = delta.content
|
||||
chunk_content = getattr(delta, "content")
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -600,15 +533,15 @@ class LLM(BaseLLM):
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
@@ -639,8 +572,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -650,14 +583,14 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = choice.message
|
||||
message = getattr(choice, "message")
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = message.content
|
||||
content = getattr(message, "content")
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
@@ -684,8 +617,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -694,13 +627,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = choice.message
|
||||
message = getattr(choice, "message")
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = message.tool_calls
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -740,11 +673,11 @@ class LLM(BaseLLM):
|
||||
# Catch context window errors from litellm and convert them to our own exception type.
|
||||
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e)) from e
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {e!s}")
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -755,22 +688,22 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -782,27 +715,16 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
# Convert ChatCompletionDeltaToolCall to ToolCall format
|
||||
from crewai.events.types.llm_events import ToolCall, FunctionCall
|
||||
converted_tool_call = ToolCall(
|
||||
id=tool_call.id,
|
||||
function=FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments or ""
|
||||
),
|
||||
type=tool_call.type,
|
||||
index=tool_call.index
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=converted_tool_call,
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=tool_call.to_dict(),
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
current_tool_accumulator.function.name
|
||||
@@ -822,9 +744,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -847,8 +769,10 @@ class LLM(BaseLLM):
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -862,11 +786,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -891,7 +815,7 @@ class LLM(BaseLLM):
|
||||
except ContextWindowExceededError as e:
|
||||
# Convert litellm's context window error to our own exception type
|
||||
# for consistent handling in the rest of the codebase
|
||||
raise LLMContextLengthExceededException(str(e)) from e
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
@@ -923,7 +847,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
elif tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
@@ -944,11 +868,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -975,9 +899,9 @@ class LLM(BaseLLM):
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
@@ -1015,17 +939,17 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {e!s}",
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
@@ -1034,13 +958,13 @@ class LLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -1067,8 +991,8 @@ class LLM(BaseLLM):
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
@@ -1104,9 +1028,10 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1140,21 +1065,21 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
):
|
||||
"""Handle the events for the LLM call.
|
||||
@@ -1166,22 +1091,22 @@ class LLM(BaseLLM):
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
"""
|
||||
if hasattr(crewai_event_bus, "emit"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=messages,
|
||||
response=response,
|
||||
call_type=call_type,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=messages,
|
||||
response=response,
|
||||
call_type=call_type,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -1222,7 +1147,7 @@ class LLM(BaseLLM):
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||
return messages + [{"role": "user", "content": "Please continue."}]
|
||||
return messages
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
@@ -1232,22 +1157,20 @@ class LLM(BaseLLM):
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return [*messages, {"role": "user", "content": ""}]
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
if self.is_anthropic:
|
||||
# Anthropic requires messages to start with 'user' role
|
||||
if not messages or messages[0]["role"] == "system":
|
||||
# If first message is system or empty, add a placeholder user message
|
||||
messages = [{"role": "user", "content": "."}, *messages]
|
||||
if not self.is_anthropic:
|
||||
return messages
|
||||
|
||||
# Apply prompt caching if enabled and supported (after all other formatting)
|
||||
if self.enable_prompt_caching and self._supports_prompt_caching():
|
||||
messages = self._apply_prompt_caching(messages)
|
||||
# Anthropic requires messages to start with 'user' role
|
||||
if not messages or messages[0]["role"] == "system":
|
||||
# If first message is system or empty, add a placeholder user message
|
||||
return [{"role": "user", "content": "."}, *messages]
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -1284,7 +1207,7 @@ class LLM(BaseLLM):
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1292,7 +1215,7 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -1306,14 +1229,14 @@ class LLM(BaseLLM):
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152 # Current max from gemini-1.5-pro
|
||||
MIN_CONTEXT = 1024
|
||||
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
)
|
||||
|
||||
self.context_window_size = int(
|
||||
@@ -1324,7 +1247,7 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: list[Any]):
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
"""Base LLM abstract class for CrewAI.
|
||||
|
||||
This module provides the abstract base class for all LLM implementations
|
||||
in CrewAI.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Final
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -24,38 +15,41 @@ class BaseLLM(ABC):
|
||||
messages when things go wrong.
|
||||
|
||||
Attributes:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
"""
|
||||
|
||||
model: str
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> None:
|
||||
temperature: Optional[float] = None,
|
||||
):
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
|
||||
All custom LLM implementations should call super().__init__() to ensure
|
||||
that these default attributes are properly initialized.
|
||||
"""
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.stop: list[str] = stop or []
|
||||
self.stop = []
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +64,6 @@ class BaseLLM(ABC):
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
from_agent: Optional agent caller to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
@@ -81,20 +74,21 @@ class BaseLLM(ABC):
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
bool: True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
return True # Default implementation assumes support for stop words
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the LLM.
|
||||
|
||||
Returns:
|
||||
The number of tokens/characters the model can handle.
|
||||
int: The number of tokens/characters the model can handle.
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
return 4096
|
||||
|
||||
88
src/crewai/llms/third_party/ai_suite.py
vendored
88
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -1,62 +1,24 @@
|
||||
"""AI Suite LLM integration for CrewAI.
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
This module provides integration with AI Suite for LLM capabilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aisuite as ai # type: ignore
|
||||
import aisuite as ai
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class AISuiteLLM(BaseLLM):
|
||||
"""AI Suite LLM implementation.
|
||||
|
||||
This class provides integration with AI Suite models through the BaseLLM interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the AI Suite LLM.
|
||||
|
||||
Args:
|
||||
model: The model identifier for AI Suite.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional keyword arguments passed to the AI Suite client.
|
||||
"""
|
||||
super().__init__(model, temperature, stop)
|
||||
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
||||
super().__init__(model, temperature, **kwargs)
|
||||
self.client = ai.Client()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the AI Suite LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task caller.
|
||||
from_agent: Optional agent caller.
|
||||
|
||||
Returns:
|
||||
The text response from the LLM.
|
||||
"""
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
@@ -64,35 +26,15 @@ class AISuiteLLM(BaseLLM):
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the AI Suite completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for the completion API.
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"tools": tools,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
if self.stop:
|
||||
params["stop"] = self.stop
|
||||
|
||||
return params
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
False, as AI Suite does not currently support function calling.
|
||||
"""
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -21,9 +19,9 @@ class ContextualMemory:
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Agent | None = None,
|
||||
task: Task | None = None,
|
||||
) -> None:
|
||||
agent: Optional["Agent"] = None,
|
||||
task: Optional["Task"] = None,
|
||||
):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
@@ -44,7 +42,7 @@ class ContextualMemory:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
@@ -54,15 +52,14 @@ class ContextualMemory:
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context_parts = [
|
||||
self._fetch_ltm_context(task.description),
|
||||
self._fetch_stm_context(query),
|
||||
self._fetch_entity_context(query),
|
||||
self._fetch_external_context(query),
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
context = []
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
context.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -73,11 +70,11 @@ class ContextualMemory:
|
||||
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task: str) -> str | None:
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -93,14 +90,14 @@ class ContextualMemory:
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -110,7 +107,7 @@ class ContextualMemory:
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
@@ -131,6 +128,6 @@ class ContextualMemory:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
f"- {result['context']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.chromadb.utils import _sanitize_collection_name
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
@@ -16,7 +13,6 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -32,8 +28,7 @@ class Mem0Storage(Storage):
|
||||
supported_types = {"short_term", "long_term", "entities", "external"}
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. "
|
||||
f"Must be one of: {', '.join(supported_types)}"
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
def _extract_config_values(self):
|
||||
@@ -71,8 +66,7 @@ class Mem0Storage(Storage):
|
||||
- Includes user_id and agent_id if both are present.
|
||||
- Includes user_id if only user_id is present.
|
||||
- Includes agent_id if only agent_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and
|
||||
mem0_run_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
|
||||
"""
|
||||
filter = defaultdict(list)
|
||||
|
||||
@@ -92,44 +86,21 @@ class Mem0Storage(Storage):
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
|
||||
return next(
|
||||
(
|
||||
m.get("content", "")
|
||||
for m in reversed(list(messages))
|
||||
if m.get("role") == role
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
conversations = []
|
||||
messages = metadata.pop("messages", None)
|
||||
if messages:
|
||||
last_user = _last_content(messages, "user")
|
||||
last_assistant = _last_content(messages, "assistant")
|
||||
|
||||
if user_msg := self._get_user_message(last_user):
|
||||
conversations.append({"role": "user", "content": user_msg})
|
||||
|
||||
if assistant_msg := self._get_assistant_message(last_assistant):
|
||||
conversations.append({"role": "assistant", "content": assistant_msg})
|
||||
else:
|
||||
conversations.append({"role": "assistant", "content": value})
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external",
|
||||
"external": "external"
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer,
|
||||
"infer": self.infer
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
@@ -148,17 +119,15 @@ class Mem0Storage(Storage):
|
||||
if agent_id := self.config.get("agent_id", self._get_agent_name()):
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
self.memory.add(conversations, **params)
|
||||
self.memory.add(assistant_message, **params)
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[Any]:
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
@@ -179,10 +148,10 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params["threshold"] = score_threshold
|
||||
params['threshold'] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
@@ -190,8 +159,8 @@ class Mem0Storage(Storage):
|
||||
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["content"] = result["memory"]
|
||||
|
||||
result["context"] = result["memory"]
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self):
|
||||
@@ -211,19 +180,4 @@ class Mem0Storage(Storage):
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return _sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
|
||||
def _get_assistant_message(self, text: str) -> str:
|
||||
marker = "Final Answer:"
|
||||
if marker in text:
|
||||
return text.split(marker, 1)[1].strip()
|
||||
return text
|
||||
|
||||
def _get_user_message(self, text: str) -> str:
|
||||
pattern = r"User message:\s*(.*)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return text
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from typing import Any, Dict, List, Optional
|
||||
from chromadb.api import ClientAPI
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
import warnings
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
@@ -19,6 +20,8 @@ class RAGStorage(BaseRAGStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
@@ -30,25 +33,37 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
from chromadb.config import Settings
|
||||
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(embedding_function=embedding_function)
|
||||
self._client = create_client(config)
|
||||
self._set_embedder_config()
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
self.app = create_persistent_client(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
logging.info(f"Collection found or created: {self.collection}")
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -70,65 +85,77 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {e!s}")
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: dict[str, Any] | None = None,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {e!s}")
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
self.app = None
|
||||
self.collection = None
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(
|
||||
e
|
||||
) or "does not exist" in str(e):
|
||||
# Ignore readonly database and collection not found errors (already reset)
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
) from e
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
@@ -1,87 +1,95 @@
|
||||
"""Decorators for defining crew components and their behaviors."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import Any, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
"""Decorators for defining crew components and their behaviors."""
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def before_kickoff(func):
|
||||
def before_kickoff(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Marks a method to execute before crew kickoff."""
|
||||
func.is_before_kickoff = True
|
||||
func.is_before_kickoff = True # type: ignore
|
||||
return func
|
||||
|
||||
|
||||
def after_kickoff(func):
|
||||
def after_kickoff(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Marks a method to execute after crew kickoff."""
|
||||
func.is_after_kickoff = True
|
||||
func.is_after_kickoff = True # type: ignore
|
||||
return func
|
||||
|
||||
|
||||
def task(func):
|
||||
def task(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as a crew task."""
|
||||
func.is_task = True
|
||||
func.is_task = True # type: ignore
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = func.__name__
|
||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
result = func(self, *args, **kwargs)
|
||||
if not result.name: # type: ignore
|
||||
result.name = func.__name__ # type: ignore
|
||||
return result
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
|
||||
def agent(func):
|
||||
def agent(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
func.is_agent = True # type: ignore
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def llm(func):
|
||||
def llm(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
func.is_llm = True # type: ignore
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
def output_json(cls: type[R]) -> type[R]:
|
||||
"""Marks a class as JSON output format."""
|
||||
cls.is_output_json = True
|
||||
cls.is_output_json = True # type: ignore
|
||||
return cls
|
||||
|
||||
|
||||
def output_pydantic(cls):
|
||||
def output_pydantic(cls: type[R]) -> type[R]:
|
||||
"""Marks a class as Pydantic output format."""
|
||||
cls.is_output_pydantic = True
|
||||
cls.is_output_pydantic = True # type: ignore
|
||||
return cls
|
||||
|
||||
|
||||
def tool(func):
|
||||
def tool(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as a crew tool."""
|
||||
func.is_tool = True
|
||||
func.is_tool = True # type: ignore
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def callback(func):
|
||||
def callback(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as a crew callback."""
|
||||
func.is_callback = True
|
||||
func.is_callback = True # type: ignore
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def cache_handler(func):
|
||||
def cache_handler(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""Marks a method as a cache handler."""
|
||||
func.is_cache_handler = True
|
||||
func.is_cache_handler = True # type: ignore
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def crew(func) -> Callable[..., Crew]:
|
||||
def crew(
|
||||
func: Callable[Concatenate[Any, P], Crew],
|
||||
) -> Callable[Concatenate[Any, P], Crew]:
|
||||
"""Marks a method as the main crew execution point."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs) -> Crew:
|
||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> Crew:
|
||||
instantiated_tasks = []
|
||||
instantiated_agents = []
|
||||
agent_roles = set()
|
||||
@@ -91,7 +99,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agents = self._original_agents.items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _task_name, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +108,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _agent_name, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -109,19 +117,23 @@ def crew(func) -> Callable[..., Crew]:
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew = func(self, *args, **kwargs)
|
||||
crew_result = func(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(callback, instance):
|
||||
def wrapper(*args, **kwargs):
|
||||
return callback(instance, *args, **kwargs)
|
||||
def callback_wrapper(callback_func: Any, instance: Any) -> Callable[..., Any]:
|
||||
def inner_wrapper(*cb_args: Any, **cb_kwargs: Any) -> Any:
|
||||
return callback_func(instance, *cb_args, **cb_kwargs)
|
||||
|
||||
return wrapper
|
||||
return inner_wrapper
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for callback_func in self._before_kickoff.values():
|
||||
crew_result.before_kickoff_callbacks.append(
|
||||
callback_wrapper(callback_func, self)
|
||||
)
|
||||
for callback_func in self._after_kickoff.values():
|
||||
crew_result.after_kickoff_callbacks.append(
|
||||
callback_wrapper(callback_func, self)
|
||||
)
|
||||
|
||||
return crew
|
||||
return crew_result
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import Any, Callable, Dict, TypeVar, cast, List
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
load_dotenv()
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
@@ -16,7 +14,7 @@ T = TypeVar("T", bound=type)
|
||||
"""Base decorator for creating crew classes with configuration and function management."""
|
||||
|
||||
|
||||
def CrewBase(cls: T) -> T: # noqa: N802
|
||||
def CrewBase(cls: T) -> T:
|
||||
"""Wraps a class with crew functionality and configuration management."""
|
||||
|
||||
class WrappedClass(cls): # type: ignore
|
||||
@@ -31,7 +29,6 @@ def CrewBase(cls: T) -> T: # noqa: N802
|
||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
|
||||
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -89,18 +86,15 @@ def CrewBase(cls: T) -> T: # noqa: N802
|
||||
import types
|
||||
return types.MethodType(_close_mcp_server, self)
|
||||
|
||||
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
|
||||
def get_mcp_tools(self, *tool_names: list[str]) -> List[BaseTool]:
|
||||
if not self.mcp_server_params:
|
||||
return []
|
||||
|
||||
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
|
||||
from crewai_tools import MCPServerAdapter
|
||||
|
||||
adapter = getattr(self, '_mcp_server_adapter', None)
|
||||
if not adapter:
|
||||
self._mcp_server_adapter = MCPServerAdapter(
|
||||
self.mcp_server_params,
|
||||
connect_timeout=self.mcp_connect_timeout
|
||||
)
|
||||
self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params)
|
||||
|
||||
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
|
||||
|
||||
@@ -160,8 +154,8 @@ def CrewBase(cls: T) -> T: # noqa: N802
|
||||
}
|
||||
|
||||
def _filter_functions(
|
||||
self, functions: dict[str, Callable], attribute: str
|
||||
) -> dict[str, Callable]:
|
||||
self, functions: Dict[str, Callable], attribute: str
|
||||
) -> Dict[str, Callable]:
|
||||
return {
|
||||
name: func
|
||||
for name, func in functions.items()
|
||||
@@ -190,11 +184,11 @@ def CrewBase(cls: T) -> T: # noqa: N802
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
cache_handler_functions: dict[str, Callable],
|
||||
callbacks: dict[str, Callable],
|
||||
agent_info: Dict[str, Any],
|
||||
llms: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
cache_handler_functions: Dict[str, Callable],
|
||||
callbacks: Dict[str, Callable],
|
||||
) -> None:
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
@@ -251,13 +245,13 @@ def CrewBase(cls: T) -> T: # noqa: N802
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable],
|
||||
tasks: dict[str, Callable],
|
||||
output_json_functions: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
callback_functions: dict[str, Callable],
|
||||
output_pydantic_functions: dict[str, Callable],
|
||||
task_info: Dict[str, Any],
|
||||
agents: Dict[str, Callable],
|
||||
tasks: Dict[str, Callable],
|
||||
output_json_functions: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
callback_functions: Dict[str, Callable],
|
||||
output_pydantic_functions: Dict[str, Callable],
|
||||
) -> None:
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def memoize(func):
|
||||
cache = {}
|
||||
def memoize(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator that caches function results based on arguments.
|
||||
|
||||
Args:
|
||||
func: The function to memoize.
|
||||
|
||||
Returns:
|
||||
The memoized function.
|
||||
"""
|
||||
cache: dict[tuple, R] = {}
|
||||
|
||||
@wraps(func)
|
||||
def memoized_func(*args, **kwargs):
|
||||
def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Memoized wrapper function."""
|
||||
key = (args, tuple(kwargs.items()))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
|
||||
@@ -4,9 +4,8 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
@@ -24,13 +23,13 @@ from crewai.rag.chromadb.utils import (
|
||||
_process_query_results,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
@@ -47,7 +46,7 @@ class ChromaDBClient(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -307,12 +306,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
@@ -350,12 +347,10 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
|
||||
@@ -3,18 +3,18 @@
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
DEFAULT_TENANT,
|
||||
)
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
)
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import os
|
||||
from hashlib import md5
|
||||
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
@@ -24,7 +23,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
|
||||
@@ -3,28 +3,27 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
@@ -45,7 +44,7 @@ class PreparedDocuments(NamedTuple):
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings (empty dict for no metadata)
|
||||
metadatas: List of document metadata mappings
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
@@ -86,7 +85,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
@@ -5,14 +5,13 @@ from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_COLLECTION,
|
||||
INVALID_CHARS_PATTERN,
|
||||
@@ -79,7 +78,7 @@ def _prepare_documents_for_chromadb(
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata and metadata[0] else {})
|
||||
metadatas.append(metadata[0] if metadata else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
@@ -155,7 +154,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
include_strings = [item.value for item in include]
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
@@ -189,9 +188,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i])
|
||||
if metadatas and i < len(metadatas) and metadatas[i] is not None
|
||||
else {},
|
||||
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
@@ -274,7 +271,7 @@ def _sanitize_collection_name(
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable, Annotated
|
||||
from typing_extensions import Unpack, Required, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from typing_extensions import Required, TypedDict, Unpack
|
||||
|
||||
|
||||
from crewai.rag.types import (
|
||||
BaseRecord,
|
||||
EmbeddingFunction,
|
||||
BaseRecord,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
|
||||
query: Required[str]
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
metadata_filter: dict[str, Any]
|
||||
score_threshold: float
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
@@ -60,7 +60,7 @@ def get_embedding_function(
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings
|
||||
- openai: OpenAI embeddings (default)
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
@@ -77,7 +77,7 @@ def get_embedding_function(
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI embedding
|
||||
# Use default OpenAI with retry logic
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
@@ -32,21 +32,45 @@ class BaseRAGStorage(ABC):
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: dict[str, Any] | None = None,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
) -> List[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Security constants for CrewAI.
|
||||
|
||||
This module contains security-related constants used throughout the security module.
|
||||
|
||||
Notes:
|
||||
- TODO: Determine if CREW_AI_NAMESPACE should be made dynamic or configurable
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
CREW_AI_NAMESPACE: Annotated[
|
||||
UUID,
|
||||
"Create a deterministic UUID using v5 (SHA-1). Custom namespace for CrewAI to enhance security.",
|
||||
] = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||
@@ -1,123 +1,130 @@
|
||||
"""Fingerprint Module
|
||||
"""
|
||||
Fingerprint Module
|
||||
|
||||
This module provides functionality for generating and validating unique identifiers
|
||||
for CrewAI agents. These identifiers are used for tracking, auditing, and security.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
from uuid import UUID, uuid4, uuid5
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.security.constants import CREW_AI_NAMESPACE
|
||||
|
||||
|
||||
def _validate_metadata(v: Any) -> dict[str, Any]:
|
||||
"""Validate that metadata is a dictionary with string keys and valid values."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Metadata must be a dictionary")
|
||||
|
||||
# Validate that all keys are strings
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata keys must be strings, got {type(key)}")
|
||||
|
||||
# Validate nested dictionaries (prevent deeply nested structures)
|
||||
if isinstance(value, dict):
|
||||
# Check for nested dictionaries (limit depth to 1)
|
||||
for nested_key, nested_value in value.items():
|
||||
if not isinstance(nested_key, str):
|
||||
raise ValueError(
|
||||
f"Nested metadata keys must be strings, got {type(nested_key)}"
|
||||
)
|
||||
if isinstance(nested_value, dict):
|
||||
raise ValueError("Metadata can only be nested one level deep")
|
||||
|
||||
# Check for maximum metadata size (prevent DoS)
|
||||
if len(str(v)) > 10_000: # Limit metadata size to 10KB
|
||||
raise ValueError("Metadata size exceeds maximum allowed (10KB)")
|
||||
|
||||
return v
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class Fingerprint(BaseModel):
|
||||
"""A class for generating and managing unique identifiers for agents.
|
||||
"""
|
||||
A class for generating and managing unique identifiers for agents.
|
||||
|
||||
Each agent has dual identifiers:
|
||||
- Human-readable ID: For debugging and reference (derived from role if not specified)
|
||||
- Fingerprint UUID: Unique runtime identifier for tracking and auditing
|
||||
|
||||
Attributes:
|
||||
uuid_str: String representation of the UUID for this fingerprint, auto-generated
|
||||
created_at: When this fingerprint was created, auto-generated
|
||||
metadata: Additional metadata associated with this fingerprint
|
||||
uuid_str (str): String representation of the UUID for this fingerprint, auto-generated
|
||||
created_at (datetime): When this fingerprint was created, auto-generated
|
||||
metadata (Dict[str, Any]): Additional metadata associated with this fingerprint
|
||||
"""
|
||||
|
||||
_uuid_str: str = PrivateAttr(default_factory=lambda: str(uuid4()))
|
||||
_created_at: datetime = PrivateAttr(default_factory=datetime.now)
|
||||
metadata: Annotated[dict[str, Any], BeforeValidator(_validate_metadata)] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator('metadata')
|
||||
@classmethod
|
||||
def validate_metadata(cls, v):
|
||||
"""Validate that metadata is a dictionary with string keys and valid values."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Metadata must be a dictionary")
|
||||
|
||||
# Validate that all keys are strings
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata keys must be strings, got {type(key)}")
|
||||
|
||||
# Validate nested dictionaries (prevent deeply nested structures)
|
||||
if isinstance(value, dict):
|
||||
# Check for nested dictionaries (limit depth to 1)
|
||||
for nested_key, nested_value in value.items():
|
||||
if not isinstance(nested_key, str):
|
||||
raise ValueError(f"Nested metadata keys must be strings, got {type(nested_key)}")
|
||||
if isinstance(nested_value, dict):
|
||||
raise ValueError("Metadata can only be nested one level deep")
|
||||
|
||||
# Check for maximum metadata size (prevent DoS)
|
||||
if len(str(v)) > 10000: # Limit metadata size to 10KB
|
||||
raise ValueError("Metadata size exceeds maximum allowed (10KB)")
|
||||
|
||||
return v
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize a Fingerprint with auto-generated uuid_str and created_at."""
|
||||
# Remove uuid_str and created_at from data to ensure they're auto-generated
|
||||
if 'uuid_str' in data:
|
||||
data.pop('uuid_str')
|
||||
if 'created_at' in data:
|
||||
data.pop('created_at')
|
||||
|
||||
# Call the parent constructor with the modified data
|
||||
super().__init__(**data)
|
||||
|
||||
@property
|
||||
def uuid_str(self) -> str:
|
||||
"""Get the string representation of the UUID for this fingerprint."""
|
||||
return self._uuid_str
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime:
|
||||
"""Get the creation timestamp for this fingerprint."""
|
||||
return self._created_at
|
||||
|
||||
@property
|
||||
def uuid(self) -> UUID:
|
||||
def uuid(self) -> uuid.UUID:
|
||||
"""Get the UUID object for this fingerprint."""
|
||||
return UUID(self.uuid_str)
|
||||
return uuid.UUID(self.uuid_str)
|
||||
|
||||
@classmethod
|
||||
def _generate_uuid(cls, seed: str) -> str:
|
||||
"""Generate a deterministic UUID based on a seed string.
|
||||
"""
|
||||
Generate a deterministic UUID based on a seed string.
|
||||
|
||||
Args:
|
||||
seed: The seed string to use for UUID generation
|
||||
seed (str): The seed string to use for UUID generation
|
||||
|
||||
Returns:
|
||||
A string representation of the UUID consistently generated from the seed
|
||||
str: A string representation of the UUID consistently generated from the seed
|
||||
"""
|
||||
if not isinstance(seed, str):
|
||||
raise ValueError("Seed must be a string")
|
||||
|
||||
if not seed.strip():
|
||||
raise ValueError("Seed cannot be empty or whitespace")
|
||||
|
||||
# Create a deterministic UUID using v5 (SHA-1)
|
||||
# Custom namespace for CrewAI to enhance security
|
||||
|
||||
return str(uuid5(CREW_AI_NAMESPACE, seed))
|
||||
# Using a unique namespace specific to CrewAI to reduce collision risks
|
||||
CREW_AI_NAMESPACE = uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')
|
||||
return str(uuid.uuid5(CREW_AI_NAMESPACE, seed))
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls, seed: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> Self:
|
||||
"""Static factory method to create a new Fingerprint.
|
||||
def generate(cls, seed: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> 'Fingerprint':
|
||||
"""
|
||||
Static factory method to create a new Fingerprint.
|
||||
|
||||
Args:
|
||||
seed: A string to use as seed for the UUID generation.
|
||||
seed (Optional[str]): A string to use as seed for the UUID generation.
|
||||
If None, a random UUID is generated.
|
||||
metadata: Additional metadata to store with the fingerprint.
|
||||
metadata (Optional[Dict[str, Any]]): Additional metadata to store with the fingerprint.
|
||||
|
||||
Returns:
|
||||
A new Fingerprint instance
|
||||
Fingerprint: A new Fingerprint instance
|
||||
"""
|
||||
fingerprint = cls(metadata=metadata or {})
|
||||
if seed:
|
||||
# For seed-based generation, we need to manually set the _uuid_str after creation
|
||||
fingerprint.__dict__["_uuid_str"] = cls._generate_uuid(seed)
|
||||
# For seed-based generation, we need to manually set the uuid_str after creation
|
||||
object.__setattr__(fingerprint, 'uuid_str', cls._generate_uuid(seed))
|
||||
return fingerprint
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the fingerprint (the UUID)."""
|
||||
return self.uuid_str
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Compare fingerprints by their UUID."""
|
||||
if type(other) is Fingerprint:
|
||||
if isinstance(other, Fingerprint):
|
||||
return self.uuid_str == other.uuid_str
|
||||
return False
|
||||
|
||||
@@ -125,27 +132,29 @@ class Fingerprint(BaseModel):
|
||||
"""Hash of the fingerprint (based on UUID)."""
|
||||
return hash(self.uuid_str)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert the fingerprint to a dictionary representation.
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the fingerprint to a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the fingerprint
|
||||
Dict[str, Any]: Dictionary representation of the fingerprint
|
||||
"""
|
||||
return {
|
||||
"uuid_str": self.uuid_str,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""Create a Fingerprint from a dictionary representation.
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Fingerprint':
|
||||
"""
|
||||
Create a Fingerprint from a dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary representation of a fingerprint
|
||||
data (Dict[str, Any]): Dictionary representation of a fingerprint
|
||||
|
||||
Returns:
|
||||
A new Fingerprint instance
|
||||
Fingerprint: A new Fingerprint instance
|
||||
"""
|
||||
if not data:
|
||||
return cls()
|
||||
@@ -154,10 +163,8 @@ class Fingerprint(BaseModel):
|
||||
|
||||
# For consistency with existing stored fingerprints, we need to manually set these
|
||||
if "uuid_str" in data:
|
||||
fingerprint.__dict__["_uuid_str"] = data["uuid_str"]
|
||||
object.__setattr__(fingerprint, 'uuid_str', data["uuid_str"])
|
||||
if "created_at" in data and isinstance(data["created_at"], str):
|
||||
fingerprint.__dict__["_created_at"] = datetime.fromisoformat(
|
||||
data["created_at"]
|
||||
)
|
||||
object.__setattr__(fingerprint, 'created_at', datetime.fromisoformat(data["created_at"]))
|
||||
|
||||
return fingerprint
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Security Configuration Module
|
||||
"""
|
||||
Security Configuration Module
|
||||
|
||||
This module provides configuration for CrewAI security features, including:
|
||||
- Authentication settings
|
||||
@@ -9,10 +10,9 @@ The SecurityConfig class is the primary interface for managing security settings
|
||||
in CrewAI applications.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
|
||||
@@ -28,6 +28,7 @@ class SecurityConfig(BaseModel):
|
||||
- Impersonation/delegation tokens *TODO*
|
||||
|
||||
Attributes:
|
||||
version (str): Version of the security configuration
|
||||
fingerprint (Fingerprint): The unique fingerprint automatically generated for the component
|
||||
"""
|
||||
|
||||
@@ -36,52 +37,80 @@ class SecurityConfig(BaseModel):
|
||||
# Note: Cannot use frozen=True as existing tests modify the fingerprint property
|
||||
)
|
||||
|
||||
fingerprint: Fingerprint = Field(
|
||||
default_factory=Fingerprint, description="Unique identifier for the component"
|
||||
version: str = Field(
|
||||
default="1.0.0",
|
||||
description="Version of the security configuration"
|
||||
)
|
||||
|
||||
@field_validator("fingerprint", mode="before")
|
||||
fingerprint: Fingerprint = Field(
|
||||
default_factory=Fingerprint,
|
||||
description="Unique identifier for the component"
|
||||
)
|
||||
|
||||
def is_compatible(self, min_version: str) -> bool:
|
||||
"""
|
||||
Check if this security configuration is compatible with the minimum required version.
|
||||
|
||||
Args:
|
||||
min_version (str): Minimum required version in semver format (e.g., "1.0.0")
|
||||
|
||||
Returns:
|
||||
bool: True if this configuration is compatible, False otherwise
|
||||
"""
|
||||
# Simple version comparison (can be enhanced with packaging.version if needed)
|
||||
current = [int(x) for x in self.version.split(".")]
|
||||
minimum = [int(x) for x in min_version.split(".")]
|
||||
|
||||
# Compare major, minor, patch versions
|
||||
for c, m in zip(current, minimum):
|
||||
if c > m:
|
||||
return True
|
||||
if c < m:
|
||||
return False
|
||||
return True
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def validate_fingerprint(cls, v: Any) -> Fingerprint:
|
||||
def validate_fingerprint(cls, values):
|
||||
"""Ensure fingerprint is properly initialized."""
|
||||
if v is None:
|
||||
return Fingerprint()
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
raise ValueError("Fingerprint seed cannot be empty")
|
||||
return Fingerprint.generate(seed=v)
|
||||
if isinstance(v, dict):
|
||||
return Fingerprint.from_dict(v)
|
||||
if isinstance(v, Fingerprint):
|
||||
return v
|
||||
if isinstance(values, dict):
|
||||
# Handle case where fingerprint is not provided or is None
|
||||
if 'fingerprint' not in values or values['fingerprint'] is None:
|
||||
values['fingerprint'] = Fingerprint()
|
||||
# Handle case where fingerprint is a string (seed)
|
||||
elif isinstance(values['fingerprint'], str):
|
||||
if not values['fingerprint'].strip():
|
||||
raise ValueError("Fingerprint seed cannot be empty")
|
||||
values['fingerprint'] = Fingerprint.generate(seed=values['fingerprint'])
|
||||
return values
|
||||
|
||||
raise ValueError(f"Invalid fingerprint type: {type(v)}")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the security config to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the security config
|
||||
Dict[str, Any]: Dictionary representation of the security config
|
||||
"""
|
||||
return {"fingerprint": self.fingerprint.to_dict()}
|
||||
result = {
|
||||
"fingerprint": self.fingerprint.to_dict()
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'SecurityConfig':
|
||||
"""
|
||||
Create a SecurityConfig from a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary representation of a security config
|
||||
data (Dict[str, Any]): Dictionary representation of a security config
|
||||
|
||||
Returns:
|
||||
A new SecurityConfig instance
|
||||
SecurityConfig: A new SecurityConfig instance
|
||||
"""
|
||||
fingerprint_data = data.get("fingerprint")
|
||||
fingerprint = (
|
||||
Fingerprint.from_dict(fingerprint_data)
|
||||
if fingerprint_data
|
||||
else Fingerprint()
|
||||
)
|
||||
# Make a copy to avoid modifying the original
|
||||
data_copy = data.copy()
|
||||
|
||||
fingerprint_data = data_copy.pop("fingerprint", None)
|
||||
fingerprint = Fingerprint.from_dict(fingerprint_data) if fingerprint_data else Fingerprint()
|
||||
|
||||
return cls(fingerprint=fingerprint)
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
"""Conditional task execution based on previous task output."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -11,54 +8,37 @@ from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ConditionalTask(Task):
|
||||
"""A task that can be conditionally executed based on the output of another task.
|
||||
|
||||
This task type allows for dynamic workflow execution based on the results of
|
||||
previous tasks in the crew execution chain.
|
||||
|
||||
Attributes:
|
||||
condition: Function that evaluates previous task output to determine execution.
|
||||
|
||||
Notes:
|
||||
- Cannot be the only task in your crew
|
||||
- Cannot be the first task since it needs context from the previous task
|
||||
"""
|
||||
A task that can be conditionally executed based on the output of another task.
|
||||
Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task.
|
||||
"""
|
||||
|
||||
condition: Callable[[TaskOutput], bool] | None = Field(
|
||||
condition: Callable[[TaskOutput], bool] = Field(
|
||||
default=None,
|
||||
description="Function that determines whether the task should be executed based on previous task output.",
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
condition: Callable[[Any], bool],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
|
||||
def should_execute(self, context: TaskOutput) -> bool:
|
||||
"""Determines whether the conditional task should be executed based on the provided context.
|
||||
"""
|
||||
Determines whether the conditional task should be executed based on the provided context.
|
||||
|
||||
Args:
|
||||
context: The output from the previous task that will be evaluated by the condition.
|
||||
context (Any): The context or output from the previous task that will be evaluated by the condition.
|
||||
|
||||
Returns:
|
||||
True if the task should be executed, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If no condition function is set.
|
||||
bool: True if the task should be executed, False otherwise.
|
||||
"""
|
||||
if self.condition is None:
|
||||
raise ValueError("No condition function set for conditional task")
|
||||
return self.condition(context)
|
||||
|
||||
def get_skipped_task_output(self) -> TaskOutput:
|
||||
"""Generate a TaskOutput for when the conditional task is skipped.
|
||||
|
||||
Returns:
|
||||
Empty TaskOutput with RAW format indicating the task was skipped.
|
||||
"""
|
||||
def get_skipped_task_output(self):
|
||||
return TaskOutput(
|
||||
description=self.description,
|
||||
raw="",
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
"""Task output format definitions for CrewAI."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
"""Enum that represents the output format of a task.
|
||||
|
||||
Attributes:
|
||||
JSON: Output as JSON dictionary format
|
||||
PYDANTIC: Output as Pydantic model instance
|
||||
RAW: Output as raw unprocessed string
|
||||
"""
|
||||
"""Enum that represents the output format of a task."""
|
||||
|
||||
JSON = "json"
|
||||
PYDANTIC = "pydantic"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Task output representation and formatting."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -9,31 +7,19 @@ from crewai.tasks.output_format import OutputFormat
|
||||
|
||||
|
||||
class TaskOutput(BaseModel):
|
||||
"""Class that represents the result of a task.
|
||||
|
||||
Attributes:
|
||||
description: Description of the task
|
||||
name: Optional name of the task
|
||||
expected_output: Expected output of the task
|
||||
summary: Summary of the task (auto-generated from description)
|
||||
raw: Raw output of the task
|
||||
pydantic: Pydantic model output of the task
|
||||
json_dict: JSON dictionary output of the task
|
||||
agent: Agent that executed the task
|
||||
output_format: Output format of the task (JSON, PYDANTIC, or RAW)
|
||||
"""
|
||||
"""Class that represents the result of a task."""
|
||||
|
||||
description: str = Field(description="Description of the task")
|
||||
name: str | None = Field(description="Name of the task", default=None)
|
||||
expected_output: str | None = Field(
|
||||
name: Optional[str] = Field(description="Name of the task", default=None)
|
||||
expected_output: Optional[str] = Field(
|
||||
description="Expected output of the task", default=None
|
||||
)
|
||||
summary: str | None = Field(description="Summary of the task", default=None)
|
||||
summary: Optional[str] = Field(description="Summary of the task", default=None)
|
||||
raw: str = Field(description="Raw output of the task", default="")
|
||||
pydantic: BaseModel | None = Field(
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
description="Pydantic output of task", default=None
|
||||
)
|
||||
json_dict: dict[str, Any] | None = Field(
|
||||
json_dict: Optional[Dict[str, Any]] = Field(
|
||||
description="JSON dictionary of task", default=None
|
||||
)
|
||||
agent: str = Field(description="Agent that executed the task")
|
||||
@@ -43,28 +29,13 @@ class TaskOutput(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_summary(self):
|
||||
"""Set the summary field based on the description.
|
||||
|
||||
Returns:
|
||||
Self with updated summary field.
|
||||
"""
|
||||
"""Set the summary field based on the description."""
|
||||
excerpt = " ".join(self.description.split(" ")[:10])
|
||||
self.summary = f"{excerpt}..."
|
||||
return self
|
||||
|
||||
@property
|
||||
def json(self) -> str | None: # type: ignore[override]
|
||||
"""Get the JSON string representation of the task output.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If output format is not JSON.
|
||||
|
||||
Notes:
|
||||
TODO: Refactor to use model_dump_json() to avoid BaseModel method conflict
|
||||
"""
|
||||
def json(self) -> Optional[str]:
|
||||
if self.output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"""
|
||||
@@ -76,13 +47,8 @@ class TaskOutput(BaseModel):
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the task output. Prioritizes json_dict
|
||||
over pydantic model dump if both are available.
|
||||
"""
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
output_dict = {}
|
||||
if self.json_dict:
|
||||
output_dict.update(self.json_dict)
|
||||
|
||||
@@ -1,9 +1,2 @@
|
||||
"""Telemetry configuration constants.
|
||||
|
||||
This module defines constants used for CrewAI telemetry configuration.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
|
||||
CREWAI_TELEMETRY_BASE_URL: Final[str] = "https://telemetry.crewai.com:4319"
|
||||
CREWAI_TELEMETRY_SERVICE_NAME: Final[str] = "crewAI-telemetry"
|
||||
CREWAI_TELEMETRY_BASE_URL: str = "https://telemetry.crewai.com:4319"
|
||||
CREWAI_TELEMETRY_SERVICE_NAME: str = "crewAI-telemetry"
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
"""Telemetry module for CrewAI.
|
||||
|
||||
This module provides anonymous telemetry collection for development purposes.
|
||||
No prompts, task descriptions, agent backstories/goals, responses, or sensitive
|
||||
data is collected. Users can opt-in to share more complete data using the
|
||||
`share_crew` attribute.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -13,10 +5,11 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from importlib.metadata import version
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
import threading
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
@@ -28,43 +21,30 @@ from opentelemetry.sdk.trace.export import (
|
||||
BatchSpanProcessor,
|
||||
SpanExportResult,
|
||||
)
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_BASE_URL,
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
)
|
||||
from crewai.telemetry.utils import (
|
||||
add_agent_fingerprint_to_span,
|
||||
add_crew_and_task_attributes,
|
||||
add_crew_attributes,
|
||||
close_span,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
yield
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
"""Safe wrapper for OTLP span exporter that handles exceptions gracefully.
|
||||
|
||||
This exporter prevents telemetry failures from breaking the application
|
||||
by catching and logging exceptions during span export.
|
||||
"""
|
||||
|
||||
def export(self, spans: Any) -> SpanExportResult:
|
||||
"""Export spans to the telemetry backend safely.
|
||||
|
||||
Args:
|
||||
spans: Collection of spans to export.
|
||||
|
||||
Returns:
|
||||
Export result status, FAILURE if an exception occurs.
|
||||
"""
|
||||
def export(self, spans) -> SpanExportResult:
|
||||
try:
|
||||
return super().export(spans)
|
||||
except Exception as e:
|
||||
@@ -73,13 +53,16 @@ class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
|
||||
|
||||
class Telemetry:
|
||||
"""Handle anonymous telemetry for the CrewAI package.
|
||||
"""A class to handle anonymous telemetry for the crewai package.
|
||||
|
||||
Attributes:
|
||||
ready: Whether telemetry is initialized and ready.
|
||||
trace_set: Whether the tracer provider has been set.
|
||||
resource: OpenTelemetry resource for the telemetry service.
|
||||
provider: OpenTelemetry tracer provider.
|
||||
The data being collected is for development purpose, all data is anonymous.
|
||||
|
||||
There is NO data being collected on the prompts, tasks descriptions
|
||||
agents backstories or goals nor responses or any data that is being
|
||||
processed by the agents, nor any secrets and env vars.
|
||||
|
||||
Users can opt-in to sharing more complete data using the `share_crew`
|
||||
attribute in the Crew class.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
@@ -89,14 +72,14 @@ class Telemetry:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance = super(Telemetry, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized") and self._initialized:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
|
||||
self.ready: bool = False
|
||||
self.trace_set: bool = False
|
||||
self._initialized: bool = True
|
||||
@@ -141,41 +124,29 @@ class Telemetry:
|
||||
"""Check if telemetry operations should be executed."""
|
||||
return self.ready and not self._is_telemetry_disabled()
|
||||
|
||||
def set_tracer(self) -> None:
|
||||
"""Set the tracer provider if ready and not already set."""
|
||||
def set_tracer(self):
|
||||
if self.ready and not self.trace_set:
|
||||
try:
|
||||
with suppress_warnings():
|
||||
trace.set_tracer_provider(self.provider)
|
||||
self.trace_set = True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to set tracer provider: {e}")
|
||||
except Exception:
|
||||
self.ready = False
|
||||
self.trace_set = False
|
||||
|
||||
def _safe_telemetry_operation(self, operation: Callable[[], Any]) -> None:
|
||||
"""Execute telemetry operation safely, checking both readiness and environment variables.
|
||||
|
||||
Args:
|
||||
operation: A callable that performs telemetry operations. May return any value,
|
||||
but the return value is not used by this method.
|
||||
"""
|
||||
def _safe_telemetry_operation(self, operation: Callable[[], None]) -> None:
|
||||
"""Execute telemetry operation safely, checking both readiness and environment variables."""
|
||||
if not self._should_execute_telemetry():
|
||||
return
|
||||
try:
|
||||
operation()
|
||||
except Exception as e:
|
||||
logger.debug(f"Telemetry operation failed: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
|
||||
"""Records the creation of a crew.
|
||||
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None):
|
||||
"""Records the creation of a crew."""
|
||||
|
||||
Args:
|
||||
crew: The crew being created.
|
||||
inputs: Optional input parameters for the crew.
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Created")
|
||||
self._add_attribute(
|
||||
@@ -184,14 +155,16 @@ class Telemetry:
|
||||
version("crewai"),
|
||||
)
|
||||
self._add_attribute(span, "python_version", platform.python_version())
|
||||
add_crew_attributes(span, crew, self._add_attribute)
|
||||
self._add_attribute(span, "crew_key", crew.key)
|
||||
self._add_attribute(span, "crew_id", str(crew.id))
|
||||
self._add_attribute(span, "crew_process", crew.process)
|
||||
self._add_attribute(span, "crew_memory", crew.memory)
|
||||
self._add_attribute(span, "crew_number_of_tasks", len(crew.tasks))
|
||||
self._add_attribute(span, "crew_number_of_agents", len(crew.agents))
|
||||
|
||||
# Add additional fingerprint metadata if available
|
||||
# Add fingerprint data
|
||||
if hasattr(crew, "fingerprint") and crew.fingerprint:
|
||||
self._add_attribute(span, "crew_fingerprint", crew.fingerprint.uuid_str)
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_fingerprint_created_at",
|
||||
@@ -370,27 +343,29 @@ class Telemetry:
|
||||
]
|
||||
),
|
||||
)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def task_started(self, crew: Crew, task: Task) -> Span | None:
|
||||
"""Records task started in a crew.
|
||||
"""Records task started in a crew."""
|
||||
|
||||
Args:
|
||||
crew: The crew executing the task.
|
||||
task: The task being started.
|
||||
|
||||
Returns:
|
||||
The span tracking the task execution, or None if telemetry is disabled.
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
|
||||
created_span = tracer.start_span("Task Created")
|
||||
|
||||
add_crew_and_task_attributes(created_span, crew, task, self._add_attribute)
|
||||
self._add_attribute(created_span, "crew_key", crew.key)
|
||||
self._add_attribute(created_span, "crew_id", str(crew.id))
|
||||
self._add_attribute(created_span, "task_key", task.key)
|
||||
self._add_attribute(created_span, "task_id", str(task.id))
|
||||
|
||||
# Add fingerprint data
|
||||
if hasattr(crew, "fingerprint") and crew.fingerprint:
|
||||
self._add_attribute(
|
||||
created_span, "crew_fingerprint", crew.fingerprint.uuid_str
|
||||
)
|
||||
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(
|
||||
@@ -411,9 +386,13 @@ class Telemetry:
|
||||
|
||||
# Add agent fingerprint if task has an assigned agent
|
||||
if hasattr(task, "agent") and task.agent:
|
||||
add_agent_fingerprint_to_span(
|
||||
created_span, task.agent, self._add_attribute
|
||||
agent_fingerprint = getattr(
|
||||
getattr(task.agent, "fingerprint", None), "uuid_str", None
|
||||
)
|
||||
if agent_fingerprint:
|
||||
self._add_attribute(
|
||||
created_span, "agent_fingerprint", agent_fingerprint
|
||||
)
|
||||
|
||||
if crew.share_crew:
|
||||
self._add_attribute(
|
||||
@@ -423,18 +402,30 @@ class Telemetry:
|
||||
created_span, "formatted_expected_output", task.expected_output
|
||||
)
|
||||
|
||||
close_span(created_span)
|
||||
created_span.set_status(Status(StatusCode.OK))
|
||||
created_span.end()
|
||||
|
||||
span = tracer.start_span("Task Execution")
|
||||
|
||||
add_crew_and_task_attributes(span, crew, task, self._add_attribute)
|
||||
self._add_attribute(span, "crew_key", crew.key)
|
||||
self._add_attribute(span, "crew_id", str(crew.id))
|
||||
self._add_attribute(span, "task_key", task.key)
|
||||
self._add_attribute(span, "task_id", str(task.id))
|
||||
|
||||
# Add fingerprint data to execution span
|
||||
if hasattr(crew, "fingerprint") and crew.fingerprint:
|
||||
self._add_attribute(span, "crew_fingerprint", crew.fingerprint.uuid_str)
|
||||
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
|
||||
# Add agent fingerprint if task has an assigned agent
|
||||
if hasattr(task, "agent") and task.agent:
|
||||
add_agent_fingerprint_to_span(span, task.agent, self._add_attribute)
|
||||
agent_fingerprint = getattr(
|
||||
getattr(task.agent, "fingerprint", None), "uuid_str", None
|
||||
)
|
||||
if agent_fingerprint:
|
||||
self._add_attribute(span, "agent_fingerprint", agent_fingerprint)
|
||||
|
||||
if crew.share_crew:
|
||||
self._add_attribute(span, "formatted_description", task.description)
|
||||
@@ -444,25 +435,22 @@ class Telemetry:
|
||||
|
||||
return span
|
||||
|
||||
if not self._should_execute_telemetry():
|
||||
return None
|
||||
self._safe_telemetry_operation(operation)
|
||||
return None
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
return _operation()
|
||||
|
||||
def task_ended(self, span: Span, task: Task, crew: Crew) -> None:
|
||||
def task_ended(self, span: Span, task: Task, crew: Crew):
|
||||
"""Records the completion of a task execution in a crew.
|
||||
|
||||
Args:
|
||||
span: The OpenTelemetry span tracking the task execution.
|
||||
task: The task that was completed.
|
||||
crew: The crew context in which the task was executed.
|
||||
span (Span): The OpenTelemetry span tracking the task execution
|
||||
task (Task): The task that was completed
|
||||
crew (Crew): The crew context in which the task was executed
|
||||
|
||||
Note:
|
||||
If share_crew is enabled, this will also record the task output.
|
||||
If share_crew is enabled, this will also record the task output
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
# Ensure fingerprint data is present on completion span
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
@@ -474,20 +462,21 @@ class Telemetry:
|
||||
task.output.raw if task.output else "",
|
||||
)
|
||||
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None:
|
||||
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int):
|
||||
"""Records when a tool is used repeatedly, which might indicate an issue.
|
||||
|
||||
Args:
|
||||
llm: The language model being used.
|
||||
tool_name: Name of the tool being repeatedly used.
|
||||
attempts: Number of attempts made with this tool.
|
||||
llm (Any): The language model being used
|
||||
tool_name (str): Name of the tool being repeatedly used
|
||||
attempts (int): Number of attempts made with this tool
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Repeated Usage")
|
||||
self._add_attribute(
|
||||
@@ -499,23 +488,22 @@ class Telemetry:
|
||||
self._add_attribute(span, "attempts", attempts)
|
||||
if llm:
|
||||
self._add_attribute(span, "llm", llm.model)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def tool_usage(
|
||||
self, llm: Any, tool_name: str, attempts: int, agent: Any = None
|
||||
) -> None:
|
||||
def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None):
|
||||
"""Records the usage of a tool by an agent.
|
||||
|
||||
Args:
|
||||
llm: The language model being used.
|
||||
tool_name: Name of the tool being used.
|
||||
attempts: Number of attempts made with this tool.
|
||||
agent: The agent using the tool.
|
||||
llm (Any): The language model being used
|
||||
tool_name (str): Name of the tool being used
|
||||
attempts (int): Number of attempts made with this tool
|
||||
agent (Any, optional): The agent using the tool
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Usage")
|
||||
self._add_attribute(
|
||||
@@ -529,23 +517,30 @@ class Telemetry:
|
||||
self._add_attribute(span, "llm", llm.model)
|
||||
|
||||
# Add agent fingerprint data if available
|
||||
add_agent_fingerprint_to_span(span, agent, self._add_attribute)
|
||||
close_span(span)
|
||||
if agent and hasattr(agent, "fingerprint") and agent.fingerprint:
|
||||
self._add_attribute(
|
||||
span, "agent_fingerprint", agent.fingerprint.uuid_str
|
||||
)
|
||||
if hasattr(agent, "role"):
|
||||
self._add_attribute(span, "agent_role", agent.role)
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def tool_usage_error(
|
||||
self, llm: Any, agent: Any = None, tool_name: str | None = None
|
||||
) -> None:
|
||||
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
|
||||
):
|
||||
"""Records when a tool usage results in an error.
|
||||
|
||||
Args:
|
||||
llm: The language model being used when the error occurred.
|
||||
agent: The agent using the tool.
|
||||
tool_name: Name of the tool that caused the error.
|
||||
llm (Any): The language model being used when the error occurred
|
||||
agent (Any, optional): The agent using the tool
|
||||
tool_name (str, optional): Name of the tool that caused the error
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Usage Error")
|
||||
self._add_attribute(
|
||||
@@ -560,24 +555,31 @@ class Telemetry:
|
||||
self._add_attribute(span, "tool_name", tool_name)
|
||||
|
||||
# Add agent fingerprint data if available
|
||||
add_agent_fingerprint_to_span(span, agent, self._add_attribute)
|
||||
close_span(span)
|
||||
if agent and hasattr(agent, "fingerprint") and agent.fingerprint:
|
||||
self._add_attribute(
|
||||
span, "agent_fingerprint", agent.fingerprint.uuid_str
|
||||
)
|
||||
if hasattr(agent, "role"):
|
||||
self._add_attribute(span, "agent_role", agent.role)
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def individual_test_result_span(
|
||||
self, crew: Crew, quality: float, exec_time: int, model_name: str
|
||||
) -> None:
|
||||
):
|
||||
"""Records individual test results for a crew execution.
|
||||
|
||||
Args:
|
||||
crew: The crew being tested.
|
||||
quality: Quality score of the execution.
|
||||
exec_time: Execution time in seconds.
|
||||
model_name: Name of the model used.
|
||||
crew (Crew): The crew being tested
|
||||
quality (float): Quality score of the execution
|
||||
exec_time (int): Execution time in seconds
|
||||
model_name (str): Name of the model used
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Individual Test Result")
|
||||
|
||||
@@ -586,15 +588,15 @@ class Telemetry:
|
||||
"crewai_version",
|
||||
version("crewai"),
|
||||
)
|
||||
add_crew_attributes(
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(span, "crew_key", crew.key)
|
||||
self._add_attribute(span, "crew_id", str(crew.id))
|
||||
self._add_attribute(span, "quality", str(quality))
|
||||
self._add_attribute(span, "exec_time", str(exec_time))
|
||||
self._add_attribute(span, "model_name", model_name)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def test_execution_span(
|
||||
self,
|
||||
@@ -602,17 +604,17 @@ class Telemetry:
|
||||
iterations: int,
|
||||
inputs: dict[str, Any] | None,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
):
|
||||
"""Records the execution of a test suite for a crew.
|
||||
|
||||
Args:
|
||||
crew: The crew being tested.
|
||||
iterations: Number of test iterations.
|
||||
inputs: Input parameters for the test.
|
||||
model_name: Name of the model used in testing.
|
||||
crew (Crew): The crew being tested
|
||||
iterations (int): Number of test iterations
|
||||
inputs (dict[str, Any] | None): Input parameters for the test
|
||||
model_name (str): Name of the model used in testing
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Test Execution")
|
||||
|
||||
@@ -621,9 +623,8 @@ class Telemetry:
|
||||
"crewai_version",
|
||||
version("crewai"),
|
||||
)
|
||||
add_crew_attributes(
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(span, "crew_key", crew.key)
|
||||
self._add_attribute(span, "crew_id", str(crew.id))
|
||||
self._add_attribute(span, "iterations", str(iterations))
|
||||
self._add_attribute(span, "model_name", model_name)
|
||||
|
||||
@@ -632,99 +633,93 @@ class Telemetry:
|
||||
span, "inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def deploy_signup_error_span(self) -> None:
|
||||
def deploy_signup_error_span(self):
|
||||
"""Records when an error occurs during the deployment signup process."""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Deploy Signup Error")
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def start_deployment_span(self, uuid: str | None = None) -> None:
|
||||
def start_deployment_span(self, uuid: Optional[str] = None):
|
||||
"""Records the start of a deployment process.
|
||||
|
||||
Args:
|
||||
uuid: Unique identifier for the deployment.
|
||||
uuid (Optional[str]): Unique identifier for the deployment
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Start Deployment")
|
||||
if uuid:
|
||||
self._add_attribute(span, "uuid", uuid)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def create_crew_deployment_span(self) -> None:
|
||||
def create_crew_deployment_span(self):
|
||||
"""Records the creation of a new crew deployment."""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Create Crew Deployment")
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def get_crew_logs_span(
|
||||
self, uuid: str | None, log_type: str = "deployment"
|
||||
) -> None:
|
||||
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"):
|
||||
"""Records the retrieval of crew logs.
|
||||
|
||||
Args:
|
||||
uuid: Unique identifier for the crew.
|
||||
log_type: Type of logs being retrieved. Defaults to "deployment".
|
||||
uuid (Optional[str]): Unique identifier for the crew
|
||||
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Get Crew Logs")
|
||||
self._add_attribute(span, "log_type", log_type)
|
||||
if uuid:
|
||||
self._add_attribute(span, "uuid", uuid)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def remove_crew_span(self, uuid: str | None = None) -> None:
|
||||
def remove_crew_span(self, uuid: Optional[str] = None):
|
||||
"""Records the removal of a crew.
|
||||
|
||||
Args:
|
||||
uuid: Unique identifier for the crew being removed.
|
||||
uuid (Optional[str]): Unique identifier for the crew being removed
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Remove Crew")
|
||||
if uuid:
|
||||
self._add_attribute(span, "uuid", uuid)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def crew_execution_span(
|
||||
self, crew: Crew, inputs: dict[str, Any] | None
|
||||
) -> Span | None:
|
||||
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None):
|
||||
"""Records the complete execution of a crew.
|
||||
|
||||
This is only collected if the user has opted-in to share the crew.
|
||||
|
||||
Args:
|
||||
crew: The crew being executed.
|
||||
inputs: Optional input parameters for the crew.
|
||||
|
||||
Returns:
|
||||
The execution span if crew sharing is enabled, None otherwise.
|
||||
"""
|
||||
self.crew_creation(crew, inputs)
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Execution")
|
||||
self._add_attribute(
|
||||
@@ -732,9 +727,8 @@ class Telemetry:
|
||||
"crewai_version",
|
||||
version("crewai"),
|
||||
)
|
||||
add_crew_attributes(
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(span, "crew_key", crew.key)
|
||||
self._add_attribute(span, "crew_id", str(crew.id))
|
||||
self._add_attribute(
|
||||
span, "crew_inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
@@ -792,19 +786,12 @@ class Telemetry:
|
||||
return span
|
||||
|
||||
if crew.share_crew:
|
||||
self._safe_telemetry_operation(_operation)
|
||||
return _operation()
|
||||
self._safe_telemetry_operation(operation)
|
||||
return operation()
|
||||
return None
|
||||
|
||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||
"""Records the end of crew execution.
|
||||
|
||||
Args:
|
||||
crew: The crew that finished execution.
|
||||
final_string_output: The final output from the crew.
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def end_crew(self, crew, final_string_output):
|
||||
def operation():
|
||||
self._add_attribute(
|
||||
crew._execution_span,
|
||||
"crewai_version",
|
||||
@@ -827,70 +814,68 @@ class Telemetry:
|
||||
]
|
||||
),
|
||||
)
|
||||
close_span(crew._execution_span)
|
||||
crew._execution_span.set_status(Status(StatusCode.OK))
|
||||
crew._execution_span.end()
|
||||
|
||||
if crew.share_crew:
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def _add_attribute(self, span: Span, key: str, value: Any) -> None:
|
||||
"""Add an attribute to a span.
|
||||
def _add_attribute(self, span, key, value):
|
||||
"""Add an attribute to a span."""
|
||||
|
||||
Args:
|
||||
span: The span to add the attribute to.
|
||||
key: The attribute key.
|
||||
value: The attribute value.
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
return span.set_attribute(key, value)
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_creation_span(self, flow_name: str) -> None:
|
||||
def flow_creation_span(self, flow_name: str):
|
||||
"""Records the creation of a new flow.
|
||||
|
||||
Args:
|
||||
flow_name: Name of the flow being created.
|
||||
flow_name (str): Name of the flow being created
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Creation")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
|
||||
"""Records flow visualization/plotting activity.
|
||||
|
||||
Args:
|
||||
flow_name: Name of the flow being plotted.
|
||||
node_names: List of node names in the flow.
|
||||
flow_name (str): Name of the flow being plotted
|
||||
node_names (list[str]): List of node names in the flow
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Plotting")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
self._add_attribute(span, "node_names", json.dumps(node_names))
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||
def flow_execution_span(self, flow_name: str, node_names: list[str]):
|
||||
"""Records the execution of a flow.
|
||||
|
||||
Args:
|
||||
flow_name: Name of the flow being executed.
|
||||
node_names: List of nodes being executed in the flow.
|
||||
flow_name (str): Name of the flow being executed
|
||||
node_names (list[str]): List of nodes being executed in the flow
|
||||
"""
|
||||
|
||||
def _operation():
|
||||
def operation():
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Execution")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
self._add_attribute(span, "node_names", json.dumps(node_names))
|
||||
close_span(span)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
|
||||
self._safe_telemetry_operation(_operation)
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""Telemetry utility functions.
|
||||
|
||||
This module provides utility functions for telemetry operations.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
def add_agent_fingerprint_to_span(
|
||||
span: Span, agent: Any, add_attribute_fn: Callable[[Span, str, Any], None]
|
||||
) -> None:
|
||||
"""Add agent fingerprint data to a span if available.
|
||||
|
||||
Args:
|
||||
span: The span to add the attributes to.
|
||||
agent: The agent whose fingerprint data should be added.
|
||||
add_attribute_fn: Function to add attributes to the span.
|
||||
"""
|
||||
if agent:
|
||||
# Try to get fingerprint directly
|
||||
if hasattr(agent, "fingerprint") and agent.fingerprint:
|
||||
add_attribute_fn(span, "agent_fingerprint", agent.fingerprint.uuid_str)
|
||||
if hasattr(agent, "role"):
|
||||
add_attribute_fn(span, "agent_role", agent.role)
|
||||
else:
|
||||
# Try to get fingerprint using getattr (for cases where it might not be directly accessible)
|
||||
agent_fingerprint = getattr(
|
||||
getattr(agent, "fingerprint", None), "uuid_str", None
|
||||
)
|
||||
if agent_fingerprint:
|
||||
add_attribute_fn(span, "agent_fingerprint", agent_fingerprint)
|
||||
if hasattr(agent, "role"):
|
||||
add_attribute_fn(span, "agent_role", agent.role)
|
||||
|
||||
|
||||
def add_crew_attributes(
|
||||
span: Span,
|
||||
crew: "Crew",
|
||||
add_attribute_fn: Callable[[Span, str, Any], None],
|
||||
include_fingerprint: bool = True,
|
||||
) -> None:
|
||||
"""Add crew attributes to a span.
|
||||
|
||||
Args:
|
||||
span: The span to add the attributes to.
|
||||
crew: The crew whose attributes should be added.
|
||||
add_attribute_fn: Function to add attributes to the span.
|
||||
include_fingerprint: Whether to include fingerprint data.
|
||||
"""
|
||||
add_attribute_fn(span, "crew_key", crew.key)
|
||||
add_attribute_fn(span, "crew_id", str(crew.id))
|
||||
|
||||
if include_fingerprint and hasattr(crew, "fingerprint") and crew.fingerprint:
|
||||
add_attribute_fn(span, "crew_fingerprint", crew.fingerprint.uuid_str)
|
||||
|
||||
|
||||
def add_task_attributes(
|
||||
span: Span,
|
||||
task: "Task",
|
||||
add_attribute_fn: Callable[[Span, str, Any], None],
|
||||
include_fingerprint: bool = True,
|
||||
) -> None:
|
||||
"""Add task attributes to a span.
|
||||
|
||||
Args:
|
||||
span: The span to add the attributes to.
|
||||
task: The task whose attributes should be added.
|
||||
add_attribute_fn: Function to add attributes to the span.
|
||||
include_fingerprint: Whether to include fingerprint data.
|
||||
"""
|
||||
add_attribute_fn(span, "task_key", task.key)
|
||||
add_attribute_fn(span, "task_id", str(task.id))
|
||||
|
||||
if include_fingerprint and hasattr(task, "fingerprint") and task.fingerprint:
|
||||
add_attribute_fn(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
|
||||
|
||||
def add_crew_and_task_attributes(
|
||||
span: Span,
|
||||
crew: "Crew",
|
||||
task: "Task",
|
||||
add_attribute_fn: Callable[[Span, str, Any], None],
|
||||
include_fingerprints: bool = True,
|
||||
) -> None:
|
||||
"""Add both crew and task attributes to a span.
|
||||
|
||||
Args:
|
||||
span: The span to add the attributes to.
|
||||
crew: The crew whose attributes should be added.
|
||||
task: The task whose attributes should be added.
|
||||
add_attribute_fn: Function to add attributes to the span.
|
||||
include_fingerprints: Whether to include fingerprint data.
|
||||
"""
|
||||
add_crew_attributes(span, crew, add_attribute_fn, include_fingerprints)
|
||||
add_task_attributes(span, task, add_attribute_fn, include_fingerprints)
|
||||
|
||||
|
||||
def close_span(span: Span) -> None:
|
||||
"""Set span status to OK and end it.
|
||||
|
||||
Args:
|
||||
span: The span to close.
|
||||
"""
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
span.end()
|
||||
@@ -1,22 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, get_type_hints
|
||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class ToolUsageLimitExceededError(Exception):
|
||||
class ToolUsageLimitExceeded(Exception):
|
||||
"""Exception raised when a tool has reached its maximum usage limit."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CrewStructuredTool:
|
||||
"""A structured tool that can operate on any number of inputs.
|
||||
@@ -65,10 +69,10 @@ class CrewStructuredTool:
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: type[BaseModel] | None = None,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CrewStructuredTool:
|
||||
@@ -160,7 +164,7 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload]
|
||||
return create_model(schema_name, **fields)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -188,7 +192,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: str | dict) -> dict:
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -203,18 +207,18 @@ class CrewStructuredTool:
|
||||
|
||||
raw_args = json.loads(raw_args)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
||||
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
||||
|
||||
try:
|
||||
validated_args = self.args_schema.model_validate(raw_args)
|
||||
return validated_args.model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Arguments validation failed: {e}") from e
|
||||
raise ValueError(f"Arguments validation failed: {e}")
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | dict,
|
||||
config: dict | None = None,
|
||||
input: Union[str, dict],
|
||||
config: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -230,7 +234,7 @@ class CrewStructuredTool:
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
if self.has_reached_max_usage_count():
|
||||
raise ToolUsageLimitExceededError(
|
||||
raise ToolUsageLimitExceeded(
|
||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||
)
|
||||
|
||||
@@ -239,37 +243,44 @@ class CrewStructuredTool:
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**parsed_args, **kwargs)
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
else:
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.func(**parsed_args, **kwargs)
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.func(**parsed_args, **kwargs)
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
||||
input_dict.update(kwargs)
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: str | dict, config: dict | None = None, **kwargs: Any
|
||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
if self.has_reached_max_usage_count():
|
||||
raise ToolUsageLimitExceededError(
|
||||
raise ToolUsageLimitExceeded(
|
||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||
)
|
||||
|
||||
self._increment_usage_count()
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
return result
|
||||
|
||||
try:
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
"""Crew chat input models.
|
||||
|
||||
This module provides models for defining chat inputs and fields
|
||||
for crew interactions.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatInputField(BaseModel):
|
||||
"""Represents a single required input for the crew.
|
||||
|
||||
"""
|
||||
Represents a single required input for the crew, with a name and short description.
|
||||
Example:
|
||||
```python
|
||||
field = ChatInputField(
|
||||
name="topic",
|
||||
description="The topic to focus on for the conversation"
|
||||
)
|
||||
```
|
||||
{
|
||||
"name": "topic",
|
||||
"description": "The topic to focus on for the conversation"
|
||||
}
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the input field")
|
||||
@@ -24,25 +18,23 @@ class ChatInputField(BaseModel):
|
||||
|
||||
|
||||
class ChatInputs(BaseModel):
|
||||
"""Holds crew metadata and input field definitions.
|
||||
|
||||
"""
|
||||
Holds a high-level crew_description plus a list of ChatInputFields.
|
||||
Example:
|
||||
```python
|
||||
inputs = ChatInputs(
|
||||
crew_name="topic-based-qa",
|
||||
crew_description="Use this crew for topic-based Q&A",
|
||||
inputs=[
|
||||
ChatInputField(name="topic", description="The topic to focus on"),
|
||||
ChatInputField(name="username", description="Name of the user"),
|
||||
{
|
||||
"crew_name": "topic-based-qa",
|
||||
"crew_description": "Use this crew for topic-based Q&A",
|
||||
"inputs": [
|
||||
{"name": "topic", "description": "The topic to focus on"},
|
||||
{"name": "username", "description": "Name of the user"},
|
||||
]
|
||||
)
|
||||
```
|
||||
}
|
||||
"""
|
||||
|
||||
crew_name: str = Field(..., description="The name of the crew")
|
||||
crew_description: str = Field(
|
||||
..., description="A description of the crew's purpose"
|
||||
)
|
||||
inputs: list[ChatInputField] = Field(
|
||||
inputs: List[ChatInputField] = Field(
|
||||
default_factory=list, description="A list of input fields for the crew"
|
||||
)
|
||||
|
||||
@@ -1,37 +1,18 @@
|
||||
"""Human-in-the-loop (HITL) type definitions.
|
||||
|
||||
This module provides type definitions for human-in-the-loop interactions
|
||||
in crew executions.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import List, Dict, TypedDict
|
||||
|
||||
|
||||
class HITLResumeInfo(TypedDict, total=False):
|
||||
"""HITL resume information passed from flow to crew.
|
||||
|
||||
Attributes:
|
||||
task_id: Unique identifier for the task.
|
||||
crew_execution_id: Unique identifier for the crew execution.
|
||||
task_key: Key identifying the specific task.
|
||||
task_output: Output from the task before human intervention.
|
||||
human_feedback: Feedback provided by the human.
|
||||
previous_messages: History of messages in the conversation.
|
||||
"""
|
||||
"""HITL resume information passed from flow to crew."""
|
||||
|
||||
task_id: str
|
||||
crew_execution_id: str
|
||||
task_key: str
|
||||
task_output: str
|
||||
human_feedback: str
|
||||
previous_messages: list[dict[str, str]]
|
||||
previous_messages: List[Dict[str, str]]
|
||||
|
||||
|
||||
class CrewInputsWithHITL(TypedDict, total=False):
|
||||
"""Crew inputs that may contain HITL resume information.
|
||||
|
||||
Attributes:
|
||||
_hitl_resume: Optional HITL resume information for continuing execution.
|
||||
"""
|
||||
"""Crew inputs that may contain HITL resume information."""
|
||||
|
||||
_hitl_resume: HITLResumeInfo
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
"""Usage metrics tracking for CrewAI execution.
|
||||
|
||||
This module provides models for tracking token usage and request metrics
|
||||
during crew and agent execution.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class UsageMetrics(BaseModel):
|
||||
"""Track usage metrics for crew execution.
|
||||
"""
|
||||
Model to track usage metrics for the crew's execution.
|
||||
|
||||
Attributes:
|
||||
total_tokens: Total number of tokens used.
|
||||
@@ -33,11 +27,12 @@ class UsageMetrics(BaseModel):
|
||||
default=0, description="Number of successful requests made."
|
||||
)
|
||||
|
||||
def add_usage_metrics(self, usage_metrics: Self) -> None:
|
||||
"""Add usage metrics from another UsageMetrics object.
|
||||
def add_usage_metrics(self, usage_metrics: "UsageMetrics"):
|
||||
"""
|
||||
Add the usage metrics from another UsageMetrics object.
|
||||
|
||||
Args:
|
||||
usage_metrics: The usage metrics to add.
|
||||
usage_metrics (UsageMetrics): The usage metrics to add.
|
||||
"""
|
||||
self.total_tokens += usage_metrics.total_tokens
|
||||
self.prompt_tokens += usage_metrics.prompt_tokens
|
||||
|
||||
83
src/crewai/utilities/chromadb.py
Normal file
83
src/crewai/utilities/chromadb.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import re
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
from hashlib import md5
|
||||
from typing import Optional
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
MIN_COLLECTION_LENGTH = 3
|
||||
MAX_COLLECTION_LENGTH = 63
|
||||
DEFAULT_COLLECTION = "default_collection"
|
||||
|
||||
# Compiled regex patterns for better performance
|
||||
INVALID_CHARS_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
|
||||
|
||||
def is_ipv4_pattern(name: str) -> bool:
|
||||
"""
|
||||
Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def sanitize_collection_name(
|
||||
name: Optional[str], max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||
) -> str:
|
||||
"""
|
||||
Sanitize a collection name to meet ChromaDB requirements:
|
||||
1. 3-63 characters long
|
||||
2. Starts and ends with alphanumeric character
|
||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||
4. No consecutive periods
|
||||
5. Not a valid IPv4 address
|
||||
|
||||
Args:
|
||||
name: The original collection name to sanitize
|
||||
|
||||
Returns:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
if is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
if not sanitized[0].isalnum():
|
||||
sanitized = "a" + sanitized
|
||||
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def create_persistent_client(path: str, **kwargs):
|
||||
"""
|
||||
Creates a persistent client for ChromaDB with a lock file to prevent
|
||||
concurrent creations. Works for both multi-threads and multi-processes
|
||||
environments.
|
||||
"""
|
||||
lock_id = md5(path.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(db_storage_path(), f"chromadb-{lock_id}.lock")
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(path=path, **kwargs)
|
||||
|
||||
return client
|
||||
@@ -1,16 +1,12 @@
|
||||
"""Error message definitions for CrewAI database operations.
|
||||
"""Error message definitions for CrewAI database operations."""
|
||||
|
||||
This module provides standardized error classes and message templates
|
||||
for database operations and agent repository handling.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DatabaseOperationError(Exception):
|
||||
"""Base exception class for database operation errors."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception | None = None) -> None:
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
"""Initialize the database operation error.
|
||||
|
||||
Args:
|
||||
@@ -22,17 +18,13 @@ class DatabaseOperationError(Exception):
|
||||
|
||||
|
||||
class DatabaseError:
|
||||
"""Standardized error message templates for database operations.
|
||||
"""Standardized error message templates for database operations."""
|
||||
|
||||
Provides consistent error message formatting for various database
|
||||
operation failures.
|
||||
"""
|
||||
|
||||
INIT_ERROR: Final[str] = "Database initialization error: {}"
|
||||
SAVE_ERROR: Final[str] = "Error saving task outputs: {}"
|
||||
UPDATE_ERROR: Final[str] = "Error updating task outputs: {}"
|
||||
LOAD_ERROR: Final[str] = "Error loading task outputs: {}"
|
||||
DELETE_ERROR: Final[str] = "Error deleting task outputs: {}"
|
||||
INIT_ERROR: str = "Database initialization error: {}"
|
||||
SAVE_ERROR: str = "Error saving task outputs: {}"
|
||||
UPDATE_ERROR: str = "Error updating task outputs: {}"
|
||||
LOAD_ERROR: str = "Error loading task outputs: {}"
|
||||
DELETE_ERROR: str = "Error deleting task outputs: {}"
|
||||
|
||||
@classmethod
|
||||
def format_error(cls, template: str, error: Exception) -> str:
|
||||
@@ -50,3 +42,5 @@ class DatabaseError:
|
||||
|
||||
class AgentRepositoryError(Exception):
|
||||
"""Exception raised when an agent repository is not found."""
|
||||
|
||||
...
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Logging and warning utility functions for CrewAI."""
|
||||
"""Logging utility functions for CrewAI."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
@@ -37,20 +36,3 @@ def suppress_logging(
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_warnings() -> Generator[None, None, None]:
|
||||
"""Context manager to suppress all warnings.
|
||||
|
||||
Yields:
|
||||
None during the context execution.
|
||||
|
||||
Note:
|
||||
There is a similar implementation in src/crewai/llm.py that also
|
||||
suppresses a specific deprecation warning. That version may be
|
||||
consolidated here in the future.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
yield
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
"""Task output storage handler for managing task execution results.
|
||||
|
||||
This module provides functionality for storing and retrieving task outputs
|
||||
from persistent storage, supporting replay and audit capabilities.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,64 +8,32 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
"""Handles storage and retrieval of task execution outputs."""
|
||||
|
||||
|
||||
class ExecutionLog(BaseModel):
|
||||
"""Represents a log entry for task execution.
|
||||
|
||||
Attributes:
|
||||
task_id: Unique identifier for the task.
|
||||
expected_output: The expected output description for the task.
|
||||
output: The actual output produced by the task.
|
||||
timestamp: When the task was executed.
|
||||
task_index: The position of the task in the execution sequence.
|
||||
inputs: Input parameters provided to the task.
|
||||
was_replayed: Whether this output was replayed from a previous run.
|
||||
"""
|
||||
"""Represents a log entry for task execution."""
|
||||
|
||||
task_id: str
|
||||
expected_output: str | None = None
|
||||
output: dict[str, Any]
|
||||
expected_output: Optional[str] = None
|
||||
output: Dict[str, Any]
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
task_index: int
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
was_replayed: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Enable dictionary-style access to execution log attributes.
|
||||
|
||||
Args:
|
||||
key: The attribute name to access.
|
||||
|
||||
Returns:
|
||||
The value of the requested attribute.
|
||||
"""
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
"""Manages storage and retrieval of task outputs."""
|
||||
|
||||
|
||||
class TaskOutputStorageHandler:
|
||||
"""Manages storage and retrieval of task outputs.
|
||||
|
||||
This handler provides an interface to persist and retrieve task execution
|
||||
results, supporting features like replay and audit trails.
|
||||
|
||||
Attributes:
|
||||
storage: The underlying SQLite storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the task output storage handler."""
|
||||
self.storage = KickoffTaskOutputsSQLiteStorage()
|
||||
|
||||
def update(self, task_index: int, log: dict[str, Any]) -> None:
|
||||
"""Update an existing task output in storage.
|
||||
|
||||
Args:
|
||||
task_index: The index of the task to update.
|
||||
log: Dictionary containing task execution details.
|
||||
|
||||
Raises:
|
||||
ValueError: If no saved outputs exist.
|
||||
"""
|
||||
def update(self, task_index: int, log: Dict[str, Any]):
|
||||
saved_outputs = self.load()
|
||||
if saved_outputs is None:
|
||||
raise ValueError("Logs cannot be None")
|
||||
@@ -94,31 +56,16 @@ class TaskOutputStorageHandler:
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: dict[str, Any],
|
||||
output: Dict[str, Any],
|
||||
task_index: int,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
was_replayed: bool = False,
|
||||
) -> None:
|
||||
"""Add a new task output to storage.
|
||||
|
||||
Args:
|
||||
task: The task that was executed.
|
||||
output: The output produced by the task.
|
||||
task_index: The position of the task in execution sequence.
|
||||
inputs: Optional input parameters for the task.
|
||||
was_replayed: Whether this is a replayed execution.
|
||||
"""
|
||||
):
|
||||
inputs = inputs or {}
|
||||
self.storage.add(task, output, task_index, was_replayed, inputs)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all stored task outputs."""
|
||||
def reset(self):
|
||||
self.storage.delete_all()
|
||||
|
||||
def load(self) -> list[dict[str, Any]] | None:
|
||||
"""Load all stored task outputs.
|
||||
|
||||
Returns:
|
||||
List of task output dictionaries, or None if no outputs exist.
|
||||
"""
|
||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||
return self.storage.load()
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
"""Token counting callback handler for LLM interactions.
|
||||
|
||||
This module provides a callback handler that tracks token usage
|
||||
for LLM API calls through the litellm library.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
@@ -14,38 +8,16 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
||||
|
||||
|
||||
class TokenCalcHandler(CustomLogger):
|
||||
"""Handler for calculating and tracking token usage in LLM calls.
|
||||
|
||||
This handler integrates with litellm's logging system to track
|
||||
prompt tokens, completion tokens, and cached tokens across requests.
|
||||
|
||||
Attributes:
|
||||
token_cost_process: The token process tracker to accumulate usage metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, token_cost_process: TokenProcess | None) -> None:
|
||||
"""Initialize the token calculation handler.
|
||||
|
||||
Args:
|
||||
token_cost_process: Optional token process tracker for accumulating metrics.
|
||||
"""
|
||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||
self.token_cost_process = token_cost_process
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: dict[str, Any],
|
||||
kwargs: Dict[str, Any],
|
||||
response_obj: Dict[str, Any],
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
"""Log successful LLM API call and track token usage.
|
||||
|
||||
Args:
|
||||
kwargs: The arguments passed to the LLM call.
|
||||
response_obj: The response object from the LLM API.
|
||||
start_time: The timestamp when the call started.
|
||||
end_time: The timestamp when the call completed.
|
||||
"""
|
||||
if self.token_cost_process is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -9,19 +9,19 @@ import pytest
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.tools import tool
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
@@ -445,7 +445,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
@tool
|
||||
def comapny_customer_data() -> str:
|
||||
def comapny_customer_data() -> float:
|
||||
"""Useful for getting customer related data."""
|
||||
return "The company has 42 customers"
|
||||
|
||||
@@ -500,15 +500,6 @@ def test_agent_custom_max_iterations():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_repeated_tool_usage(capsys):
|
||||
"""Test that agents handle repeated tool usage appropriately.
|
||||
|
||||
Notes:
|
||||
Investigate whether to pin down the specific execution flow by examining
|
||||
src/crewai/agents/crew_agent_executor.py:177-186 (max iterations check)
|
||||
and src/crewai/tools/tool_usage.py:152-157 (repeated usage detection)
|
||||
to ensure deterministic behavior.
|
||||
"""
|
||||
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
"""Get the final answer but don't give it yet, just re-use this tool non-stop."""
|
||||
@@ -536,16 +527,42 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lower = captured.out.lower()
|
||||
|
||||
has_repeated_usage_message = "tried reusing the same input" in output_lower
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
output = (
|
||||
captured.out.replace("\n", " ")
|
||||
.replace(" ", " ")
|
||||
.strip()
|
||||
.replace("╭", "")
|
||||
.replace("╮", "")
|
||||
.replace("╯", "")
|
||||
.replace("╰", "")
|
||||
.replace("│", "")
|
||||
.replace("─", "")
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.replace("bold", "")
|
||||
.replace("blue", "")
|
||||
.replace("yellow", "")
|
||||
.replace("green", "")
|
||||
.replace("red", "")
|
||||
.replace("dim", "")
|
||||
.replace("🤖", "")
|
||||
.replace("🔧", "")
|
||||
.replace("✅", "")
|
||||
.replace("\x1b[93m", "")
|
||||
.replace("\x1b[00m", "")
|
||||
.replace("\\", "")
|
||||
.replace('"', "")
|
||||
.replace("'", "")
|
||||
)
|
||||
|
||||
# Look for the message in the normalized output, handling the apostrophe difference
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input."
|
||||
)
|
||||
assert (
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
@@ -585,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
)
|
||||
assert (
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -863,7 +880,7 @@ def test_agent_step_callback():
|
||||
with patch.object(StepCallback, "callback") as callback:
|
||||
|
||||
@tool
|
||||
def learn_about_ai() -> str:
|
||||
def learn_about_AI() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -871,7 +888,7 @@ def test_agent_step_callback():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_ai],
|
||||
tools=[learn_about_AI],
|
||||
step_callback=StepCallback().callback,
|
||||
)
|
||||
|
||||
@@ -893,7 +910,7 @@ def test_agent_function_calling_llm():
|
||||
llm = "gpt-4o"
|
||||
|
||||
@tool
|
||||
def learn_about_ai() -> str:
|
||||
def learn_about_AI() -> str:
|
||||
"""Useful for when you need to learn about AI to write an paragraph about it."""
|
||||
return "AI is a very broad field."
|
||||
|
||||
@@ -901,7 +918,7 @@ def test_agent_function_calling_llm():
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[learn_about_ai],
|
||||
tools=[learn_about_AI],
|
||||
llm="gpt-4o",
|
||||
max_iter=2,
|
||||
function_calling_llm=llm,
|
||||
@@ -1339,7 +1356,7 @@ def test_agent_training_handler(crew_training_handler):
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
f"{agent.id!s}": {"0": {"human_feedback": "good"}}
|
||||
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
|
||||
}
|
||||
|
||||
result = agent._training_handler(task_prompt=task_prompt)
|
||||
@@ -1456,7 +1473,7 @@ def test_agent_with_custom_stop_words():
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert set(agent.llm.stop) == set([*stop_words, "\nObservation:"])
|
||||
assert set(agent.llm.stop) == set(stop_words + ["\nObservation:"])
|
||||
assert all(word in agent.llm.stop for word in stop_words)
|
||||
assert "\nObservation:" in agent.llm.stop
|
||||
|
||||
@@ -1513,7 +1530,7 @@ def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
with pytest.raises(Exception):
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@@ -1813,11 +1830,11 @@ def test_agent_execute_task_with_ollama():
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
with patch("crewai.knowledge") as mock_knowledge:
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.search.return_value = [{"content": content}]
|
||||
mock_knowledge.add_sources.return_value = [string_source]
|
||||
MockKnowledge.add_sources.return_value = [string_source]
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1846,25 +1863,12 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig(results_limit=10, score_threshold=0.5)
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
@@ -1894,27 +1898,15 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
with patch.object(Knowledge, "query") as mock_knowledge_query:
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
knowledge_config = KnowledgeConfig()
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -1943,16 +1935,10 @@ def test_agent_with_knowledge_sources_extensive_role():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.save"
|
||||
) as mock_save,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
mock_save.return_value = None
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
@@ -1982,8 +1968,8 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
with patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.BaseKnowledgeSource",
|
||||
autospec=True,
|
||||
) as mock_knowledge_source:
|
||||
mock_knowledge_source_instance = mock_knowledge_source.return_value
|
||||
) as MockKnowledgeSource:
|
||||
mock_knowledge_source_instance = MockKnowledgeSource.return_value
|
||||
mock_knowledge_source_instance.__class__ = BaseKnowledgeSource
|
||||
mock_knowledge_source_instance.sources = [string_source]
|
||||
|
||||
@@ -1997,9 +1983,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage:
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
) as MockKnowledgeStorage:
|
||||
mock_knowledge_storage = MockKnowledgeStorage.return_value
|
||||
agent.knowledge_storage = mock_knowledge_storage
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
@@ -2018,30 +2004,11 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
with (
|
||||
patch("crewai.knowledge") as mock_knowledge,
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
):
|
||||
mock_knowledge_instance = mock_knowledge.return_value
|
||||
with patch("crewai.knowledge") as MockKnowledge:
|
||||
mock_knowledge_instance = MockKnowledge.return_value
|
||||
mock_knowledge_instance.sources = [string_source]
|
||||
mock_knowledge_instance.query.return_value = [{"content": content}]
|
||||
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent with extensive role description that is longer than 80 characters",
|
||||
goal="Provide information based on knowledge sources",
|
||||
@@ -2303,26 +2270,7 @@ def test_get_knowledge_search_query():
|
||||
i18n = I18N()
|
||||
task_prompt = task.prompt()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage,
|
||||
patch(
|
||||
"crewai.knowledge.source.base_knowledge_source.KnowledgeStorage"
|
||||
) as mock_base_knowledge_storage,
|
||||
patch("crewai.rag.chromadb.client.ChromaDBClient") as mock_chromadb,
|
||||
patch.object(agent, "_get_knowledge_search_query") as mock_get_query,
|
||||
):
|
||||
mock_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_storage_instance.sources = [string_source]
|
||||
mock_storage_instance.query.return_value = [{"content": content}]
|
||||
mock_storage_instance.save.return_value = None
|
||||
|
||||
mock_chromadb_instance = mock_chromadb.return_value
|
||||
mock_chromadb_instance.add_documents.return_value = None
|
||||
|
||||
mock_base_knowledge_storage.return_value = mock_storage_instance
|
||||
|
||||
with patch.object(agent, "_get_knowledge_search_query") as mock_get_query:
|
||||
mock_get_query.return_value = "Capital of France"
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
@@ -2364,9 +2312,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
EnterpriseActionTool,
|
||||
FileReadTool,
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
@@ -2399,7 +2347,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
tool_action = EnterpriseActionTool(
|
||||
name="test_name",
|
||||
description="test_description",
|
||||
enterprise_action_token="test_token", # noqa: S106
|
||||
enterprise_action_token="test_token",
|
||||
action_name="test_action_name",
|
||||
action_schema={"test": "test"},
|
||||
)
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
# ruff: noqa: S101
|
||||
# mypy: ignore-errors
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools import BaseTool
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
# A simple test tool
|
||||
@@ -38,9 +37,10 @@ class WebSearchTool(BaseTool):
|
||||
# This is a mock implementation
|
||||
if "tokyo" in query.lower():
|
||||
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
||||
if "climate change" in query.lower() and "coral" in query.lower():
|
||||
elif "climate change" in query.lower() and "coral" in query.lower():
|
||||
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||
else:
|
||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||
|
||||
|
||||
# Define Mock Calculator Tool
|
||||
@@ -53,11 +53,10 @@ class CalculatorTool(BaseTool):
|
||||
def _run(self, expression: str) -> str:
|
||||
"""Calculate the result of a mathematical expression."""
|
||||
try:
|
||||
# Using eval with restricted builtins for test purposes only
|
||||
result = eval(expression, {"__builtins__": {}}) # noqa: S307
|
||||
result = eval(expression, {"__builtins__": {}})
|
||||
return f"The result of {expression} is {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating {expression}: {e!s}"
|
||||
return f"Error calculating {expression}: {str(e)}"
|
||||
|
||||
|
||||
# Define a custom response format using Pydantic
|
||||
@@ -149,12 +148,12 @@ def test_lite_agent_with_tools():
|
||||
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
|
||||
)
|
||||
|
||||
assert "21 million" in result.raw or "37 million" in result.raw, (
|
||||
"Agent should find Tokyo's population"
|
||||
)
|
||||
assert "per square kilometer" in result.raw, (
|
||||
"Agent should calculate population density"
|
||||
)
|
||||
assert (
|
||||
"21 million" in result.raw or "37 million" in result.raw
|
||||
), "Agent should find Tokyo's population"
|
||||
assert (
|
||||
"per square kilometer" in result.raw
|
||||
), "Agent should calculate population density"
|
||||
|
||||
received_events = []
|
||||
|
||||
@@ -295,7 +294,6 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
mock_llm.stop = []
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=NaXWifUGChHp6Ap1mvfMrNzmO4HdzddrqXkSR9T.hYo-1754508545647-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLBbtswDL37Kzid4yFx46bxbVixtsfssB22wlAl2lEri5okJ+uK/Psg
|
||||
OY3dtQV2MWA+vqf3SD5lAExJVgETWx5EZ3X++SrcY/Hrcec3l+SKP5frm/16Yx92m6/f9mwWGXR3
|
||||
jyI8sz4K6qzGoMgMsHDIA0bVxaq8WCzPyrN5AjqSqCOttSFfUt4po/JiXizz+SpfXBzZW1ICPavg
|
||||
RwYA8JS+0aeR+JtVkLRSpUPveYusOjUBMEc6Vhj3XvnATWCzERRkAppk/QYM7UFwA63aIXBoo23g
|
||||
xu/RAfw0X5ThGj6l/wquUWuawXdyWn6YSjpses9jLNNrPQG4MRR4HEsKc3tEDif7mlrr6M7/Q2WN
|
||||
Mspva4fck4lWfSDLEnrIAG7TmPoXyZl11NlQB3rA9NyiXA16bNzOFD2CgQLXk/qqmL2hV0sMXGk/
|
||||
GTQTXGxRjtRxK7yXiiZANkn92s1b2kNyZdr/kR8BIdAGlLV1KJV4mXhscxiP972205STYebR7ZTA
|
||||
Oih0cRMSG97r4aSYf/QBu7pRpkVnnRruqrF1eT7nzTmW5Zplh+wvAAAA//8DAGKunMhlAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980b99a73c1c22c6-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:12:11 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=Ahwkw3J9CDiluZudRgDmybz4FO07eXLz2MQDtkgfct4-1758143531-1.0.1.1-_3e8agfTZW.FPpRMLb1A2nET4OHQEGKNZeGeWT8LIiuSi8R2HWsGsJyueUyzYBYnfHqsfBUO16K1.TkEo2XiqVCaIi6pymeeQxwtXFF1wj8;
|
||||
path=/; expires=Wed, 17-Sep-25 21:42:11 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=iHqLoc_2sNQLMyzfGCLtGol8vf1Y44xirzQJUuUF_TI-1758143531242-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '419'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '609'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_ece5f999e09e4c189d38e5bc08b2fad9
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,128 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blBSmrabG0ICViqCCydYRbPOJDHreIztbIFV/zty
|
||||
0m1SPiQukTJv3vN7M/OUAAhVixKE7DDI3ur09dvgfa8P/FO+e/9wa/aHb4cPH9viEw5yK1aRwfdf
|
||||
SYZn1gvJvdUUFJsJlo4wUFTNd8U+32zybD0CPdekI621Id1w2iuj0nW23qTZLs33Z3bHSpIXJXxO
|
||||
AACexm/0aWr6LkrIVs+VnrzHlkR5aQIQjnWsCPRe+YAmiNUMSjaBzGj9FgwfQaKBVj0SILTRNqDx
|
||||
R3IAX8wbZVDDq/G/hI60Zjiy0/VS0FEzeIyhzKD1AkBjOGAcyhjl7oycLuY1t9bxvf+NKhpllO8q
|
||||
R+jZRKM+sBUjekoA7sYhDVe5hXXc21AFfqDxubzYTXpi3s0CfXkGAwfUi/ruPNprvaqmgEr7xZiF
|
||||
RNlRPVPnneBQK14AySL1n27+pj0lV6b9H/kZkJJsoLqyjmolrxPPbY7i6f6r7TLl0bDw5B6VpCoo
|
||||
cnETNTU46OmghP/hA/VVo0xLzjo1XVVjq2KbYbOlorgRySn5BQAA//8DALxsmCBjAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980ba79a4ab5f555-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:21:42 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=aMMf0fLckKHz0BLW_2lATxD.7R61uYo1ZVW8aeFbruA-1758144102-1.0.1.1-6EKM3UxpdczoiQ6VpPpqqVnY7ftnXndFRWE4vyTzVcy.CQ4N539D97Wh8Ye9EUAvpUuukhW.r5MznkXq4tPXgCCmEv44RvVz2GBAz_e31h8;
|
||||
path=/; expires=Wed, 17-Sep-25 21:51:42 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=VqrtvU8.QdEHc4.1XXUVmccaCcoj_CiNfI2zhKJoGRs-1758144102566-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '308'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '620'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_fa896433021140238115972280c05651
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,127 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
|
||||
is the expected criteria for your final answer: test output\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '812'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrKxY8W4WV+JHoVgR95NJD4UvaBgJDrSS2FJclV3bSwP9e
|
||||
kHYsuU2BXghwZ2c4s8vnDEDoWpQgVCdZ9c7kNx+4v7vb7jafnrrPX25/vtObX48f1Q31m+uFmEUG
|
||||
PXxHxS+sN4p6Z5A12QOsPErGqFqsl1fF4nJdzBPQU40m0lrH+YLyXludX8wvFvl8nRdXR3ZHWmEQ
|
||||
JXzNAACe0xl92hofRQlJK1V6DEG2KMpTE4DwZGJFyBB0YGlZzEZQkWW0yfotWNqBkhZavUWQ0Ebb
|
||||
IG3YoQf4Zt9rKw28TfcSNhgYaGA3nAl6bIYgYyg7GDMBpLXEMg4lRbk/IvuTeUOt8/QQ/qCKRlsd
|
||||
usqjDGSj0cDkREL3GcB9GtJwlls4T73jiukHpueK5eKgJ8bdTNDLI8jE0kzqq/XsFb2qRpbahMmY
|
||||
hZKqw3qkjjuRQ61pAmST1H+7eU37kFzb9n/kR0ApdIx15TzWWp0nHts8xq/7r7bTlJNhEdBvtcKK
|
||||
Nfq4iRobOZjD/kV4Cox91WjbondeH35V46rlai6bFS6X1yLbZ78BAAD//wMAZdfoWWMDAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 980b9e0c5fa516a0-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 17 Sep 2025 21:15:11 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=w6UZxbAZgYg9EFkKPfrSbMK97MB4jfs7YyvcEmgkvak-1758143711-1.0.1.1-j7YC1nvoMKxYK0T.5G2XDF6TXUCPu_HUs4YO9v65r3NHQFIcOaHbQXX4vqabSgynL2tZy23pbZgD8Cdmxhdw9dp4zkAXhU.imP43_pw4dSE;
|
||||
path=/; expires=Wed, 17-Sep-25 21:45:11 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=ij9Q8tB7sj2GczANlJ7gbXVjj6hMhz1iVb6oGHuRYu8-1758143711202-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '462'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '665'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_04536db97c8c4768a200e38c1368c176
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -330,222 +330,4 @@ interactions:
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"input": ["Capital of France"], "model": "text-embedding-3-small", "encoding_format":
|
||||
"base64"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '96'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=rvDDZbBWaissP0luvtyuyyAWcPx3AiaoZS9LkAuK4sM-1746636999152-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/embeddings
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA1R6Ww+yyrbl+/4VK+uV3pGLUFXrjbsISKEgYqfTAUEEBeRWQJ2c/97Br/t094uJ
|
||||
WIZUzVljjjHm/I9//fXX321a5Y/x73/++vtTDuPf/217liVj8vc/f/33f/31119//cfv8/9bmddp
|
||||
nmVlU/yW/34smyxf/v7nL/a/nvzfRf/89XfIFDaOu8OtX0ttiOESiQE+pXaRrnxot/BYFAy5kEUP
|
||||
ONOCJaymU0mOVijT9bPeH/BZXg4kYptvv4S7/QQztojx/ay+Ul5kFR8F+9XEUeizziKFvQ+uIr4T
|
||||
E107h77ZSw1r9zPiUCqSYL4SxUbv5usRy5ZLusgkecBnVXUYOzGbrs25voiDG1c4vEhCTxHz8GCW
|
||||
BAdvL9+sngv6iwv3uxDj7HWgPUkXUIBvFy3YnkpJo7AQIvipwoIEgINatQMnCZbG54EP4NtQGgm1
|
||||
j84HJBDbhxeNI+7DBsboBuS5G27VeLUbGV4TmSEyXDVNYO2qg963GbCJrrazpK3iIk9XzQn1/aGa
|
||||
HvHpDatPvZDj9a5UvL67PpApmxGO6maite0dfTjujIfHLN4F0Dun8IiTkpSc0vyuLRl9hRAz+on4
|
||||
10ruubG6mchYruLEKW2rfYOdyUKFSzKs4XeRsjXNYsmeAMVeXOk9n7hjDu9xnGNTZ57p3DVSDWr5
|
||||
fMXpOdNo4zDGG1m7/QtHFVHoPOe4gw9gDliBlhTQ8z7hRY3aLjneZ4Uu8u3TAfe1ODjU9lrFPfww
|
||||
hvez5uNTtZcBx1Zhglg5CrF231NAhv5ewND/BCQ5FX2/CJdjAW8sKYkyly7lHFe2AdR9HZ9Jc6EC
|
||||
8O8qelzOmcdoegrac2xJ0KLvI8be8UXp28I+fJ/uiCQH4ews2KhKpFTMwxPPo5S+M+Oxh4esr7Hb
|
||||
TDigi6VIaPf8aMS6GVol/PItMgbfA8z3HghaTguYSWJFrmr1AovFCxN6lWOCH8Zy0Cj3tG24rysZ
|
||||
n9r61c8ncz9A32GbqY46xeETxDNIpUswCWHyqRbOfMVIaG0bG4uy9Et2iGSYhTeJqEEyARpwbgEt
|
||||
fzpMxzJXU9bImBBMrPAicq0fwCyP9wiGaWsRR481jeNRCZHD8pG348i7IurdyaE9iRTHha1X0yeE
|
||||
NbicCPCK19j063LnYvSxK5+oxyWshHE3y2h/CD/kfrjXdEaVFIL3y9fwbQWHgApiPMGnM2U4gotf
|
||||
rZ/1nKNkskqcWfRTrXUytNBlfAWr5VD1rak4ETSWm4iVsLxt+8tVwDlaOfWqg7XlfHY6cDjfPGIy
|
||||
4EFJveg+1JsHxmnzzgDl35IHDwafeKabnYLRmXoXamJ1m+ZoDTQara0OOTVZsHr21YoviMHCV9w2
|
||||
Xhyo0Fma0VZh/dXfON9HRTV/7vMAFpFeJxbnZ20Rs28L+x0YiFcKNzDj21P95SexzvRYcVejYNHL
|
||||
3p9w6l15ZzaGYQZWc7gQo/rMKQEr1P+s12KUOD+8gU/9ir193EcVh3mkwnF3eGDjcbS1LV9l8aiE
|
||||
lbfU4hqMtzd6S8X3POFnBStnNPtbjcDbqLGm0n06n77BCkrCD948wBDMDjY8+L7vWRztmSPoLobs
|
||||
/fCAaLN67eeMnxKYHFWJHOpFTvlQEfcwPpWLR7+T6nBPHQ9S9MI6ceprH0yrM/OI8IDBTs6jdNu/
|
||||
DN9WTr3yzmgaX5ATD3fr0/R4tjlWbCmuMlpPOSTqnak0sgPGHo6vm7bhs619j3KhIsKLzDT/8EeV
|
||||
1hpe4nkhF+HtaLPNtTnkXo2LtUA2KUfDD4SscAUY21nSc339iVEYXiSiAuAEfKu9BnSa9hl+HG4i
|
||||
WMTs1aGT3p/w7z7wL7dc0RLgAeMrbKtB08gM7kdXw9nexVrfiZGOPs9uT3R6PmksFWUT1dXxhuWD
|
||||
++l7duVnMGVqiPVdGjpcC+YHhF+uw9bK7bWFrcIYct54xGcd2g6b9t7lT71zGdpRKl2frhTXkeJ9
|
||||
TEnVKDafDxh0yteD3rOp1rgp3uiZ2Tnxjq2vsVdrlKRLvC5Y9vRLT+tFv6DrpZ+JWoHIWVRpfcMt
|
||||
HsSNn2oq3OsgRiqlAbaSl+1wVXSckWv6N5I43Q3wrfdkwHzTHyRWXyxYSb/av/wg3m6t6cJPsJTU
|
||||
lDFxoE2rs6T3ywxD+rxO+8ukpUIowTesh9kh0fhM6dqy1IXtM9OxfBtdwKUjTuA8ePrEcrmiCTwT
|
||||
STCTvSvOZOPjbPjeIvdFHeKJitivfqDP6G7CjBzN/KSNuyV4w21/WAuA3As69GPIvT4uPsx2oq1R
|
||||
zQ+wegGbGLWJtCH83k0QNKHhzVnUpHzifh4ofz1PHmdKpdOuB30AeUY/3jc+Tz2NKrH9xYPkr7Oh
|
||||
Cca7f0PXOb2wdXocA5Zw7AOBV6wSP9bHtPM5wkrIHQdispqv0fNBidD3XJUT3zHYoW7nQyh809Hb
|
||||
H42vI2QobaEa+0+c4XB11putTSA5yhLxrDFKZ41QHd4WDXit8S2rqTo+HvA1zphoPP4GRHncV4jl
|
||||
YzR9g2oE81l+mGLTgpToVcqAWc7LPUpQhDzY2FfA8rkO4Wjw2OOy+FEtmjBc4GMML0S1StXhbOH1
|
||||
RkrTavhxApPGqcevDleddYhy+egBOUmCj2q3GT3JSz/BELGi+wd/5LfHBvRY7WTYIQETXDSPlN3w
|
||||
AgyJhXDEAEgH7RTLcOc5KjlcTueeFtkl+lMPsEISh33czh26z7gnp8K9g29/kic0WAMkh8WgQBh0
|
||||
5gKKy74ix0vLassrjy5o6d4SkfkP0eawuEyQu5V3YufKBcyXbIboy6bvaTl1di+8IzRImSWY09kl
|
||||
PhAI/A6QdbuBKLcG9QMPvrzQTt8nztr6VbHzwJhSFKcslmP9lAp5UXQIjFPiOUTrwWqpX+tP/mtc
|
||||
iPvJCZZQcly+J9rXtirhtexZIJhVTzb8C9aIGy1YfS8x8T7tJ2gMl1/h/az4ntBbRb9AtS7gEoEA
|
||||
O80bgd/9hnyvCuQQJp+eWnp5kXhcatOeQwZdNHy2YWsZPLbXz6tarvfrAz1hKGKv65j+veEV5BhT
|
||||
nahzKpxZqqMVupNRYVvMvJ7/Sv0FbnhHrLpfHHrX5hB+PE8gWos9sAiP3Sy1RnzBxzQK+3l9v2oQ
|
||||
0uyKQ3RuANH8kw6/kxzjbOOXQmSGK/RVTCdwMqtq3Yn6/MtPbL1fOhBOr0uNBll38F2LngHHFZ4O
|
||||
F/4wEcv4qpUQa14Ob2eh8Fgdvx1ivKs3xPWww7kIg2BelEcLLr4tT8vjyqRrVS/RH34h2+2gLexq
|
||||
s7DjvHxipUJK6cangcZxM06uwSWgOHcvcMoDwXvkDy7o+ciZ4XqpP8RRHezwQ/JiID/bxSS6vaJ9
|
||||
X/arQyKBPTGGbglG4SGscPceBhzQQ9ovtzf3hvU1QRPParO27KWwg/zJ2OHj1B6D9ZuZMVQiV8VB
|
||||
9DgFg8a8CpAnB8YTZlvSfvkNtXlf4xv4HgArpW0BNzwm6l37BlRT8g4wnxudqslQg8mLAhOZgoG8
|
||||
9cdHEsRDIKwumnapmPbzQJMJHoKaxcrGX3/4iMDpBYnmfc1KcLsYwvq4PxLtNA2UCF3ng+h10qdd
|
||||
nRFap50aojJPeqx+TnPaSb26R8wcf7G/C0j1zbuOgXs3IB6bjExA90dBB32sF95Mj1YvbN/Rxk9I
|
||||
Yow64HOQFKChzJngJv72q6r2rOSwbIQV5VgFg4dWH5k3scTu+q0Ambs3A1fR8vBt/I7BPBwOCdyJ
|
||||
fEiMV6pro+d9JVCY94FYXzcDvaWXPkpFGeN4qY2eqyJlRnBfH7yPfmTpwOc6I92mMMJX5MQVbYyO
|
||||
lcbbxd/yray2/V/g5/o2cZ7qi7a691eC1FoVf3yhmp+q+YbvCjUkQ/lXowa9qHA9PSBR7WFK5/KR
|
||||
WX/qXYfYCfQ5m13gGLy/2FuPQz/o+rlD6mwM5NZbcs9zXlLD4/1IyHH+wIrGZ1aCA7l88AlWJFjk
|
||||
OjQRMydfD5qDDFojYyL44+PO+lyDRa1OM7zlzh5rytugQh2QGEz6ycAGcvbV+sZ6iXKDuU/rwW/o
|
||||
d+MbwJyHFWdfjnVmfgAJ+O4PDnHzB5cSK/AHwJeJOQmrPFakbNccJRwi2OqRo/G96Tykl1buvPHw
|
||||
clK2ZVAJxITTSJ4chWAKqBPDTb8RV3zNdEnvyQq9em9h88GE/WrRtQU/fXXAg06Xj36SITeKb2Ie
|
||||
P0Ww6GaygmOUO8RwzmUwd/xLRdM9GsnJrJd+0x8h2PALO1nTASp/Yhvm/JvDxySVwAgnkUcbvyWG
|
||||
SC4O6198HlbOtcAGhiCYHkzBI2ZiGqw97RP9jNXThGlkn7APhKPDO8ESAeMqGMTK0Fwt+nOXgy6S
|
||||
NSID4aiRXVyyUGVYi4RZ7lA+oE4CZRrJRNvPvEM13zChJr5u3szEVzoPV8SDa6Iy+CSpaiXgQZog
|
||||
aOgJK2fz7FBWGX3QS7DE2vHm9iy+3VT49mJ54saPowmPq1xA4SSJZLu/gGbJ/EblmwrePn1UYCK7
|
||||
eoYnK2KJ9uNb2KgKVLZZThyLVMEsj+fwd37TaxrELV8T9Yd3WJ51tV+b8+RDJN/DTZ/LgQAPlxYq
|
||||
OU+J+Rk4rVX2hxruibNMHGNeAqGLqIys0rLIvTqCYH6xCgPZ83rFch+9wNAyXAmsTH16/J75gtWF
|
||||
7gPmXmOTw+GLUrryJxPeB93A12BXgpmmvQtxpcX48OHNfnnqeIK/88eTH/XLcjR4KJbu5efngHm3
|
||||
93kwOe+AeJirwLpOwR5BE7XE84vM4YmX+fAwV19vJ5lvSpMskyHoZIpl6B37ae/DCGqWBIjiiZlG
|
||||
iywJYQpG6rH53ajmvBRXsHfPBGNmycDS5EHNFfYhmCStmOliGZMLW2Xq8UkJb9Wq3ezo5/8Qyw7T
|
||||
fmzBPof5x2L/3C9CSz9H3rMQt/yWA77eKwWsDl+RHGuurKZz0/vQct45OZiaXa0JtSZo4zEmj/Rt
|
||||
pbN/vOSIUZ8L0eXqQckS7UNkZfKThPdD3xP541uIpHebeIOwVpueV9FTv2FsVB8/ZfscenCPUYST
|
||||
p2AGwtwNDEj1ycWHwOpTOuj8Bb7qu4LtyXsFlG2tEpYscyVHxaXOUvrnAaaXR01wRnC62rtklTa+
|
||||
hR1hGdI581Ifci2TTNLkKcH8uJ1bBKvHk+SPqA8oLeMc6L6h4OTerRUlju8iIn0+5KkQyVm3+waR
|
||||
LVv48PMn9kfNhsGS7DF2A+/P+6Dn6fuJ9xqxmupvasKyKsNpOZtnjQaXVw43vj7B6kaccfN34GGI
|
||||
NXw5LmxFJLDm6E7BBePzxwros/Yn1LRiSix/jyk7MRIPzVGzscmVWspl+6aAnkk6rL9V2elEET+A
|
||||
wWkFUTAbAbI2UQdTLmwnfsOrmR9oAn5+1MZXNYrIZwbnqSqwBaq4GqD8SX7xxNiBI/3hF6yk4U4y
|
||||
G52BUAdNjN479eEtHybu1+SdXODiFxYxjF2bzk3UxaAy44hc47HqOU6Q9rC7hA/y03NLt+/sP3h0
|
||||
P9xNSoB/lmHER8MEnsvNmdeDPv38S2IEO5XysGFruOHHJLEO2087f+mQJl9sfCx0Lp2E9mVLDfz8
|
||||
H3z9+Y+92T5J0N1TjRejdwKb2e+wB0eZCoo+MPBz3j9JnD4qOvO2HyI218epZ773dGaZvSVKTvbE
|
||||
zmRdnMU2YxNsfi22FokJ6DU/1jDx4yOJDFF0Vv0hJbBdbwdyXIoo5cluWiHiISLKs/s4czHd9/B0
|
||||
OjPeromPlaDloICptzJTtfEvfqsvUHbCK85pvNPWTI8hXKAMiBJBvRLq3TpA53jUsXWm34rW5ZWF
|
||||
9sFsiCwUFqDrTsvBhic/Pzcgh1PRIl5/etNSlhdnLcx1QMHBabB5CWQgiPpeguvudfeks9CDGV01
|
||||
CRJhrxDsxGGwgK8yo9N95TwmONvpli8xOuFBxUotrikVTkYsbviDbZsw2urm5A0vjnojxqbvOYja
|
||||
C9TOk7Pxt5VOh+fDhrITXYlyMPlgIl52gT7lLwTv/XMw9zl0oV0HKsafMtJ6rRH3YNPHHkpz0SEO
|
||||
c6rBFl8PHm4inU971UM/v3njc/1YKr0PkuQBMN78ofVjehbIh0KdOss79/RCwgeaUWdgIypwRa/D
|
||||
Xv35H8T/HCetu+elBze/yqNB+OxX7aaG6B4OAlZrq3MWOC0s2vQrPi4Fn071/ljCULAhdkKxC4Tj
|
||||
3Vp//yd4ifV0OUlaKL5cZ/akMI77OWsECAfxGk/DdQrB/LAqV9rwHpt6V1fLK899wIZPcfps/IGA
|
||||
t9FC/eN+SSSKVT/sv6MJvOvjgI/bfRluJcx/+OsJwuVQkW1/8OcnRZ/hqpEhGAawOMmTyGe3SJfY
|
||||
st9wjdGM5UEBThMOng02PMHpeIrTdZ1SCT7vuYCtpf5Um1/jwWU3lfiw3yf9eH5KJuSr73d68c8d
|
||||
HZ5CfIGJ2EdeluC6H+7+bpB0r9thNb9qwdLg0Yab3+S9SLMC2j3F5Fff8LFdOUCYizTDyzv4bHpP
|
||||
6ac6eXfwgFh92vCkWhRNVCWPNyti2/rH+cMHNj1JsNoFlHuN7gB/eu6w6c25a9Y3PPUePzEgrau5
|
||||
H8sWfs+vkiiS0zgr2U3zD3+J+wVawLMiq4p5qitb/jt0nC+6h+BYnT1287/G5XhioWDjCza+Xgfo
|
||||
O+ImOEGXkOweukDoxFwH0XNRsU+Pbb/A8wMC9TsU5K64VBuDR79CY/QCchIPLaBRtbTIL+Mz9oJz
|
||||
l87v0OrgtH9Zf/g0a+CDBJ4NPhNN2ec9laPXgPKhVLE5JdeKX0x5RUNfM8S5srrDea1Xw+FCCXFD
|
||||
CrRe5NUYNpa/89Ys+VYziBMXuJ5oEWVcdsEqioccNP1RmVB9P1V00BkfIvKZsDKadUA2fQ/XhcrE
|
||||
PJNDwNvzOwTIJQO2ETvR9YVYH8nXu0WwA0+UZRIA4bQcpenr9i9tuJvHHAYnk8XWk+oOnYzWRvVR
|
||||
OmI1HL7OWl4SHsTCqnprBXhtBfPK/PETVLK8gxXoTQt/9ce1rUOwvtxuhTrfdjgUOSXd+jUR4n1o
|
||||
kNvVYPqhOQAIG9ioRFWYkfbV1JVQNdl8i9/ozBoBJpS6XvLKPnlr9BBHe9hO/dMT/UNTzbS7t+Cn
|
||||
f8yBVakwn+UL0voZYVUyBNpMT52Hm77CrtAjbfzohiotXS0RXSqkYNz8Ydg4TE0co7lWa/SwHrAX
|
||||
fEDOF8WmA2aphGqp/GAVnYyK/eknTG4zdrf+Bx837RvWt64izidZKpKhtIOWjn1vTU9cvw4OKP74
|
||||
G27y9SqujqvyT3ycnM+C+aenLkXDTEAzW0B//mckqQGxxPuQbniXw7NPD9P+ZlT93BwoA5Rl4rFN
|
||||
Mgp+fgtqC2AS75kuwXjlpQhKB7qfaIwkrb37wgCUnKWbuXTUtn5OCX563rt7WbVemb6Al0cee4Ji
|
||||
zXRuojL+g/+O63zBuvVv4MjnPTGKwk6XJCgH4BhZNoHhwlS02KEYZucu8sQDklN+5y8t2urptLjE
|
||||
p+MLExaQUztjL3T1YN09CxMpjOYRZ/emv35E9/N7sBeNM/0mQTdB8WR+seXga7pgoy/h7/w3Pkzn
|
||||
PNQiKO/PNn7gRAl+/BSuzXTC0fgEdG57aw+as4mx7J/NXiidaw09XTbxTx99Gc/rwOPb9eTg3phf
|
||||
P6aFy90WPenD3qtl67cARTR1bJH7uZ/rZ7GiD7u+iMWhD5imW9BB/l7zxFJq4AxN1CVQi/rrNOtp
|
||||
nC7P1/AGSuFzePOHqzk9f99w81vxIXqfKf1IdvvTo+QEKxyspfSB0OpmER8SsDrLvm5l2Do0m3a+
|
||||
eXXW2pxldLw7hDjbeQ9bPwKW7TPHbrC8Kirydgw7WUqJcUw9IJCR+n/8Gquw39Xy45PdyaNbf0yv
|
||||
eNeWINSDucUGIUK6/PrV/bd4TWCrH6yUFgXqtM8O6+cdAtMtsHW4a5OAWN6t0pYqPUAoNUzk9bJS
|
||||
/vRVDiRDD6dff41w4/wGh3Z4ELvwTMpLvSr9/GmibPySpgstkYZcxoNKYlLaRUCFW78HywNk6fLh
|
||||
8lL6Puob1hP1E7TNATDQGTST2JE3ONRIPy6MmF3mwQ0vh7h9F1B+zBGJk5et8Y3bmpDgffTzryjX
|
||||
vp0WTiz3msRVJ2CuVWcPHGXhyeH2dCrOlDwVdIXJEwdNu75Tzxcf7YziTDJ2bqu1Fx4Q/v2bCvjP
|
||||
f/311//4TRjUbZZ/tsGAMV/Gf//XqMC/hX8PdfL5/BlDmIakyP/+539PIPz97dv6O/7PsX3nzfD3
|
||||
P38Jf0YN/h7bMfn8P4//tb3oP//1vwAAAP//AwDPjjDU3iAAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 97d174615cfef96b-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 10 Sep 2025 19:50:30 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=eYh.U8kiOc9xS0U2L8g4MiopA6w9E7lUuodx4D.rMOU-1757533830-1.0.1.1-YO2od1GbrHRgwOEdJSw3gCcNy8XFBF_O.jT_f8F2z6dWZsBIS7XPLWUpJAzenthO1wXRkx7OZDmVrPCPro2sSj1srJCxCY8KgIwcjw5NWGU;
|
||||
path=/; expires=Wed, 10-Sep-25 20:20:30 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=vkbBikeJy.dDV.o7ZB2HjcJaD_hkp9dDeCEBfHZxG94-1757533830280-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-allow-origin:
|
||||
- '*'
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-model:
|
||||
- text-embedding-3-small
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '172'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
via:
|
||||
- envoy-router-59c745856-z5gxd
|
||||
x-envoy-upstream-service-time:
|
||||
- '267'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '10000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '9999996'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_06f3f9465f1a4af0ae5a4d8a58f19321
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Test Knowledge creation and querying functionality."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -22,7 +23,7 @@ def mock_vector_db():
|
||||
instance = mock.return_value
|
||||
instance.query.return_value = [
|
||||
{
|
||||
"content": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"context": "Brandon's favorite color is blue and he likes Mexican food.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -43,13 +44,13 @@ def test_single_short_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite color?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("blue" in result["content"].lower() for result in results)
|
||||
assert any("blue" in result["context"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -83,14 +84,14 @@ def test_single_2k_character_string(mock_vector_db):
|
||||
content=content, metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [string_source]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -108,7 +109,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
|
||||
# Mock the vector db query response
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "Brandon has a dog named Max.", "score": 0.9}
|
||||
{"context": "Brandon has a dog named Max.", "score": 0.9}
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
@@ -118,7 +119,7 @@ def test_multiple_short_strings(mock_vector_db):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("max" in result["content"].lower() for result in results)
|
||||
assert any("max" in result["context"].lower() for result in results)
|
||||
# Verify the mock was called
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
@@ -179,7 +180,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
]
|
||||
|
||||
mock_vector_db.sources = string_sources
|
||||
mock_vector_db.query.return_value = [{"content": contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
@@ -187,7 +188,7 @@ def test_multiple_2k_character_strings(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -204,13 +205,13 @@ def test_single_short_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What sport does Brandon like?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("basketball" in result["content"].lower() for result in results)
|
||||
assert any("basketball" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -246,13 +247,13 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
|
||||
file_paths=[file_path], metadata={"preference": "personal"}
|
||||
)
|
||||
mock_vector_db.sources = [file_source]
|
||||
mock_vector_db.query.return_value = [{"content": content, "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite movie?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the results contain the expected information
|
||||
assert any("inception" in result["content"].lower() for result in results)
|
||||
assert any("inception" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -285,13 +286,13 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
|
||||
]
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "Brandon lives in New York.", "score": 0.9}
|
||||
{"context": "Brandon lives in New York.", "score": 0.9}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What city does he reside in?"
|
||||
results = mock_vector_db.query(query)
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("new york" in result["content"].lower() for result in results)
|
||||
assert any("new york" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -359,7 +360,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
mock_vector_db.sources = file_sources
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"content": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"context": "Brandon's favorite book is 'The Hitchhiker's Guide to the Galaxy'.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
@@ -369,7 +370,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"the hitchhiker's guide to the galaxy" in result["content"].lower()
|
||||
"the hitchhiker's guide to the galaxy" in result["context"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -406,14 +407,14 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
|
||||
|
||||
# Combine string and file sources
|
||||
mock_vector_db.sources = string_sources + file_sources
|
||||
mock_vector_db.query.return_value = [{"content": file_contents[1], "score": 0.9}]
|
||||
mock_vector_db.query.return_value = [{"context": file_contents[1], "score": 0.9}]
|
||||
|
||||
# Perform a query
|
||||
query = "What is Brandon's favorite book?"
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("the alchemist" in result["content"].lower() for result in results)
|
||||
assert any("the alchemist" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -429,7 +430,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
)
|
||||
mock_vector_db.sources = [pdf_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
{"context": "crewai create crew latest-ai-development", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -438,7 +439,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any(
|
||||
"crewai create crew latest-ai-development" in result["content"].lower()
|
||||
"crewai create crew latest-ai-development" in result["context"].lower()
|
||||
for result in results
|
||||
)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
@@ -466,7 +467,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [csv_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -474,7 +475,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["content"] for result in results)
|
||||
assert any("30" in result["context"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -501,7 +502,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [json_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
{"context": "Alice lives in Los Angeles.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -509,7 +510,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("los angeles" in result["content"].lower() for result in results)
|
||||
assert any("los angeles" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -517,7 +518,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
"""Test ExcelKnowledgeSource with a simple Excel file."""
|
||||
|
||||
# Create an Excel file with sample data
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
import pandas as pd
|
||||
|
||||
excel_data = {
|
||||
"Name": ["Brandon", "Alice", "Bob"],
|
||||
@@ -534,7 +535,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
)
|
||||
mock_vector_db.sources = [excel_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{"content": "Brandon is 30 years old.", "score": 0.9}
|
||||
{"context": "Brandon is 30 years old.", "score": 0.9}
|
||||
]
|
||||
|
||||
# Perform a query
|
||||
@@ -542,7 +543,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
||||
results = mock_vector_db.query(query)
|
||||
|
||||
# Assert that the correct information is retrieved
|
||||
assert any("30" in result["content"] for result in results)
|
||||
assert any("30" in result["context"] for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@@ -556,20 +557,20 @@ def test_docling_source(mock_vector_db):
|
||||
mock_vector_db.sources = [docling_source]
|
||||
mock_vector_db.query.return_value = [
|
||||
{
|
||||
"content": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
# Perform a query
|
||||
query = "What is reward hacking?"
|
||||
results = mock_vector_db.query(query)
|
||||
assert any("reward hacking" in result["content"].lower() for result in results)
|
||||
assert any("reward hacking" in result["context"].lower() for result in results)
|
||||
mock_vector_db.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_multiple_docling_sources() -> None:
|
||||
urls: list[Path | str] = [
|
||||
def test_multiple_docling_sources():
|
||||
urls: List[Union[Path, str]] = [
|
||||
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
||||
]
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Tests for Knowledge SearchResult type conversion and integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge # type: ignore[import-untyped]
|
||||
from crewai.knowledge.source.string_knowledge_source import ( # type: ignore[import-untyped]
|
||||
StringKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import ( # type: ignore[import-untyped]
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
|
||||
def test_knowledge_query_returns_searchresult() -> None:
|
||||
"""Test that Knowledge.query returns SearchResult format."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = [
|
||||
{
|
||||
"content": "AI is fascinating",
|
||||
"score": 0.9,
|
||||
"metadata": {"source": "doc1"},
|
||||
},
|
||||
{
|
||||
"content": "Machine learning rocks",
|
||||
"score": 0.8,
|
||||
"metadata": {"source": "doc2"},
|
||||
},
|
||||
]
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test knowledge content")]
|
||||
knowledge = Knowledge(collection_name="test_collection", sources=sources)
|
||||
|
||||
results = knowledge.query(
|
||||
["AI technology"], results_limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
mock_storage.search.assert_called_once_with(
|
||||
["AI technology"], limit=5, score_threshold=0.3
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 2
|
||||
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
assert "content" in result
|
||||
assert "score" in result
|
||||
assert "metadata" in result
|
||||
|
||||
assert results[0]["content"] == "AI is fascinating"
|
||||
assert results[0]["score"] == 0.9
|
||||
assert results[1]["content"] == "Machine learning rocks"
|
||||
assert results[1]["score"] == 0.8
|
||||
|
||||
|
||||
def test_knowledge_query_with_empty_results() -> None:
|
||||
"""Test Knowledge.query with empty search results."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.return_value = []
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="empty_test", sources=sources)
|
||||
|
||||
results = knowledge.query(["nonexistent query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_searchresult() -> None:
|
||||
"""Test extract_knowledge_context works with SearchResult format."""
|
||||
search_results = [
|
||||
{"content": "Python is great for AI", "score": 0.95, "metadata": {}},
|
||||
{"content": "Machine learning algorithms", "score": 0.88, "metadata": {}},
|
||||
{"content": "Deep learning frameworks", "score": 0.82, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Python is great for AI" in context
|
||||
assert "Machine learning algorithms" in context
|
||||
assert "Deep learning frameworks" in context
|
||||
|
||||
expected_content = (
|
||||
"Python is great for AI\nMachine learning algorithms\nDeep learning frameworks"
|
||||
)
|
||||
assert expected_content in context
|
||||
|
||||
|
||||
def test_extract_knowledge_context_with_empty_content() -> None:
|
||||
"""Test extract_knowledge_context handles empty or invalid content."""
|
||||
search_results = [
|
||||
{"content": "", "score": 0.5, "metadata": {}},
|
||||
{"content": None, "score": 0.4, "metadata": {}},
|
||||
{"score": 0.3, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
def test_extract_knowledge_context_filters_invalid_results() -> None:
|
||||
"""Test that extract_knowledge_context filters out invalid results."""
|
||||
search_results: list[dict[str, Any] | None] = [
|
||||
{"content": "Valid content 1", "score": 0.9, "metadata": {}},
|
||||
{"content": "", "score": 0.8, "metadata": {}},
|
||||
{"content": "Valid content 2", "score": 0.7, "metadata": {}},
|
||||
None,
|
||||
{"content": None, "score": 0.6, "metadata": {}},
|
||||
]
|
||||
|
||||
context = extract_knowledge_context(search_results)
|
||||
|
||||
assert "Additional Information:" in context
|
||||
assert "Valid content 1" in context
|
||||
assert "Valid content 2" in context
|
||||
assert context.count("\n") == 1
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_storage_exception_handling(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge handles storage exceptions gracefully."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.search.side_effect = Exception("Storage error")
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="error_test", sources=sources)
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.storage = None
|
||||
knowledge.query(["test query"])
|
||||
|
||||
|
||||
def test_knowledge_add_sources_integration() -> None:
|
||||
"""Test Knowledge.add_sources integrates properly with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [
|
||||
StringKnowledgeSource(content="Content 1"),
|
||||
StringKnowledgeSource(content="Content 2"),
|
||||
]
|
||||
knowledge = Knowledge(collection_name="add_sources_test", sources=sources)
|
||||
|
||||
knowledge.add_sources()
|
||||
|
||||
for source in sources:
|
||||
assert source.storage == mock_storage
|
||||
|
||||
|
||||
def test_knowledge_reset_integration() -> None:
|
||||
"""Test Knowledge.reset integrates with storage."""
|
||||
with patch("crewai.knowledge.knowledge.KnowledgeStorage") as mock_storage_class:
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="reset_test", sources=sources)
|
||||
|
||||
knowledge.reset()
|
||||
|
||||
mock_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.rag.config.utils.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.KnowledgeStorage")
|
||||
def test_knowledge_reset_without_storage(
|
||||
mock_storage_class: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test Knowledge.reset raises error when storage is None."""
|
||||
sources = [StringKnowledgeSource(content="Test content")]
|
||||
knowledge = Knowledge(collection_name="no_storage_test", sources=sources)
|
||||
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
knowledge.reset()
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Integration tests for KnowledgeStorage RAG client migration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_knowledge_storage_uses_rag_client(
|
||||
mock_get_embedding: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that KnowledgeStorage properly integrates with RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_create_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "test content", "score": 0.9, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
embedder_config = {"provider": "openai", "model": "text-embedding-3-small"}
|
||||
storage = KnowledgeStorage(
|
||||
embedder=embedder_config, collection_name="test_knowledge"
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
|
||||
results = storage.search(["test query"], limit=5, score_threshold=0.3)
|
||||
|
||||
mock_get_client.assert_not_called()
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name="knowledge_test_knowledge",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter=None,
|
||||
score_threshold=0.3,
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], dict)
|
||||
assert "content" in results[0]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_collection_name_prefixing(mock_get_client: MagicMock) -> None:
|
||||
"""Test that collection names are properly prefixed."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage(collection_name="custom_knowledge")
|
||||
storage.search(["test"], limit=1)
|
||||
|
||||
mock_client.search.assert_called_once()
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge_custom_knowledge"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage_default = KnowledgeStorage()
|
||||
storage_default.search(["test"], limit=1)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["collection_name"] == "knowledge"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_save_documents_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test document saving through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_docs")
|
||||
documents = ["Document 1 content", "Document 2 content"]
|
||||
|
||||
storage.save(documents)
|
||||
|
||||
mock_client.get_or_create_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_docs"
|
||||
)
|
||||
mock_client.add_documents.assert_called_once()
|
||||
|
||||
call_kwargs = mock_client.add_documents.call_args.kwargs
|
||||
added_docs = call_kwargs["documents"]
|
||||
assert len(added_docs) == 2
|
||||
assert added_docs[0]["content"] == "Document 1 content"
|
||||
assert added_docs[1]["content"] == "Document 2 content"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_reset_integration(mock_get_client: MagicMock) -> None:
|
||||
"""Test collection reset through RAG client."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_reset")
|
||||
storage.reset()
|
||||
|
||||
mock_client.delete_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_reset"
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_search_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test error handling during search operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = Exception("RAG client error")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="error_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_embedding_configuration_flow(
|
||||
mock_get_embedding: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that embedding configuration flows properly to RAG client."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_get_embedding.return_value = mock_embedding_func
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_query_list_conversion(mock_get_client: MagicMock) -> None:
|
||||
"""Test that query list is properly converted to string."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
storage.search(["single query"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "single query"
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["query one", "query two"])
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "query one query two"
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_metadata_filter_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test metadata filter parameter handling."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = []
|
||||
|
||||
storage = KnowledgeStorage()
|
||||
|
||||
metadata_filter = {"category": "technical", "priority": "high"}
|
||||
storage.search(["test"], metadata_filter=metadata_filter)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] == metadata_filter
|
||||
|
||||
mock_client.reset_mock()
|
||||
storage.search(["test"], metadata_filter=None)
|
||||
|
||||
call_kwargs = mock_client.search.call_args.kwargs
|
||||
assert call_kwargs["metadata_filter"] is None
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
||||
"""Test specific handling of dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception("dimension mismatch detected")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_test")
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
storage.save(["test document"])
|
||||
@@ -1,20 +1,19 @@
|
||||
from unittest.mock import patch, ANY
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.task import Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -39,23 +38,22 @@ def short_term_memory():
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
|
||||
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]):
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -175,12 +173,12 @@ def test_save_and_search(short_term_memory):
|
||||
|
||||
expected_result = [
|
||||
{
|
||||
"content": memory.data,
|
||||
"context": memory.data,
|
||||
"metadata": {"agent": "test_agent"},
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
with patch.object(ShortTermMemory, "search", return_value=expected_result):
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["content"] == memory.data, "Data value mismatch."
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
|
||||
@@ -285,43 +285,6 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
def test_add_documents_all_without_metadata(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2"},
|
||||
{"content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] is None
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
@@ -395,31 +358,6 @@ class TestChromaDBClient:
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_without_metadata(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_is_ipv4_pattern,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
"""Test suite for ChromaDB utility functions."""
|
||||
|
||||
def test_sanitize_collection_name_long_name(self) -> None:
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = _sanitize_collection_name(long_name)
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self) -> None:
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = _sanitize_collection_name(special_chars)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_short_name(self) -> None:
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = _sanitize_collection_name(short_name)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self) -> None:
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = _sanitize_collection_name(bad_ends)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_none(self) -> None:
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = _sanitize_collection_name(None)
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = _sanitize_collection_name(ipv4)
|
||||
assert sanitized.startswith("ip_")
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_is_ipv4_pattern(self) -> None:
|
||||
"""Test IPv4 pattern detection."""
|
||||
assert _is_ipv4_pattern("192.168.1.1") is True
|
||||
assert _is_ipv4_pattern("not.an.ip.address") is False
|
||||
|
||||
def test_sanitize_collection_name_properties(self) -> None:
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases: list[str] = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = _sanitize_collection_name(test_case)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_empty_string(self) -> None:
|
||||
"""Test sanitizing an empty string."""
|
||||
sanitized = _sanitize_collection_name("")
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_whitespace_only(self) -> None:
|
||||
"""Test sanitizing a string with only whitespace."""
|
||||
sanitized = _sanitize_collection_name(" ")
|
||||
assert (
|
||||
sanitized == "a__z"
|
||||
) # Spaces become underscores, padded to meet requirements
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
@@ -1,250 +0,0 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Tests for RAG client error handling scenarios."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles RAG client connection failures."""
|
||||
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="connection_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles search timeouts gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = TimeoutError("Search operation timed out")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="timeout_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles missing collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = ValueError(
|
||||
"Collection 'knowledge_missing' does not exist"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="missing_collection")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles invalid embedding configurations."""
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.get_embedding_function"
|
||||
) as mock_get_embedding:
|
||||
mock_get_embedding.side_effect = ValueError(
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception(
|
||||
"attempt to write a readonly database"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="readonly_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test KnowledgeStorage reset handles non-existent collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="nonexistent_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "valid result", "metadata": {"source": "test"}},
|
||||
{"invalid": "missing content field", "metadata": {"source": "test"}},
|
||||
None,
|
||||
{"content": None, "metadata": {"source": "test"}},
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="malformed_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles network interruptions during operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_client.search.side_effect = [
|
||||
ConnectionError("Network interruption"),
|
||||
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="network_test")
|
||||
|
||||
first_attempt = storage.search(["test query"])
|
||||
assert first_attempt == []
|
||||
|
||||
mock_client.search.side_effect = None
|
||||
mock_client.search.return_value = [
|
||||
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
second_attempt = storage.search(["test query"])
|
||||
assert len(second_attempt) == 1
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test detailed handling of embedding dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception(
|
||||
"Embedding dimension mismatch: expected 384, got 1536"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
storage.save(["test document"])
|
||||
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
assert "Make sure you're using the same embedding model" in str(exc_info.value)
|
||||
assert "crewai reset-memories -a" in str(exc_info.value)
|
||||
@@ -1,7 +1,8 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0.client.main import MemoryClient
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
@@ -12,71 +13,11 @@ class MockCrew:
|
||||
self.agents = [MagicMock(role="Test Agent")]
|
||||
|
||||
|
||||
# Test data constants
|
||||
SYSTEM_CONTENT = (
|
||||
"You are Friendly chatbot assistant. You are a kind and "
|
||||
"knowledgeable chatbot assistant. You excel at understanding user needs, "
|
||||
"providing helpful responses, and maintaining engaging conversations. "
|
||||
"You remember previous interactions to provide a personalized experience.\n"
|
||||
"Your personal goal is: Engage in useful and interesting conversations "
|
||||
"with users while remembering context.\n"
|
||||
"To give my best complete final answer to the task respond using the exact "
|
||||
"following format:\n\n"
|
||||
"Thought: I now can give a great answer\n"
|
||||
"Final Answer: Your final answer must be the great and the most complete "
|
||||
"as possible, it must be outcome described.\n\n"
|
||||
"I MUST use these formats, my job depends on it!"
|
||||
)
|
||||
|
||||
USER_CONTENT = (
|
||||
"\nCurrent Task: Respond to user conversation. User message: "
|
||||
"What do you know about me?\n\n"
|
||||
"This is the expected criteria for your final answer: Contextually "
|
||||
"appropriate, helpful, and friendly response.\n"
|
||||
"you MUST return the actual complete content as the final answer, "
|
||||
"not a summary.\n\n"
|
||||
"# Useful context: \nExternal memories:\n"
|
||||
"- User is from India\n"
|
||||
"- User is interested in the solar system\n"
|
||||
"- User name is Vidit Ostwal\n"
|
||||
"- User is interested in French cuisine\n\n"
|
||||
"Begin! This is VERY important to you, use the tools available and give "
|
||||
"your best Final Answer, your job depends on it!\n\n"
|
||||
"Thought:"
|
||||
)
|
||||
|
||||
ASSISTANT_CONTENT = (
|
||||
"I now can give a great answer \n"
|
||||
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
TEST_DESCRIPTION = (
|
||||
"Respond to user conversation. User message: What do you know about me?"
|
||||
)
|
||||
|
||||
# Extracted content (after processing by _get_user_message and _get_assistant_message)
|
||||
EXTRACTED_USER_CONTENT = "What do you know about me?"
|
||||
EXTRACTED_ASSISTANT_CONTENT = (
|
||||
"Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
"""Fixture to create a mock Memory instance"""
|
||||
return MagicMock(spec=Memory)
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
return mock_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -84,9 +25,7 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# Patch the Memory class to return our mock
|
||||
with patch(
|
||||
"mem0.Memory.from_config", return_value=mock_mem0_memory
|
||||
) as mock_from_config:
|
||||
with patch("mem0.memory.main.Memory.from_config", return_value=mock_mem0_memory) as mock_from_config:
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mock_vector_store",
|
||||
@@ -117,14 +56,7 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"local_mem0_config": config,
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
embedder_config={"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1", "infer" : True}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage, mock_from_config, config
|
||||
@@ -141,7 +73,8 @@ def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory_client():
|
||||
"""Fixture to create a mock MemoryClient instance"""
|
||||
return MagicMock(spec=MemoryClient)
|
||||
mock_memory = MagicMock(spec=MemoryClient)
|
||||
return mock_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -152,35 +85,34 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
embedder_config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True
|
||||
}
|
||||
|
||||
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client_using_explictly_config(
|
||||
mock_mem0_memory_client, mock_mem0_memory
|
||||
):
|
||||
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch both MemoryClient and Memory to prevent actual initialization
|
||||
with (
|
||||
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
|
||||
):
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
|
||||
|
||||
crew = MockCrew()
|
||||
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
|
||||
|
||||
return Mem0Storage(type="short_term", crew=crew, config=new_config)
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=new_config)
|
||||
return mem0_storage
|
||||
|
||||
|
||||
def test_mem0_storage_with_memory_client_initialization(
|
||||
@@ -210,23 +142,18 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
mock_mem0_memory_client.update_project = MagicMock()
|
||||
|
||||
new_categories = [
|
||||
{
|
||||
"lifestyle_management_concerns": (
|
||||
"Tracks daily routines, habits, hobbies and interests "
|
||||
"including cooking, time management and work-life balance"
|
||||
)
|
||||
},
|
||||
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
|
||||
]
|
||||
|
||||
crew = MockCrew()
|
||||
|
||||
config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories,
|
||||
}
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories
|
||||
}
|
||||
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
_ = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
@@ -236,6 +163,8 @@ def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_cl
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
@@ -243,134 +172,68 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
agent_id='Test_Agent'
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.crew.agents = [
|
||||
MagicMock(role="Test Agent"),
|
||||
MagicMock(role="Test Agent 2"),
|
||||
MagicMock(role="Test Agent 3"),
|
||||
]
|
||||
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
[{"role": "assistant" , "content": test_value}],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
|
||||
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
test_metadata = {"key": "value"}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
[{'role': 'assistant' , 'content': test_value}],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
metadata={"type": "short_term", "key": "value"},
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
includes="include1",
|
||||
excludes="exclude1",
|
||||
output_format="v1.1",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
output_format='v1.1',
|
||||
user_id='test_user',
|
||||
agent_id='Test_Agent'
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -379,25 +242,18 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="test_user",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
assert results[0]["context"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -407,15 +263,15 @@ def test_search_method_with_memory_client(
|
||||
limit=5,
|
||||
metadata={"type": "short_term"},
|
||||
user_id="test_user",
|
||||
version="v2",
|
||||
version='v2',
|
||||
run_id="my_run_id",
|
||||
output_format="v1.1",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
output_format='v1.1',
|
||||
filters={'AND': [{'run_id': 'my_run_id'}]},
|
||||
threshold=0.5
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
assert results[0]["context"] == "Result 1"
|
||||
|
||||
|
||||
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
@@ -423,12 +279,14 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
|
||||
config={
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH"
|
||||
}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
assert mem0_storage.infer is True
|
||||
|
||||
|
||||
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
@@ -439,25 +297,19 @@ def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
mem0_storage.save("test memory", {"key": "value"})
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{"role": "assistant", "content": "test memory"}],
|
||||
[{'role': 'assistant' , 'content': 'test memory'}],
|
||||
infer=True,
|
||||
metadata={"type": "external", "key": "value"},
|
||||
agent_id="agent-123",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_agent_entity():
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
@@ -466,29 +318,22 @@ def test_search_method_with_agent_entity():
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
assert results[0]["context"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(
|
||||
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
|
||||
)
|
||||
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
|
||||
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
@@ -496,10 +341,10 @@ def test_search_method_with_agent_id_and_user_id():
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="user-123",
|
||||
user_id='user-123',
|
||||
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
assert results[0]["context"] == "Result 1"
|
||||
|
||||
@@ -1,216 +0,0 @@
|
||||
# ruff: noqa: S105
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.context import (
|
||||
set_platform_integration_token,
|
||||
get_platform_integration_token,
|
||||
platform_context,
|
||||
_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformIntegrationToken:
|
||||
def setup_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
def teardown_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
def test_set_platform_integration_token(self):
|
||||
test_token = "test-token-123"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
set_platform_integration_token(test_token)
|
||||
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
def test_get_platform_integration_token_from_context_var(self):
|
||||
test_token = "context-var-token"
|
||||
|
||||
_platform_integration_token.set(test_token)
|
||||
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
|
||||
def test_get_platform_integration_token_from_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
assert get_platform_integration_token() == "env-token-456"
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
|
||||
def test_context_var_takes_precedence_over_env_var(self):
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(context_token)
|
||||
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_platform_integration_token_returns_none_when_not_set(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_basic_usage(self):
|
||||
test_token = "context-manager-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_nested_contexts(self):
|
||||
"""Test nested platform_context context managers."""
|
||||
outer_token = "outer-token"
|
||||
inner_token = "inner-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(outer_token):
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
with platform_context(inner_token):
|
||||
assert get_platform_integration_token() == inner_token
|
||||
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_preserves_existing_token(self):
|
||||
"""Test that platform_context preserves existing token when exiting."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(initial_token)
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_platform_context_manager_exception_handling(self):
|
||||
"""Test that platform_context properly resets token even when exception occurs."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(initial_token)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise ValueError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_platform_context_manager_with_none_initial_state(self):
|
||||
"""Test platform_context when initial state is None."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise RuntimeError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
|
||||
def test_platform_context_with_env_fallback(self):
|
||||
"""Test platform_context interaction with environment variable fallback."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
def test_multiple_sequential_context_managers(self):
|
||||
"""Test multiple sequential uses of platform_context."""
|
||||
token1 = "token-1"
|
||||
token2 = "token-2"
|
||||
token3 = "token-3"
|
||||
|
||||
with platform_context(token1):
|
||||
assert get_platform_integration_token() == token1
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token2):
|
||||
assert get_platform_integration_token() == token2
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token3):
|
||||
assert get_platform_integration_token() == token3
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_empty_string_token(self):
|
||||
empty_token = ""
|
||||
|
||||
set_platform_integration_token(empty_token)
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
with platform_context(empty_token):
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
def test_special_characters_in_token(self):
|
||||
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
|
||||
|
||||
set_platform_integration_token(special_token)
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
with platform_context(special_token):
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
def test_very_long_token(self):
|
||||
long_token = "a" * 10000
|
||||
|
||||
set_platform_integration_token(long_token)
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
with platform_context(long_token):
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
|
||||
def test_empty_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
@patch('crewai.context.os.getenv')
|
||||
def test_env_var_access_error_handling(self, mock_getenv):
|
||||
mock_getenv.side_effect = OSError("Environment access error")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
get_platform_integration_token()
|
||||
|
||||
def test_context_var_isolation_between_tests(self):
|
||||
"""Test that context variable changes don't leak between test methods."""
|
||||
test_token = "isolation-test-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
set_platform_integration_token(test_token)
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
|
||||
def test_context_manager_return_value(self):
|
||||
"""Test that platform_context can be used in with statement with return value."""
|
||||
test_token = "return-value-token"
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
with platform_context(test_token) as ctx:
|
||||
assert ctx is None
|
||||
assert get_platform_integration_token() == test_token
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
@@ -14,29 +14,11 @@ from crewai.agent import Agent
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
@@ -45,9 +27,28 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ceo():
|
||||
@@ -363,7 +364,7 @@ def test_hierarchical_process(researcher, writer):
|
||||
|
||||
assert (
|
||||
result.raw
|
||||
== "**1. The Rise of Autonomous AI Agents in Daily Life** \nAs artificial intelligence technology progresses, the integration of autonomous AI agents into everyday life becomes increasingly prominent. These agents, capable of making decisions without human intervention, are reshaping industries from healthcare to finance. Exploring case studies where autonomous AI has successfully decreased operational costs or improved efficiency can reveal not only the benefits but also the ethical implications of delegating decision-making to machines. This topic offers an exciting opportunity to dive into the AI landscape, showcasing current developments such as AI assistants and autonomous vehicles.\n\n**2. Ethical Implications of Generative AI in Creative Industries** \nThe surge of generative AI tools in creative fields, such as art, music, and writing, has sparked a heated debate about authorship and originality. This article could investigate how these tools are being used by artists and creators, examining both the potential for innovation and the risk of devaluing traditional art forms. Highlighting perspectives from creators, legal experts, and ethicists could provide a comprehensive overview of the challenges faced, including copyright concerns and the emotional impact on human artists. This discussion is vital as the creative landscape evolves alongside technological advancements, making it ripe for exploration.\n\n**3. AI in Climate Change Mitigation: Current Solutions and Future Potential** \nAs the world grapples with climate change, AI technology is increasingly being harnessed to develop innovative solutions for sustainability. From predictive analytics that optimize energy consumption to machine learning algorithms that improve carbon capture methods, AI's potential in environmental science is vast. This topic invites an exploration of existing AI applications in climate initiatives, with a focus on groundbreaking research and initiatives aimed at reducing humanity's carbon footprint. Highlighting successful projects and technology partnerships can illustrate the positive impact AI can have on global climate efforts, inspiring further exploration and investment in this area.\n\n**4. The Future of Work: How AI is Reshaping Employment Landscapes** \nThe discussions around AI's impact on the workforce are both urgent and complex, as advances in automation and machine learning continue to transform the job market. This article could delve into the current trends of AI-driven job displacement alongside opportunities for upskilling and the creation of new job roles. By examining case studies of companies that integrate AI effectively and the resulting workforce adaptations, readers can gain valuable insights into preparing for a future where humans and AI collaborate. This exploration highlights the importance of policies that promote workforce resilience in the face of change.\n\n**5. Decentralized AI: Exploring the Role of Blockchain in AI Development** \nAs blockchain technology sweeps through various sectors, its application in AI development presents a fascinating topic worth examining. Decentralized AI could address issues of data privacy, security, and democratization in AI models by allowing users to retain ownership of data while benefiting from AI's capabilities. This article could analyze how decentralized networks are disrupting traditional AI development models, featuring innovative projects that harness the synergy between blockchain and AI. Highlighting potential pitfalls and the future landscape of decentralized AI could stimulate discussion among technologists, entrepreneurs, and policymakers alike.\n\nThese topics not only reflect current trends but also probe deeper into ethical and practical considerations, making them timely and relevant for contemporary audiences."
|
||||
== "1. **The Rise of Autonomous AI Agents in Daily Life** \n As artificial intelligence technology progresses, the integration of autonomous AI agents into everyday life becomes increasingly prominent. These agents, capable of making decisions without human intervention, are reshaping industries from healthcare to finance. Exploring case studies where autonomous AI has successfully decreased operational costs or improved efficiency can reveal not only the benefits but also the ethical implications of delegating decision-making to machines. This topic offers an exciting opportunity to dive into the AI landscape, showcasing current developments such as AI assistants and autonomous vehicles.\n\n2. **Ethical Implications of Generative AI in Creative Industries** \n The surge of generative AI tools in creative fields, such as art, music, and writing, has sparked a heated debate about authorship and originality. This article could investigate how these tools are being used by artists and creators, examining both the potential for innovation and the risk of devaluing traditional art forms. Highlighting perspectives from creators, legal experts, and ethicists could provide a comprehensive overview of the challenges faced, including copyright concerns and the emotional impact on human artists. This discussion is vital as the creative landscape evolves alongside technological advancements, making it ripe for exploration.\n\n3. **AI in Climate Change Mitigation: Current Solutions and Future Potential** \n As the world grapples with climate change, AI technology is increasingly being harnessed to develop innovative solutions for sustainability. From predictive analytics that optimize energy consumption to machine learning algorithms that improve carbon capture methods, AI's potential in environmental science is vast. This topic invites an exploration of existing AI applications in climate initiatives, with a focus on groundbreaking research and initiatives aimed at reducing humanity's carbon footprint. Highlighting successful projects and technology partnerships can illustrate the positive impact AI can have on global climate efforts, inspiring further exploration and investment in this area.\n\n4. **The Future of Work: How AI is Reshaping Employment Landscapes** \n The discussions around AI's impact on the workforce are both urgent and complex, as advances in automation and machine learning continue to transform the job market. This article could delve into the current trends of AI-driven job displacement alongside opportunities for upskilling and the creation of new job roles. By examining case studies of companies that integrate AI effectively and the resulting workforce adaptations, readers can gain valuable insights into preparing for a future where humans and AI collaborate. This exploration highlights the importance of policies that promote workforce resilience in the face of change.\n\n5. **Decentralized AI: Exploring the Role of Blockchain in AI Development** \n As blockchain technology sweeps through various sectors, its application in AI development presents a fascinating topic worth examining. Decentralized AI could address issues of data privacy, security, and democratization in AI models by allowing users to retain ownership of data while benefiting from AI's capabilities. This article could analyze how decentralized networks are disrupting traditional AI development models, featuring innovative projects that harness the synergy between blockchain and AI. Highlighting potential pitfalls and the future landscape of decentralized AI could stimulate discussion among technologists, entrepreneurs, and policymakers alike."
|
||||
)
|
||||
|
||||
|
||||
@@ -569,6 +570,8 @@ def test_crew_with_delegating_agents(ceo, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -581,7 +584,7 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -619,16 +622,18 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -641,7 +646,7 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -681,16 +686,18 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in new_ceo.tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_task_tools_override_agent_tools(researcher):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -703,7 +710,7 @@ def test_task_tools_override_agent_tools(researcher):
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -711,7 +718,7 @@ def test_task_tools_override_agent_tools(researcher):
|
||||
class AnotherTestTool(BaseTool):
|
||||
name: str = "Another Test Tool"
|
||||
description: str = "Another test tool"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Another processed: {query}"
|
||||
@@ -747,6 +754,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
"""
|
||||
Test that task tools override agent tools while preserving delegation tools when allow_delegation=True
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -758,7 +766,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -766,7 +774,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
class AnotherTestTool(BaseTool):
|
||||
name: str = "Another Test Tool"
|
||||
description: str = "Another test tool"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Another processed: {query}"
|
||||
@@ -807,17 +815,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Confirm AnotherTestTool is present but TestTool is not
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
|
||||
"AnotherTestTool should be present"
|
||||
)
|
||||
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"TestTool should not be present among used tools"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AnotherTestTool) for tool in used_tools
|
||||
), "AnotherTestTool should be present"
|
||||
assert not any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "TestTool should not be present among used tools"
|
||||
|
||||
# Confirm delegation tool(s) are present
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Finally, make sure the agent's original tools remain unchanged
|
||||
assert len(researcher_with_delegation.tools) == 1
|
||||
@@ -921,9 +929,9 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||
assert cache_calls[1] == expected_call, (
|
||||
f"Second call mismatch: {cache_calls[1]}"
|
||||
)
|
||||
assert (
|
||||
cache_calls[1] == expected_call
|
||||
), f"Second call mismatch: {cache_calls[1]}"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1034,7 +1042,7 @@ def test_crew_kickoff_streaming_usage_metrics():
|
||||
assert result.token_usage.cached_prompt_tokens == 0
|
||||
|
||||
|
||||
def test_agents_rpm_is_never_set_if_crew_max_rpm_is_not_set():
|
||||
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
@@ -1387,9 +1395,8 @@ def test_kickoff_for_each_error_handling():
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(Crew, "kickoff") as mock_kickoff:
|
||||
mock_kickoff.side_effect = [
|
||||
*expected_outputs[:2],
|
||||
Exception("Simulated kickoff error"),
|
||||
mock_kickoff.side_effect = expected_outputs[:2] + [
|
||||
Exception("Simulated kickoff error")
|
||||
]
|
||||
with pytest.raises(Exception, match="Simulated kickoff error"):
|
||||
crew.kickoff_for_each(inputs=inputs)
|
||||
@@ -1667,9 +1674,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(used_tools[0], CodeInterpreterTool), (
|
||||
"Tool should be CodeInterpreterTool"
|
||||
)
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1753,10 +1760,10 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
||||
assert result.raw == "Howdy!"
|
||||
|
||||
assert result.token_usage == UsageMetrics(
|
||||
total_tokens=1673,
|
||||
prompt_tokens=1562,
|
||||
completion_tokens=111,
|
||||
successful_requests=3,
|
||||
total_tokens=2390,
|
||||
prompt_tokens=2264,
|
||||
completion_tokens=126,
|
||||
successful_requests=4,
|
||||
cached_prompt_tokens=0,
|
||||
)
|
||||
|
||||
@@ -2172,7 +2179,8 @@ def test_tools_with_custom_caching():
|
||||
return first_number * second_number
|
||||
|
||||
def cache_func(args, result):
|
||||
return result % 2 == 0
|
||||
cache = result % 2 == 0
|
||||
return cache
|
||||
|
||||
multiplcation_tool.cache_function = cache_func
|
||||
|
||||
@@ -2876,7 +2884,7 @@ def test_manager_agent_with_tools_raises_exception(researcher, writer):
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Manager agent should not have tools"):
|
||||
with pytest.raises(Exception):
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
@@ -3100,7 +3108,7 @@ def test_crew_task_db_init():
|
||||
db_handler.load()
|
||||
assert True # If we reach this point, no exception was raised
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception was raised: {e!s}")
|
||||
pytest.fail(f"An exception was raised: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -3486,9 +3494,8 @@ def test_key(researcher, writer):
|
||||
process=Process.sequential,
|
||||
tasks=tasks,
|
||||
)
|
||||
hash = md5(
|
||||
f"{researcher.key}|{writer.key}|{tasks[0].key}|{tasks[1].key}".encode(),
|
||||
usedforsecurity=False,
|
||||
hash = hashlib.md5(
|
||||
f"{researcher.key}|{writer.key}|{tasks[0].key}|{tasks[1].key}".encode()
|
||||
).hexdigest()
|
||||
|
||||
assert crew.key == hash
|
||||
@@ -3527,9 +3534,8 @@ def test_key_with_interpolated_inputs():
|
||||
process=Process.sequential,
|
||||
tasks=tasks,
|
||||
)
|
||||
hash = md5(
|
||||
f"{researcher.key}|{writer.key}|{tasks[0].key}|{tasks[1].key}".encode(),
|
||||
usedforsecurity=False,
|
||||
hash = hashlib.md5(
|
||||
f"{researcher.key}|{writer.key}|{tasks[0].key}|{tasks[1].key}".encode()
|
||||
).hexdigest()
|
||||
|
||||
assert crew.key == hash
|
||||
@@ -3809,15 +3815,16 @@ def test_fetch_inputs():
|
||||
expected_placeholders = {"role_detail", "topic", "field"}
|
||||
actual_placeholders = crew.fetch_inputs()
|
||||
|
||||
assert actual_placeholders == expected_placeholders, (
|
||||
f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
)
|
||||
assert (
|
||||
actual_placeholders == expected_placeholders
|
||||
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
|
||||
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
"""
|
||||
Test that task tools don't override code execution tools when allow_code_execution=True
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
@@ -3834,7 +3841,7 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -3885,20 +3892,20 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Verify all expected tools are present
|
||||
assert any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"Task's TestTool should be present"
|
||||
)
|
||||
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
|
||||
"CodeInterpreterTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "Task's TestTool should be present"
|
||||
assert any(
|
||||
isinstance(tool, CodeInterpreterTool) for tool in used_tools
|
||||
), "CodeInterpreterTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
|
||||
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
||||
assert len(used_tools) == 4, (
|
||||
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
)
|
||||
assert (
|
||||
len(used_tools) == 4
|
||||
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -3942,9 +3949,9 @@ def test_multimodal_flag_adds_multimodal_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Check that the multimodal tool was added
|
||||
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
|
||||
"AddImageTool should be present when agent is multimodal"
|
||||
)
|
||||
assert any(
|
||||
isinstance(tool, AddImageTool) for tool in used_tools
|
||||
), "AddImageTool should be present when agent is multimodal"
|
||||
|
||||
# Verify we have exactly one tool (just the AddImageTool)
|
||||
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
||||
@@ -4208,9 +4215,9 @@ def test_crew_guardrail_feedback_in_context():
|
||||
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
||||
|
||||
# Verify that the second execution included the guardrail feedback
|
||||
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
|
||||
"Guardrail feedback should be included in retry context"
|
||||
)
|
||||
assert (
|
||||
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
|
||||
), "Guardrail feedback should be included in retry context"
|
||||
|
||||
# Verify final output meets guardrail requirements
|
||||
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
||||
@@ -4225,11 +4232,13 @@ def test_before_kickoff_callback():
|
||||
|
||||
@CrewBase
|
||||
class TestCrewClass:
|
||||
from typing import List
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
agents_config = None
|
||||
tasks_config = None
|
||||
@@ -4253,11 +4262,12 @@ def test_before_kickoff_callback():
|
||||
|
||||
@task
|
||||
def my_task(self):
|
||||
return Task(
|
||||
task = Task(
|
||||
description="Test task description",
|
||||
expected_output="Test expected output",
|
||||
agent=self.my_agent(),
|
||||
)
|
||||
return task
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
@@ -4423,46 +4433,46 @@ def test_crew_copy_with_memory():
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
||||
"Copied crew should have _short_term_memory"
|
||||
)
|
||||
assert crew_copy._short_term_memory is not None, (
|
||||
"Copied _short_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
||||
"Copied _short_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_short_term_memory"
|
||||
), "Copied crew should have _short_term_memory"
|
||||
assert (
|
||||
crew_copy._short_term_memory is not None
|
||||
), "Copied _short_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._short_term_memory) != original_short_term_id
|
||||
), "Copied _short_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
||||
"Copied crew should have _long_term_memory"
|
||||
)
|
||||
assert crew_copy._long_term_memory is not None, (
|
||||
"Copied _long_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
||||
"Copied _long_term_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_long_term_memory"
|
||||
), "Copied crew should have _long_term_memory"
|
||||
assert (
|
||||
crew_copy._long_term_memory is not None
|
||||
), "Copied _long_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._long_term_memory) != original_long_term_id
|
||||
), "Copied _long_term_memory should be a new object"
|
||||
|
||||
assert hasattr(crew_copy, "_entity_memory"), (
|
||||
"Copied crew should have _entity_memory"
|
||||
)
|
||||
assert crew_copy._entity_memory is not None, (
|
||||
"Copied _entity_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
||||
"Copied _entity_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_entity_memory"
|
||||
), "Copied crew should have _entity_memory"
|
||||
assert (
|
||||
crew_copy._entity_memory is not None
|
||||
), "Copied _entity_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._entity_memory) != original_entity_id
|
||||
), "Copied _entity_memory should be a new object"
|
||||
|
||||
if original_external_id:
|
||||
assert hasattr(crew_copy, "_external_memory"), (
|
||||
"Copied crew should have _external_memory"
|
||||
)
|
||||
assert crew_copy._external_memory is not None, (
|
||||
"Copied _external_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._external_memory) != original_external_id, (
|
||||
"Copied _external_memory should be a new object"
|
||||
)
|
||||
assert hasattr(
|
||||
crew_copy, "_external_memory"
|
||||
), "Copied crew should have _external_memory"
|
||||
assert (
|
||||
crew_copy._external_memory is not None
|
||||
), "Copied _external_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._external_memory) != original_external_id
|
||||
), "Copied _external_memory should be a new object"
|
||||
else:
|
||||
assert (
|
||||
not hasattr(crew_copy, "_external_memory")
|
||||
@@ -4725,25 +4735,21 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
) as external_memory_save:
|
||||
crew.kickoff()
|
||||
|
||||
external_memory_save.assert_called_once()
|
||||
|
||||
call_args = external_memory_save.call_args
|
||||
|
||||
assert "value" in call_args.kwargs or len(call_args.args) > 0
|
||||
assert "metadata" in call_args.kwargs or len(call_args.args) > 1
|
||||
|
||||
if "metadata" in call_args.kwargs:
|
||||
metadata = call_args.kwargs["metadata"]
|
||||
else:
|
||||
metadata = call_args.args[1]
|
||||
|
||||
assert "description" in metadata
|
||||
assert "messages" in metadata
|
||||
assert isinstance(metadata["messages"], list)
|
||||
assert len(metadata["messages"]) >= 2
|
||||
|
||||
messages = metadata["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
expected_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Researcher. You're an expert in research and you love to learn new things.\nYour personal goal is: You research about math.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\nCurrent Task: Research a topic to teach a kid aged 6 about math.\n\nThis is the expected criteria for your final answer: A topic, explanation, angle, and examples.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I now can give a great answer \nFinal Answer: \n\n**Topic: Understanding Shapes (Geometry)**\n\n**Explanation:** \nShapes are everywhere around us! They are the special forms that we can see in everyday objects. Teaching a 6-year-old about shapes is not only fun but also a way to help them think about the world around them and develop their spatial awareness. We will focus on basic shapes: circle, square, triangle, and rectangle. Understanding these shapes helps kids recognize and describe their environment.\n\n**Angle:** \nLet’s make learning about shapes an adventure! We can turn it into a treasure hunt where the child has to find objects around the house or outside that match the shapes we learn. This hands-on approach helps make the learning stick!\n\n**Examples:** \n1. **Circle:** \n - Explanation: A circle is round and has no corners. It looks like a wheel or a cookie! \n - Activity: Find objects that are circles, such as a clock, a dinner plate, or a ball. Draw a big circle on a paper and then try to draw smaller circles inside it.\n\n2. **Square:** \n - Explanation: A square has four equal sides and four corners. It looks like a box! \n - Activity: Look for squares in books, in windows, or in building blocks. Try to build a tall tower using square blocks!\n\n3. **Triangle:** \n - Explanation: A triangle has three sides and three corners. It looks like a slice of pizza or a roof! \n - Activity: Use crayons to draw a big triangle and then find things that are shaped like a triangle, like a slice of cheese or a traffic sign.\n\n4. **Rectangle:** \n - Explanation: A rectangle has four sides but only opposite sides are equal. It’s like a stretched square! \n - Activity: Search for rectangles, such as a book cover or a door. You can cut out rectangles from colored paper and create a collage!\n\nBy relating the shapes to fun activities and using real-world examples, we not only make learning more enjoyable but also help the child better remember and understand the concept of shapes in math. This foundation forms the basis of their future learning in geometry!",
|
||||
},
|
||||
]
|
||||
external_memory_save.assert_called_once_with(
|
||||
value=ANY,
|
||||
metadata={"description": ANY, "messages": expected_messages},
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, ClassVar
|
||||
from typing import List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
@@ -45,8 +44,8 @@ class InternalCrew:
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
@llm
|
||||
def local_llm(self):
|
||||
@@ -90,8 +89,7 @@ class InternalCrew:
|
||||
|
||||
@CrewBase
|
||||
class InternalCrewWithMCP(InternalCrew):
|
||||
mcp_server_params: ClassVar[dict[str, Any]] = {"host": "localhost", "port": 8000}
|
||||
mcp_connect_timeout = 120
|
||||
mcp_server_params = {"host": "localhost", "port": 8000}
|
||||
|
||||
@agent
|
||||
def reporting_analyst(self):
|
||||
@@ -202,8 +200,8 @@ def test_before_kickoff_with_none_input():
|
||||
def test_multiple_before_after_kickoff():
|
||||
@CrewBase
|
||||
class MultipleHooksCrew:
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
@@ -286,7 +284,4 @@ def test_internal_crew_with_mcp():
|
||||
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
|
||||
assert crew.researcher().tools == [simple_tool]
|
||||
|
||||
adapter_mock.assert_called_once_with(
|
||||
{"host": "localhost", "port": 8000},
|
||||
connect_timeout=120
|
||||
)
|
||||
adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000})
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class TestPromptCaching:
|
||||
"""Test prompt caching functionality."""
|
||||
|
||||
def test_llm_prompt_caching_disabled_by_default(self):
|
||||
"""Test that prompt caching is disabled by default."""
|
||||
llm = LLM(model="gpt-4o")
|
||||
assert llm.enable_prompt_caching is False
|
||||
assert llm.cache_control == {"type": "ephemeral"}
|
||||
|
||||
def test_llm_prompt_caching_enabled(self):
|
||||
"""Test that prompt caching can be enabled."""
|
||||
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||
assert llm.enable_prompt_caching is True
|
||||
|
||||
def test_llm_custom_cache_control(self):
|
||||
"""Test custom cache_control configuration."""
|
||||
custom_cache_control = {"type": "ephemeral", "ttl": 3600}
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
cache_control=custom_cache_control
|
||||
)
|
||||
assert llm.cache_control == custom_cache_control
|
||||
|
||||
def test_supports_prompt_caching_openai(self):
|
||||
"""Test prompt caching support detection for OpenAI models."""
|
||||
llm = LLM(model="gpt-4o")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_anthropic(self):
|
||||
"""Test prompt caching support detection for Anthropic models."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20240620")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_bedrock(self):
|
||||
"""Test prompt caching support detection for Bedrock models."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_deepseek(self):
|
||||
"""Test prompt caching support detection for Deepseek models."""
|
||||
llm = LLM(model="deepseek/deepseek-chat")
|
||||
assert llm._supports_prompt_caching() is True
|
||||
|
||||
def test_supports_prompt_caching_unsupported(self):
|
||||
"""Test prompt caching support detection for unsupported models."""
|
||||
llm = LLM(model="ollama/llama2")
|
||||
assert llm._supports_prompt_caching() is False
|
||||
|
||||
def test_anthropic_cache_control_formatting_string_content(self):
|
||||
"""Test that cache_control is properly formatted for Anthropic models with string content."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["type"] == "text"
|
||||
assert system_message["content"][0]["text"] == "You are a helpful assistant."
|
||||
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
user_messages = [m for m in formatted_messages if m["role"] == "user"]
|
||||
actual_user_message = user_messages[1] # Second user message is the actual one
|
||||
assert actual_user_message["content"] == "Hello, how are you?"
|
||||
|
||||
def test_anthropic_cache_control_formatting_list_content(self):
|
||||
"""Test that cache_control is properly formatted for Anthropic models with list content."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."},
|
||||
{"type": "text", "text": "Be concise and accurate."}
|
||||
]
|
||||
},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert len(system_message["content"]) == 2
|
||||
assert "cache_control" not in system_message["content"][0]
|
||||
assert system_message["content"][1]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_anthropic_multiple_system_messages_cache_control(self):
|
||||
"""Test that cache_control is only added to the last system message."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "First system message."},
|
||||
{"role": "system", "content": "Second system message."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
first_system = formatted_messages[1] # Index 1 after placeholder user message
|
||||
assert first_system["role"] == "system"
|
||||
assert first_system["content"] == "First system message."
|
||||
|
||||
second_system = formatted_messages[2] # Index 2 after placeholder user message
|
||||
assert second_system["role"] == "system"
|
||||
assert isinstance(second_system["content"], list)
|
||||
assert second_system["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_openai_prompt_caching_passthrough(self):
|
||||
"""Test that OpenAI prompt caching works without message modification."""
|
||||
llm = LLM(model="gpt-4o", enable_prompt_caching=True)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
assert formatted_messages == messages
|
||||
|
||||
def test_prompt_caching_disabled_passthrough(self):
|
||||
"""Test that when prompt caching is disabled, messages pass through with normal Anthropic formatting."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=False
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
expected_messages = [
|
||||
{"role": "user", "content": "."},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
assert formatted_messages == expected_messages
|
||||
|
||||
def test_unsupported_model_passthrough(self):
|
||||
"""Test that unsupported models pass through messages unchanged even with caching enabled."""
|
||||
llm = LLM(
|
||||
model="ollama/llama2",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
assert formatted_messages == messages
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_anthropic_cache_control_in_completion_call(self, mock_completion):
|
||||
"""Test that cache_control is properly passed to litellm.completion for Anthropic models."""
|
||||
mock_completion.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Test response"))],
|
||||
usage=Mock(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150
|
||||
)
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
llm.call(messages)
|
||||
|
||||
call_args = mock_completion.call_args[1]
|
||||
formatted_messages = call_args["messages"]
|
||||
|
||||
system_message = next(m for m in formatted_messages if m["role"] == "system")
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_crew_with_prompt_caching(self):
|
||||
"""Test that crews can use LLMs with prompt caching enabled."""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
assert crew.agents[0].llm.enable_prompt_caching is True
|
||||
|
||||
def test_bedrock_model_detection(self):
|
||||
"""Test that Bedrock models are properly detected for prompt caching."""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
enable_prompt_caching=True
|
||||
)
|
||||
|
||||
assert llm._supports_prompt_caching() is True
|
||||
assert llm.is_anthropic is False
|
||||
|
||||
def test_custom_cache_control_parameters(self):
|
||||
"""Test that custom cache_control parameters are properly stored."""
|
||||
custom_cache_control = {
|
||||
"type": "ephemeral",
|
||||
"max_age": 3600,
|
||||
"scope": "session"
|
||||
}
|
||||
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
enable_prompt_caching=True,
|
||||
cache_control=custom_cache_control
|
||||
)
|
||||
|
||||
assert llm.cache_control == custom_cache_control
|
||||
|
||||
messages = [{"role": "system", "content": "Test system message."}]
|
||||
formatted_messages = llm._format_messages_for_provider(messages)
|
||||
|
||||
system_message = formatted_messages[1]
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert system_message["content"][0]["cache_control"] == custom_cache_control
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from hashlib import md5
|
||||
from typing import Tuple, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -248,7 +248,7 @@ def test_guardrail_type_error():
|
||||
return (True, x)
|
||||
|
||||
@staticmethod
|
||||
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, str | TaskOutput]:
|
||||
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
|
||||
return (True, x)
|
||||
|
||||
obj = Object()
|
||||
@@ -271,7 +271,7 @@ def test_guardrail_type_error():
|
||||
guardrail=Object.guardrail_static_fn,
|
||||
)
|
||||
|
||||
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
|
||||
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
|
||||
return (y, x)
|
||||
|
||||
Task(
|
||||
@@ -340,7 +340,7 @@ def test_output_pydantic_hierarchical():
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert isinstance(result.pydantic, ScoreOutput)
|
||||
assert result.to_dict() == {"score": 4}
|
||||
assert result.to_dict() == {"score": 5}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -401,8 +401,8 @@ def test_output_json_hierarchical():
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert result.json == '{"score": 4}'
|
||||
assert result.to_dict() == {"score": 4}
|
||||
assert result.json == '{"score": 5}'
|
||||
assert result.to_dict() == {"score": 5}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -560,8 +560,8 @@ def test_output_json_dict_hierarchical():
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert {"score": 4} == result.json_dict
|
||||
assert result.to_dict() == {"score": 4}
|
||||
assert {"score": 5} == result.json_dict
|
||||
assert result.to_dict() == {"score": 5}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -900,11 +900,11 @@ def test_conditional_task_copy_preserves_type():
|
||||
assert isinstance(copied_conditional_task, ConditionalTask)
|
||||
|
||||
|
||||
def test_interpolate_inputs(tmp_path):
|
||||
def test_interpolate_inputs():
|
||||
task = Task(
|
||||
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
|
||||
output_file=str(tmp_path / "{topic}" / "output_{date}.txt"),
|
||||
output_file="/tmp/{topic}/output_{date}.txt",
|
||||
)
|
||||
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
@@ -915,7 +915,7 @@ def test_interpolate_inputs(tmp_path):
|
||||
== "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about AI."
|
||||
assert task.output_file == str(tmp_path / "AI" / "output_2025.txt")
|
||||
assert task.output_file == "/tmp/AI/output_2025.txt"
|
||||
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
inputs={"topic": "ML", "date": "2025"}
|
||||
@@ -925,7 +925,7 @@ def test_interpolate_inputs(tmp_path):
|
||||
== "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."
|
||||
assert task.output_file == str(tmp_path / "ML" / "output_2025.txt")
|
||||
assert task.output_file == "/tmp/ML/output_2025.txt"
|
||||
|
||||
|
||||
def test_interpolate_only():
|
||||
@@ -1074,9 +1074,8 @@ def test_key():
|
||||
description=original_description,
|
||||
expected_output=original_expected_output,
|
||||
)
|
||||
hash = md5(
|
||||
f"{original_description}|{original_expected_output}".encode(),
|
||||
usedforsecurity=False,
|
||||
hash = hashlib.md5(
|
||||
f"{original_description}|{original_expected_output}".encode()
|
||||
).hexdigest()
|
||||
|
||||
assert task.key == hash, "The key should be the hash of the description."
|
||||
@@ -1087,7 +1086,7 @@ def test_key():
|
||||
)
|
||||
|
||||
|
||||
def test_output_file_validation(tmp_path):
|
||||
def test_output_file_validation():
|
||||
"""Test output file path validation."""
|
||||
# Valid paths
|
||||
assert (
|
||||
@@ -1098,15 +1097,13 @@ def test_output_file_validation(tmp_path):
|
||||
).output_file
|
||||
== "output.txt"
|
||||
)
|
||||
# Use secure temporary path instead of /tmp
|
||||
temp_file = tmp_path / "output.txt"
|
||||
assert (
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file=str(temp_file),
|
||||
output_file="/tmp/output.txt",
|
||||
).output_file
|
||||
== str(temp_file).lstrip("/") # Remove leading slash to match expected behavior
|
||||
== "tmp/output.txt"
|
||||
)
|
||||
assert (
|
||||
Task(
|
||||
@@ -1323,7 +1320,7 @@ def test_interpolate_with_list_of_dicts():
|
||||
}
|
||||
result = interpolate_only("{people}", input_data)
|
||||
|
||||
parsed_result = ast.literal_eval(result)
|
||||
parsed_result = eval(result)
|
||||
assert isinstance(parsed_result, list)
|
||||
assert len(parsed_result) == 2
|
||||
assert parsed_result[0]["name"] == "Alice"
|
||||
@@ -1349,7 +1346,7 @@ def test_interpolate_with_nested_structures():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{company}", input_data)
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
|
||||
assert parsed["name"] == "TechCorp"
|
||||
assert len(parsed["departments"]) == 2
|
||||
@@ -1367,7 +1364,7 @@ def test_interpolate_with_special_characters():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{special_data}", input_data)
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
|
||||
assert parsed["quotes"] == """This has "double" and 'single' quotes"""
|
||||
assert parsed["unicode"] == "文字化けテスト"
|
||||
@@ -1389,7 +1386,7 @@ def test_interpolate_mixed_types():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{data}", input_data)
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
|
||||
assert parsed["name"] == "Test Dataset"
|
||||
assert parsed["samples"] == 1000
|
||||
@@ -1412,7 +1409,7 @@ def test_interpolate_complex_combination():
|
||||
]
|
||||
}
|
||||
result = interpolate_only("{report}", input_data)
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0]["month"] == "January"
|
||||
@@ -1485,7 +1482,7 @@ def test_interpolate_valid_complex_types():
|
||||
|
||||
# Should not raise any errors
|
||||
result = interpolate_only("{data}", {"data": valid_data})
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
assert parsed["name"] == "Valid Dataset"
|
||||
assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5
|
||||
|
||||
@@ -1515,7 +1512,7 @@ def test_interpolate_valid_types():
|
||||
}
|
||||
|
||||
result = interpolate_only("{data}", {"data": valid_data})
|
||||
parsed = ast.literal_eval(result)
|
||||
parsed = eval(result)
|
||||
|
||||
assert parsed["active"] is True
|
||||
assert parsed["deleted"] is False
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -37,7 +39,6 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.func == basic_function
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
@@ -49,7 +50,6 @@ def test_from_function(basic_function):
|
||||
assert tool.func == basic_function
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
@@ -62,7 +62,6 @@ def test_validate_function_signature(basic_function, schema_class):
|
||||
# Should not raise any exceptions
|
||||
tool._validate_function_signature()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(basic_function):
|
||||
"""Test asynchronous invocation"""
|
||||
@@ -71,7 +70,6 @@ async def test_ainvoke(basic_function):
|
||||
result = await tool.ainvoke(input={"param1": "test"})
|
||||
assert result == "test 0"
|
||||
|
||||
|
||||
def test_parse_args_dict(basic_function):
|
||||
"""Test parsing dictionary arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -80,7 +78,6 @@ def test_parse_args_dict(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_parse_args_string(basic_function):
|
||||
"""Test parsing string arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -89,7 +86,6 @@ def test_parse_args_string(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_complex_types():
|
||||
"""Test handling of complex parameter types"""
|
||||
|
||||
@@ -103,7 +99,6 @@ def test_complex_types():
|
||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||
assert result == "Processed 3 items with 1 nested keys"
|
||||
|
||||
|
||||
def test_schema_inheritance():
|
||||
"""Test tool creation with inherited schema"""
|
||||
|
||||
@@ -124,14 +119,13 @@ def test_schema_inheritance():
|
||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||
assert result == "test 42"
|
||||
|
||||
|
||||
def test_default_values_in_schema():
|
||||
"""Test handling of default values in schema"""
|
||||
|
||||
def default_func(
|
||||
required_param: str,
|
||||
optional_param: str = "default",
|
||||
nullable_param: int | None = None,
|
||||
nullable_param: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Test function with default values."""
|
||||
return f"{required_param} {optional_param} {nullable_param}"
|
||||
@@ -150,7 +144,6 @@ def test_default_values_in_schema():
|
||||
)
|
||||
assert result == "test custom 42"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool_decorator():
|
||||
from crewai.tools import tool
|
||||
@@ -162,7 +155,6 @@ def custom_tool_decorator():
|
||||
|
||||
return custom_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool():
|
||||
from crewai.tools import BaseTool
|
||||
@@ -177,25 +169,17 @@ def custom_tool():
|
||||
|
||||
return CustomTool()
|
||||
|
||||
|
||||
def build_simple_crew(tool):
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
agent1 = Agent(
|
||||
role="Simple role",
|
||||
goal="Simple goal",
|
||||
backstory="Simple backstory",
|
||||
tools=[tool],
|
||||
)
|
||||
agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool])
|
||||
|
||||
say_hi_task = Task(
|
||||
description="Use the custom tool result as answer.",
|
||||
agent=agent1,
|
||||
expected_output="Use the tool result",
|
||||
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
||||
)
|
||||
|
||||
return Crew(agents=[agent1], tasks=[say_hi_task])
|
||||
|
||||
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
||||
return crew
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
@@ -204,7 +188,6 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
@@ -212,7 +195,6 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_within_flow(custom_tool):
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -223,7 +205,8 @@ def test_async_tool_within_flow(custom_tool):
|
||||
@start()
|
||||
async def start(self):
|
||||
crew = build_simple_crew(custom_tool)
|
||||
return await crew.kickoff_async()
|
||||
result = await crew.kickoff_async()
|
||||
return result
|
||||
|
||||
flow = StructuredExampleFlow()
|
||||
result = flow.kickoff()
|
||||
@@ -236,141 +219,12 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
|
||||
class StructuredExampleFlow(Flow):
|
||||
from crewai.flow.flow import start
|
||||
|
||||
@start()
|
||||
async def start(self):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
return await crew.kickoff_async()
|
||||
result = await crew.kickoff_async()
|
||||
return result
|
||||
|
||||
flow = StructuredExampleFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
def test_structured_tool_invoke_calls_func_only_once():
|
||||
"""Test that CrewStructuredTool.invoke() calls the underlying function exactly once."""
|
||||
call_count = 0
|
||||
call_history = []
|
||||
|
||||
def counting_function(param: str) -> str:
|
||||
"""Function that tracks how many times it's called."""
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
call_history.append(f"Call #{call_count} with param: {param}")
|
||||
return f"Result from call #{call_count}: {param}"
|
||||
|
||||
# Create CrewStructuredTool directly
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=counting_function,
|
||||
name="direct_test_tool",
|
||||
description="Tool to test direct invoke() method",
|
||||
)
|
||||
|
||||
# Call invoke() directly - this is where the bug was
|
||||
result = tool.invoke({"param": "test_value"})
|
||||
|
||||
# Critical assertions that would catch the duplicate execution bug
|
||||
assert call_count == 1, (
|
||||
f"DUPLICATE EXECUTION BUG: Function was called {call_count} times instead of 1. "
|
||||
f"This means CrewStructuredTool.invoke() has duplicate function calls. "
|
||||
f"Call history: {call_history}"
|
||||
)
|
||||
|
||||
assert len(call_history) == 1, (
|
||||
f"Expected 1 call in history, got {len(call_history)}: {call_history}"
|
||||
)
|
||||
|
||||
assert call_history[0] == "Call #1 with param: test_value", (
|
||||
f"Expected 'Call #1 with param: test_value', got: {call_history[0]}"
|
||||
)
|
||||
|
||||
assert result == "Result from call #1: test_value", (
|
||||
f"Expected result from first call, got: {result}"
|
||||
)
|
||||
|
||||
|
||||
def test_structured_tool_invoke_multiple_calls_increment_correctly():
|
||||
"""Test multiple calls to invoke() to ensure each increments correctly."""
|
||||
call_count = 0
|
||||
|
||||
def incrementing_function(value: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return value + call_count
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=incrementing_function,
|
||||
name="incrementing_tool",
|
||||
description="Tool that increments on each call",
|
||||
)
|
||||
|
||||
result1 = tool.invoke({"value": 10})
|
||||
assert call_count == 1, (
|
||||
f"After first invoke, expected call_count=1, got {call_count}"
|
||||
)
|
||||
assert result1 == 11, f"Expected 11 (10+1), got {result1}"
|
||||
|
||||
result2 = tool.invoke({"value": 20})
|
||||
assert call_count == 2, (
|
||||
f"After second invoke, expected call_count=2, got {call_count}"
|
||||
)
|
||||
assert result2 == 22, f"Expected 22 (20+2), got {result2}"
|
||||
|
||||
result3 = tool.invoke({"value": 30})
|
||||
assert call_count == 3, (
|
||||
f"After third invoke, expected call_count=3, got {call_count}"
|
||||
)
|
||||
assert result3 == 33, f"Expected 33 (30+3), got {result3}"
|
||||
|
||||
|
||||
def test_structured_tool_invoke_with_side_effects():
|
||||
"""Test that side effects only happen once per invoke() call."""
|
||||
side_effects = []
|
||||
|
||||
def side_effect_function(action: str) -> str:
|
||||
side_effects.append(f"SIDE_EFFECT: {action} executed at call")
|
||||
return f"Action {action} completed"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=side_effect_function,
|
||||
name="side_effect_tool",
|
||||
description="Tool with observable side effects",
|
||||
)
|
||||
|
||||
result = tool.invoke({"action": "write_file"})
|
||||
|
||||
assert len(side_effects) == 1, (
|
||||
f"SIDE EFFECT BUG: Expected 1 side effect, got {len(side_effects)}. "
|
||||
f"This indicates the function was called multiple times. "
|
||||
f"Side effects: {side_effects}"
|
||||
)
|
||||
|
||||
assert side_effects[0] == "SIDE_EFFECT: write_file executed at call"
|
||||
assert result == "Action write_file completed"
|
||||
|
||||
|
||||
def test_structured_tool_invoke_exception_handling():
|
||||
"""Test that exceptions don't cause duplicate execution."""
|
||||
call_count = 0
|
||||
|
||||
def failing_function(should_fail: bool) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if should_fail:
|
||||
raise ValueError(f"Intentional failure on call #{call_count}")
|
||||
return f"Success on call #{call_count}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=failing_function, name="failing_tool", description="Tool that can fail"
|
||||
)
|
||||
|
||||
result = tool.invoke({"should_fail": False})
|
||||
assert call_count == 1, f"Expected 1 call for success case, got {call_count}"
|
||||
assert result == "Success on call #1"
|
||||
|
||||
call_count = 0
|
||||
|
||||
with pytest.raises(ValueError, match="Intentional failure on call #1"):
|
||||
tool.invoke({"should_fail": True})
|
||||
|
||||
assert call_count == 1
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
@@ -1,20 +1,17 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import (
|
||||
TraceBatchManager,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.flow.flow import Flow, start
|
||||
|
||||
|
||||
class TestTraceListenerSetup:
|
||||
@@ -284,9 +281,9 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
|
||||
assert len(trace_handlers) == 0, (
|
||||
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
)
|
||||
assert (
|
||||
len(trace_handlers) == 0
|
||||
), f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
|
||||
def test_trace_listener_setup_correctly_for_crew(self):
|
||||
"""Test that trace listener is set up correctly when enabled"""
|
||||
@@ -406,254 +403,3 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
|
||||
"""Test first-time user trace collection logic with timeout behavior"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
return_value=False,
|
||||
) as mock_prompt,
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
assert trace_listener.first_time_handler.collected_events is False
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"handle_execution_completion",
|
||||
wraps=trace_listener.first_time_handler.handle_execution_completion,
|
||||
) as mock_handle_completion,
|
||||
patch.object(
|
||||
trace_listener.batch_manager,
|
||||
"add_event",
|
||||
wraps=trace_listener.batch_manager.add_event,
|
||||
) as mock_add_event,
|
||||
):
|
||||
result = crew.kickoff()
|
||||
assert result is not None
|
||||
|
||||
assert mock_handle_completion.call_count >= 1
|
||||
assert mock_add_event.call_count >= 1
|
||||
|
||||
assert trace_listener.first_time_handler.collected_events is True
|
||||
|
||||
mock_prompt.assert_called_once_with(timeout_seconds=20)
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_user_accepts(self, mock_plus_api_calls):
|
||||
"""Test first-time user trace collection when user accepts viewing traces"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"_initialize_backend_and_send_events",
|
||||
wraps=trace_listener.first_time_handler._initialize_backend_and_send_events,
|
||||
) as mock_init_backend,
|
||||
patch.object(
|
||||
trace_listener.first_time_handler, "_display_ephemeral_trace_link"
|
||||
) as mock_display_link,
|
||||
patch.object(
|
||||
trace_listener.first_time_handler,
|
||||
"handle_execution_completion",
|
||||
wraps=trace_listener.first_time_handler.handle_execution_completion,
|
||||
) as mock_handle_completion,
|
||||
):
|
||||
trace_listener.batch_manager.ephemeral_trace_url = (
|
||||
"https://crewai.com/trace/mock-id"
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_handle_completion.call_count >= 1, (
|
||||
"handle_execution_completion should be called"
|
||||
)
|
||||
|
||||
assert trace_listener.first_time_handler.collected_events is True, (
|
||||
"Events should be marked as collected"
|
||||
)
|
||||
|
||||
mock_init_backend.assert_called_once()
|
||||
|
||||
mock_display_link.assert_called_once()
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
|
||||
"""Test the consolidation logic for first-time users vs regular tracing"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.is_first_execution",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
|
||||
assert trace_listener.first_time_handler.is_first_time is True
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Test task", expected_output="test output", agent=agent
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
|
||||
result = crew.kickoff()
|
||||
|
||||
assert mock_initialize.call_count >= 1
|
||||
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
|
||||
assert result is not None
|
||||
|
||||
def test_first_time_handler_timeout_behavior(self):
|
||||
"""Test the timeout behavior of the first-time trace prompt"""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils._is_test_environment",
|
||||
return_value=False,
|
||||
),
|
||||
patch("threading.Thread") as mock_thread,
|
||||
):
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
prompt_user_for_trace_viewing,
|
||||
)
|
||||
|
||||
mock_thread_instance = Mock()
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
mock_thread.return_value = mock_thread_instance
|
||||
|
||||
result = prompt_user_for_trace_viewing(timeout_seconds=5)
|
||||
|
||||
assert result is False
|
||||
mock_thread.assert_called_once()
|
||||
call_args = mock_thread.call_args
|
||||
assert call_args[1]["daemon"] is True
|
||||
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=5)
|
||||
mock_thread_instance.is_alive.assert_called_once()
|
||||
|
||||
def test_first_time_handler_graceful_error_handling(self):
|
||||
"""Test graceful error handling in first-time trace logic"""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.utils.should_auto_collect_first_time_traces",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.prompt_user_for_trace_viewing",
|
||||
side_effect=Exception("Prompt failed"),
|
||||
),
|
||||
patch(
|
||||
"crewai.events.listeners.tracing.first_time_trace_handler.mark_first_execution_completed"
|
||||
) as mock_mark_completed,
|
||||
):
|
||||
handler = FirstTimeTraceHandler()
|
||||
handler.is_first_time = True
|
||||
handler.collected_events = True
|
||||
|
||||
handler.handle_execution_completion()
|
||||
|
||||
mock_mark_completed.assert_called_once()
|
||||
|
||||
123
tests/utilities/test_chromadb_utils.py
Normal file
123
tests/utilities/test_chromadb_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.chromadb import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
is_ipv4_pattern,
|
||||
sanitize_collection_name,
|
||||
create_persistent_client,
|
||||
)
|
||||
|
||||
|
||||
def persistent_client_worker(path, queue):
|
||||
try:
|
||||
create_persistent_client(path=path)
|
||||
queue.put(None)
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
|
||||
class TestChromadbUtils(unittest.TestCase):
|
||||
def test_sanitize_collection_name_long_name(self):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = sanitize_collection_name(long_name)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self):
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = sanitize_collection_name(special_chars)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_sanitize_collection_name_short_name(self):
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = sanitize_collection_name(short_name)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self):
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = sanitize_collection_name(bad_ends)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_sanitize_collection_name_none(self):
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = sanitize_collection_name(None)
|
||||
self.assertEqual(sanitized, "default_collection")
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self):
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = sanitize_collection_name(ipv4)
|
||||
self.assertTrue(sanitized.startswith("ip_"))
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
self.assertTrue(all(c.isalnum() or c in ["_", "-"] for c in sanitized))
|
||||
|
||||
def test_is_ipv4_pattern(self):
|
||||
"""Test IPv4 pattern detection."""
|
||||
self.assertTrue(is_ipv4_pattern("192.168.1.1"))
|
||||
self.assertFalse(is_ipv4_pattern("not.an.ip.address"))
|
||||
|
||||
def test_sanitize_collection_name_properties(self):
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = sanitize_collection_name(test_case)
|
||||
self.assertGreaterEqual(len(sanitized), MIN_COLLECTION_LENGTH)
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_create_persistent_client_passes_args(self):
|
||||
with patch(
|
||||
"crewai.utilities.chromadb.PersistentClient"
|
||||
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_instance = MagicMock()
|
||||
mock_persistent_client.return_value = mock_instance
|
||||
|
||||
settings = Settings(allow_reset=True)
|
||||
client = create_persistent_client(path=tmpdir, settings=settings)
|
||||
|
||||
mock_persistent_client.assert_called_once_with(
|
||||
path=tmpdir, settings=settings
|
||||
)
|
||||
self.assertIs(client, mock_instance)
|
||||
|
||||
def test_create_persistent_client_process_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = [
|
||||
multiprocessing.Process(
|
||||
target=persistent_client_worker, args=(tmpdir, queue)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
errors = [queue.get(timeout=5) for _ in processes]
|
||||
self.assertTrue(all(err is None for err in errors))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user