mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-17 01:08:15 +00:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5053fae8a1 | ||
|
|
9facd96aad | ||
|
|
9acb327d9f | ||
|
|
aca0817421 | ||
|
|
4d21c6e4ad | ||
|
|
32d7b4a8d4 | ||
|
|
fb2323b3de | ||
|
|
e1d7de0dba | ||
|
|
96b07bfc84 | ||
|
|
b8d7942675 | ||
|
|
88fd859c26 | ||
|
|
3413f2e671 | ||
|
|
326ec15d54 | ||
|
|
c5a8fef118 | ||
|
|
b7af26ff60 | ||
|
|
48eb7c6937 | ||
|
|
d8e38f2f0b | ||
|
|
542afe61a8 | ||
|
|
8a5b3bc237 | ||
|
|
534f0707ca | ||
|
|
0046f9a96f | ||
|
|
e72a80be6e | ||
|
|
7cffcab84a | ||
|
|
f070ce8abd | ||
|
|
d9f6e2222f | ||
|
|
adef605410 | ||
|
|
cd42bcf035 | ||
|
|
bc45a7fbe3 | ||
|
|
87759cdb14 | ||
|
|
059cb93aeb | ||
|
|
cebc52694e |
127
.github/workflows/nightly.yml
vendored
Normal file
127
.github/workflows/nightly.yml
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
name: Nightly Canary Release
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 6 * * *' # daily at 6am UTC
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check for new commits
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
has_changes: ${{ steps.check.outputs.has_changes }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for commits in last 24h
|
||||
id: check
|
||||
run: |
|
||||
RECENT=$(git log --since="24 hours ago" --oneline | head -1)
|
||||
if [ -n "$RECENT" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "has_changes=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
build:
|
||||
name: Build nightly packages
|
||||
needs: check
|
||||
if: needs.check.outputs.has_changes == 'true' || github.event_name == 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Stamp nightly versions
|
||||
run: |
|
||||
DATE=$(date +%Y%m%d)
|
||||
for init_file in \
|
||||
lib/crewai/src/crewai/__init__.py \
|
||||
lib/crewai-tools/src/crewai_tools/__init__.py \
|
||||
lib/crewai-files/src/crewai_files/__init__.py; do
|
||||
CURRENT=$(python -c "
|
||||
import re
|
||||
text = open('$init_file').read()
|
||||
print(re.search(r'__version__\s*=\s*\"(.*?)\"\s*$', text, re.MULTILINE).group(1))
|
||||
")
|
||||
NIGHTLY="${CURRENT}.dev${DATE}"
|
||||
sed -i "s/__version__ = .*/__version__ = \"${NIGHTLY}\"/" "$init_file"
|
||||
echo "$init_file: $CURRENT -> $NIGHTLY"
|
||||
done
|
||||
|
||||
# Update cross-package dependency pins to nightly versions
|
||||
sed -i "s/\"crewai-tools==[^\"]*\"/\"crewai-tools==${NIGHTLY}\"/" lib/crewai/pyproject.toml
|
||||
sed -i "s/\"crewai==[^\"]*\"/\"crewai==${NIGHTLY}\"/" lib/crewai-tools/pyproject.toml
|
||||
echo "Updated cross-package dependency pins to ${NIGHTLY}"
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
uv build --all-packages
|
||||
rm dist/.gitignore
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish nightly to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/crewai
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.4"
|
||||
python-version: "3.12"
|
||||
enable-cache: false
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
failed=0
|
||||
for package in dist/*; do
|
||||
if [[ "$package" == *"crewai_devtools"* ]]; then
|
||||
echo "Skipping private package: $package"
|
||||
continue
|
||||
fi
|
||||
echo "Publishing $package"
|
||||
if ! uv publish "$package"; then
|
||||
echo "Failed to publish $package"
|
||||
failed=1
|
||||
fi
|
||||
done
|
||||
if [ $failed -eq 1 ]; then
|
||||
echo "Some packages failed to publish"
|
||||
exit 1
|
||||
fi
|
||||
71
.github/workflows/publish.yml
vendored
71
.github/workflows/publish.yml
vendored
@@ -59,6 +59,8 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.release_tag || github.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
@@ -93,3 +95,72 @@ jobs:
|
||||
echo "Some packages failed to publish"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build Slack payload
|
||||
if: success()
|
||||
id: slack
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
run: |
|
||||
payload=$(uv run python -c "
|
||||
import json, re, subprocess, sys
|
||||
|
||||
with open('lib/crewai/src/crewai/__init__.py') as f:
|
||||
m = re.search(r\"__version__\s*=\s*[\\\"']([^\\\"']+)\", f.read())
|
||||
version = m.group(1) if m else 'unknown'
|
||||
|
||||
import os
|
||||
tag = os.environ.get('RELEASE_TAG') or version
|
||||
|
||||
try:
|
||||
r = subprocess.run(['gh','release','view',tag,'--json','body','-q','.body'],
|
||||
capture_output=True, text=True, check=True)
|
||||
body = r.stdout.strip()
|
||||
except Exception:
|
||||
body = ''
|
||||
|
||||
blocks = [
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f':rocket: \`crewai v{version}\` published to PyPI'}},
|
||||
{'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'<https://pypi.org/project/crewai/{version}/|View on PyPI> · <https://github.com/crewAIInc/crewAI/releases/tag/{tag}|Release notes>'}},
|
||||
{'type':'divider'},
|
||||
]
|
||||
|
||||
if body:
|
||||
heading, items = '', []
|
||||
for line in body.split('\n'):
|
||||
line = line.strip()
|
||||
if not line: continue
|
||||
hm = re.match(r'^#{2,3}\s+(.*)', line)
|
||||
if hm:
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
heading, items = hm.group(1), []
|
||||
elif line.startswith('- ') or line.startswith('* '):
|
||||
items.append(re.sub(r'\*\*([^*]*)\*\*', r'*\1*', line[2:]))
|
||||
if heading and items:
|
||||
skip = heading in ('What\\'s Changed','') or 'Contributors' in heading
|
||||
if not skip:
|
||||
txt = f'*{heading}*\n' + '\n'.join(f'• {i}' for i in items)
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn','text':txt}})
|
||||
|
||||
blocks.append({'type':'divider'})
|
||||
blocks.append({'type':'section','text':{'type':'mrkdwn',
|
||||
'text':f'\`\`\`uv add \"crewai[tools]=={version}\"\`\`\`'}})
|
||||
|
||||
print(json.dumps({'blocks':blocks}))
|
||||
")
|
||||
echo "payload=$payload" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Notify Slack
|
||||
if: success()
|
||||
uses: slackapi/slack-github-action@v2.1.0
|
||||
with:
|
||||
webhook: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
webhook-type: incoming-webhook
|
||||
payload: ${{ steps.slack.outputs.payload }}
|
||||
|
||||
1457
docs/docs.json
1457
docs/docs.json
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,137 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="Mar 15, 2026">
|
||||
## v1.11.0rc1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.11.0rc1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add Plus API token authentication in a2a
|
||||
- Implement plan execute pattern
|
||||
|
||||
### Bug Fixes
|
||||
- Resolve code interpreter sandbox escape issue
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2rc2
|
||||
|
||||
## Contributors
|
||||
|
||||
@Copilot, @greysonlalonde, @lorenzejay, @theCyberTech
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 14, 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Bug Fixes
|
||||
- Remove exclusive locks from read-only storage operations
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2rc1
|
||||
|
||||
## Contributors
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 13, 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add release command and trigger PyPI publish
|
||||
|
||||
### Bug Fixes
|
||||
- Fix cross-process and thread-safe locking to unprotected I/O
|
||||
- Propagate contextvars across all thread and executor boundaries
|
||||
- Propagate ContextVars into async task threads
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.2a1
|
||||
|
||||
## Contributors
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 11, 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add support for tool search, saving tokens, and dynamically injecting appropriate tools during execution for Anthropics.
|
||||
- Introduce more Brave Search tools.
|
||||
- Create action for nightly releases.
|
||||
|
||||
### Bug Fixes
|
||||
- Fix LockException under concurrent multi-process execution.
|
||||
- Resolve issues with grouping parallel tool results in a single user message.
|
||||
- Address MCP tools resolutions and eliminate all shared mutable connections.
|
||||
- Update LLM parameter handling in the human_feedback function.
|
||||
- Add missing list/dict methods to LockedListProxy and LockedDictProxy.
|
||||
- Propagate contextvars context to parallel tool call threads.
|
||||
- Bump gitpython dependency to >=3.1.41 to resolve CVE path traversal vulnerability.
|
||||
|
||||
### Refactoring
|
||||
- Refactor memory classes to be serializable.
|
||||
|
||||
### Documentation
|
||||
- Update changelog and version for v1.10.1.
|
||||
|
||||
## Contributors
|
||||
|
||||
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Mar 04, 2026">
|
||||
## v1.10.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Upgrade Gemini GenAI
|
||||
|
||||
### Bug Fixes
|
||||
- Adjust executor listener value to avoid recursion
|
||||
- Group parallel function response parts in a single Content object in Gemini
|
||||
- Surface thought output from thinking models in Gemini
|
||||
- Load MCP and platform tools when agent tools are None
|
||||
- Support Jupyter environments with running event loops in A2A
|
||||
- Use anonymous ID for ephemeral traces
|
||||
- Conditionally pass plus header
|
||||
- Skip signal handler registration in non-main threads for telemetry
|
||||
- Inject tool errors as observations and resolve name collisions
|
||||
- Upgrade pypdf from 4.x to 6.7.4 to resolve Dependabot alerts
|
||||
- Resolve critical and high Dependabot security alerts
|
||||
|
||||
### Documentation
|
||||
- Sync Composio tool documentation across locales
|
||||
|
||||
## Contributors
|
||||
|
||||
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Feb 27, 2026">
|
||||
## v1.10.1a1
|
||||
|
||||
|
||||
@@ -219,6 +219,16 @@ CrewAI provides a wide range of events that you can listen for:
|
||||
- **ToolExecutionErrorEvent**: Emitted when a tool execution encounters an error
|
||||
- **ToolSelectionErrorEvent**: Emitted when there's an error selecting a tool
|
||||
|
||||
### MCP Events
|
||||
|
||||
- **MCPConnectionStartedEvent**: Emitted when starting to connect to an MCP server. Contains the server name, URL, transport type, connection timeout, and whether it's a reconnection attempt.
|
||||
- **MCPConnectionCompletedEvent**: Emitted when successfully connected to an MCP server. Contains the server name, connection duration in milliseconds, and whether it was a reconnection.
|
||||
- **MCPConnectionFailedEvent**: Emitted when connection to an MCP server fails. Contains the server name, error message, and error type (`timeout`, `authentication`, `network`, etc.).
|
||||
- **MCPToolExecutionStartedEvent**: Emitted when starting to execute an MCP tool. Contains the server name, tool name, and tool arguments.
|
||||
- **MCPToolExecutionCompletedEvent**: Emitted when MCP tool execution completes successfully. Contains the server name, tool name, result, and execution duration in milliseconds.
|
||||
- **MCPToolExecutionFailedEvent**: Emitted when MCP tool execution fails. Contains the server name, tool name, error message, and error type (`timeout`, `validation`, `server_error`, etc.).
|
||||
- **MCPConfigFetchFailedEvent**: Emitted when fetching an MCP server configuration fails (e.g., the MCP is not connected in your account, API error, or connection failure after config was fetched). Contains the slug, error message, and error type (`not_connected`, `api_error`, `connection_failed`).
|
||||
|
||||
### Knowledge Events
|
||||
|
||||
- **KnowledgeRetrievalStartedEvent**: Emitted when a knowledge retrieval is started
|
||||
|
||||
@@ -62,22 +62,22 @@ Use the `#` syntax to select specific tools from a server:
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key#web_search_exa"
|
||||
```
|
||||
|
||||
### CrewAI AMP Marketplace
|
||||
### Connected MCP Integrations
|
||||
|
||||
Access tools from the CrewAI AMP marketplace:
|
||||
Connect MCP servers from the CrewAI catalog or bring your own. Once connected in your account, reference them by slug:
|
||||
|
||||
```python
|
||||
# Full service with all tools
|
||||
"crewai-amp:financial-data"
|
||||
# Connected MCP with all tools
|
||||
"snowflake"
|
||||
|
||||
# Specific tool from AMP service
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
# Specific tool from a connected MCP
|
||||
"stripe#list_invoices"
|
||||
|
||||
# Multiple AMP services
|
||||
# Multiple connected MCPs
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
"snowflake",
|
||||
"stripe",
|
||||
"github"
|
||||
]
|
||||
```
|
||||
|
||||
@@ -99,10 +99,10 @@ multi_source_agent = Agent(
|
||||
"https://mcp.exa.ai/mcp?api_key=your_exa_key&profile=research",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# CrewAI AMP marketplace
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
# Connected MCPs from catalog
|
||||
"snowflake",
|
||||
"stripe#list_invoices",
|
||||
"github#search_repositories"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -147,7 +147,7 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=key", # Tools: mcp_exa_ai_*
|
||||
"https://weather.service.com/mcp", # Tools: weather_service_com_*
|
||||
"crewai-amp:financial-data" # Tools: financial_data_*
|
||||
"snowflake" # Tools: snowflake_*
|
||||
]
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ agent = Agent(
|
||||
"https://primary-server.com/mcp", # Primary data source
|
||||
"https://backup-server.com/mcp", # Backup if primary fails
|
||||
"https://unreachable-server.com/mcp", # Will be skipped with warning
|
||||
"crewai-amp:reliable-service" # Reliable AMP service
|
||||
"snowflake" # Connected MCP from catalog
|
||||
]
|
||||
)
|
||||
|
||||
@@ -254,7 +254,7 @@ agent = Agent(
|
||||
apps=["gmail", "slack"], # Platform integrations
|
||||
mcps=[ # MCP servers
|
||||
"https://mcp.exa.ai/mcp?api_key=key",
|
||||
"crewai-amp:research-tools"
|
||||
"snowflake"
|
||||
],
|
||||
|
||||
verbose=True,
|
||||
@@ -298,7 +298,7 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://primary-api.com/mcp", # Primary choice
|
||||
"https://backup-api.com/mcp", # Backup option
|
||||
"crewai-amp:reliable-service" # AMP fallback
|
||||
"snowflake" # Connected MCP fallback
|
||||
]
|
||||
```
|
||||
|
||||
@@ -311,7 +311,7 @@ agent = Agent(
|
||||
backstory="Financial analyst with access to weather data for agricultural market insights",
|
||||
mcps=[
|
||||
"https://weather.service.com/mcp#get_forecast",
|
||||
"crewai-amp:financial-data#stock_analysis"
|
||||
"stripe#list_invoices"
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
@@ -17,7 +17,7 @@ Use the `mcps` field directly on agents for seamless MCP tool integration. The D
|
||||
|
||||
#### String-Based References (Quick Setup)
|
||||
|
||||
Perfect for remote HTTPS servers and CrewAI AMP marketplace:
|
||||
Perfect for remote HTTPS servers and connected MCP integrations from the CrewAI catalog:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
@@ -29,8 +29,8 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key", # External MCP server
|
||||
"https://api.weather.com/mcp#get_forecast", # Specific tool from server
|
||||
"crewai-amp:financial-data", # CrewAI AMP marketplace
|
||||
"crewai-amp:research-tools#pubmed_search" # Specific AMP tool
|
||||
"snowflake", # Connected MCP from catalog
|
||||
"stripe#list_invoices" # Specific tool from connected MCP
|
||||
]
|
||||
)
|
||||
# MCP tools are now automatically available to your agent!
|
||||
@@ -127,7 +127,7 @@ research_agent = Agent(
|
||||
backstory="Expert researcher with access to multiple data sources",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=your_profile",
|
||||
"crewai-amp:weather-service#current_conditions"
|
||||
"snowflake#run_query"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -204,19 +204,22 @@ mcps=[
|
||||
]
|
||||
```
|
||||
|
||||
#### CrewAI AMP Marketplace
|
||||
#### Connected MCP Integrations
|
||||
|
||||
Connect MCP servers from the CrewAI catalog or bring your own. Once connected in your account, reference them by slug:
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
# Full AMP MCP service - get all available tools
|
||||
"crewai-amp:financial-data",
|
||||
# Connected MCP - get all available tools
|
||||
"snowflake",
|
||||
|
||||
# Specific tool from AMP service using # syntax
|
||||
"crewai-amp:research-tools#pubmed_search",
|
||||
# Specific tool from a connected MCP using # syntax
|
||||
"stripe#list_invoices",
|
||||
|
||||
# Multiple AMP services
|
||||
"crewai-amp:weather-service",
|
||||
"crewai-amp:market-analysis"
|
||||
# Multiple connected MCPs
|
||||
"snowflake",
|
||||
"stripe",
|
||||
"github"
|
||||
]
|
||||
```
|
||||
|
||||
@@ -299,7 +302,7 @@ from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||
mcps=[
|
||||
# String references
|
||||
"https://external-api.com/mcp", # External server
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
"snowflake", # Connected MCP from catalog
|
||||
|
||||
# Structured configurations
|
||||
MCPServerStdio(
|
||||
@@ -409,7 +412,7 @@ agent = Agent(
|
||||
# String references
|
||||
"https://reliable-server.com/mcp", # Will work
|
||||
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
||||
"crewai-amp:working-service", # Will work
|
||||
"snowflake", # Connected MCP from catalog
|
||||
|
||||
# Structured configs
|
||||
MCPServerStdio(
|
||||
|
||||
@@ -1,97 +1,316 @@
|
||||
---
|
||||
title: Brave Search
|
||||
description: The `BraveSearchTool` is designed to search the internet using the Brave Search API.
|
||||
title: Brave Search Tools
|
||||
description: A suite of tools for querying the Brave Search API — covering web, news, image, and video search.
|
||||
icon: searchengin
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
# `BraveSearchTool`
|
||||
# Brave Search Tools
|
||||
|
||||
## Description
|
||||
|
||||
This tool is designed to perform web searches using the Brave Search API. It allows you to search the internet with a specified query and retrieve relevant results. The tool supports customizable result counts and country-specific searches.
|
||||
CrewAI offers a family of Brave Search tools, each targeting a specific [Brave Search API](https://brave.com/search/api/) endpoint.
|
||||
Rather than a single catch-all tool, you can pick exactly the tool that matches the kind of results your agent needs:
|
||||
|
||||
| Tool | Endpoint | Use case |
|
||||
| --- | --- | --- |
|
||||
| `BraveWebSearchTool` | Web Search | General web results, snippets, and URLs |
|
||||
| `BraveNewsSearchTool` | News Search | Recent news articles and headlines |
|
||||
| `BraveImageSearchTool` | Image Search | Image results with dimensions and source URLs |
|
||||
| `BraveVideoSearchTool` | Video Search | Video results from across the web |
|
||||
| `BraveLocalPOIsTool` | Local POIs | Find points of interest (e.g., restaurants) |
|
||||
| `BraveLocalPOIsDescriptionTool` | Local POIs | Retrieve AI-generated location descriptions |
|
||||
| `BraveLLMContextTool` | LLM Context | Pre-extracted web content optimized for AI agents, LLM grounding, and RAG pipelines. |
|
||||
|
||||
All tools share a common base class (`BraveSearchToolBase`) that provides consistent behavior — rate limiting, automatic retries on `429` responses, header and parameter validation, and optional file saving.
|
||||
|
||||
<Note>
|
||||
The older `BraveSearchTool` class is still available for backwards compatibility, but it is considered **legacy** and will not receive the same level of attention going forward. We recommend migrating to the specific tools listed above, which offer richer configuration and a more focused interface.
|
||||
</Note>
|
||||
|
||||
<Note>
|
||||
While many tools (e.g., _BraveWebSearchTool_, _BraveNewsSearchTool_, _BraveImageSearchTool_, and _BraveVideoSearchTool_) can be used with a free Brave Search API subscription/plan, some parameters (e.g., `enable_snippets`) and tools (e.g., _BraveLocalPOIsTool_ and _BraveLocalPOIsDescriptionTool_) require a paid plan. Consult your subscription plan's capabilities for clarification.
|
||||
</Note>
|
||||
|
||||
## Installation
|
||||
|
||||
To incorporate this tool into your project, follow the installation instructions below:
|
||||
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Steps to Get Started
|
||||
## Getting Started
|
||||
|
||||
To effectively use the `BraveSearchTool`, follow these steps:
|
||||
1. **Install the package** — confirm that `crewai[tools]` is installed in your Python environment.
|
||||
2. **Get an API key** — sign up at [api-dashboard.search.brave.com/login](https://api-dashboard.search.brave.com/login) to generate a key.
|
||||
3. **Set the environment variable** — store your key as `BRAVE_API_KEY`, or pass it directly via the `api_key` parameter.
|
||||
|
||||
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
|
||||
2. **API Key Acquisition**: Acquire a Brave Search API key at https://api.search.brave.com/app/keys (sign in to generate a key).
|
||||
3. **Environment Configuration**: Store your obtained API key in an environment variable named `BRAVE_API_KEY` to facilitate its use by the tool.
|
||||
## Quick Examples
|
||||
|
||||
## Example
|
||||
|
||||
The following example demonstrates how to initialize the tool and execute a search with a given query:
|
||||
### Web Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
# Initialize the tool for internet searching capabilities
|
||||
tool = BraveSearchTool()
|
||||
|
||||
# Execute a search
|
||||
results = tool.run(search_query="CrewAI agent framework")
|
||||
tool = BraveWebSearchTool()
|
||||
results = tool.run(q="CrewAI agent framework")
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
The `BraveSearchTool` accepts the following parameters:
|
||||
|
||||
- **search_query**: Mandatory. The search query you want to use to search the internet.
|
||||
- **country**: Optional. Specify the country for the search results. Default is empty string.
|
||||
- **n_results**: Optional. Number of search results to return. Default is `10`.
|
||||
- **save_file**: Optional. Whether to save the search results to a file. Default is `False`.
|
||||
|
||||
## Example with Parameters
|
||||
|
||||
Here is an example demonstrating how to use the tool with additional parameters:
|
||||
### News Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveNewsSearchTool
|
||||
|
||||
# Initialize the tool with custom parameters
|
||||
tool = BraveSearchTool(
|
||||
country="US",
|
||||
n_results=5,
|
||||
save_file=True
|
||||
tool = BraveNewsSearchTool()
|
||||
results = tool.run(q="latest AI breakthroughs")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Image Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveImageSearchTool
|
||||
|
||||
tool = BraveImageSearchTool()
|
||||
results = tool.run(q="northern lights photography")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Video Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveVideoSearchTool
|
||||
|
||||
tool = BraveVideoSearchTool()
|
||||
results = tool.run(q="how to build AI agents")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Location POI Descriptions
|
||||
|
||||
```python Code
|
||||
from crewai_tools import (
|
||||
BraveWebSearchTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
)
|
||||
|
||||
# Execute a search
|
||||
results = tool.run(search_query="Latest AI developments")
|
||||
print(results)
|
||||
web_search = BraveWebSearchTool(raw=True)
|
||||
poi_details = BraveLocalPOIsDescriptionTool()
|
||||
|
||||
results = web_search.run(q="italian restaurants in pensacola, florida")
|
||||
|
||||
if "locations" in results:
|
||||
location_ids = [ loc["id"] for loc in results["locations"]["results"] ]
|
||||
if location_ids:
|
||||
descriptions = poi_details.run(ids=location_ids)
|
||||
print(descriptions)
|
||||
```
|
||||
|
||||
## Common Constructor Parameters
|
||||
|
||||
Every Brave Search tool accepts the following parameters at initialization:
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `api_key` | `str \| None` | `None` | Brave API key. Falls back to the `BRAVE_API_KEY` environment variable. |
|
||||
| `headers` | `dict \| None` | `None` | Additional HTTP headers to send with every request (e.g., `api-version`, geolocation headers). |
|
||||
| `requests_per_second` | `float` | `1.0` | Maximum request rate. The tool will sleep between calls to stay within this limit. |
|
||||
| `save_file` | `bool` | `False` | When `True`, each response is written to a timestamped `.txt` file. |
|
||||
| `raw` | `bool` | `False` | When `True`, the full API JSON response is returned without any refinement. |
|
||||
| `timeout` | `int` | `30` | HTTP request timeout in seconds. |
|
||||
| `country` | `str \| None` | `None` | Legacy shorthand for geo-targeting (e.g., `"US"`). Prefer using the `country` query parameter directly. |
|
||||
| `n_results` | `int` | `10` | Legacy shorthand for result count. Prefer using the `count` query parameter directly. |
|
||||
|
||||
<Warning>
|
||||
The `country` and `n_results` constructor parameters exist for backwards compatibility. They are applied as defaults when the corresponding query parameters (`country`, `count`) are not provided at call time. For new code, we recommend passing `country` and `count` directly as query parameters instead.
|
||||
</Warning>
|
||||
|
||||
## Query Parameters
|
||||
|
||||
Each tool validates its query parameters against a Pydantic schema before sending the request.
|
||||
The parameters vary slightly per endpoint — here is a summary of the most commonly used ones:
|
||||
|
||||
### BraveWebSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting (e.g., `"US"`). |
|
||||
| `search_lang` | Two-letter language code for results (e.g., `"en"`). |
|
||||
| `count` | Max number of results to return (1–20). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter: `"pd"` (past day), `"pw"` (past week), `"pm"` (past month), `"py"` (past year), or a date range like `"2025-01-01to2025-06-01"`. |
|
||||
| `extra_snippets` | Include up to 5 additional text snippets per result. |
|
||||
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Web Search API documentation](https://api-dashboard.search.brave.com/api-reference/web/search/get).
|
||||
|
||||
### BraveNewsSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–50). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter (same options as Web Search). |
|
||||
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave News Search API documentation](https://api-dashboard.search.brave.com/api-reference/news/news_search/get).
|
||||
|
||||
### BraveImageSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–200). |
|
||||
| `safesearch` | Content filter: `"off"` or `"strict"`. |
|
||||
| `spellcheck` | Attempt to correct spelling errors in the query. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Image Search API documentation](https://api-dashboard.search.brave.com/api-reference/images/image_search).
|
||||
|
||||
### BraveVideoSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–50). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter (same options as Web Search). |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Video Search API documentation](https://api-dashboard.search.brave.com/api-reference/videos/video_search/get).
|
||||
|
||||
### BraveLocalPOIsTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
|
||||
For the complete parameter and header reference, see [Brave Local POIs API documentation](https://api-dashboard.search.brave.com/api-reference/web/local_pois).
|
||||
|
||||
### BraveLocalPOIsDescriptionTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
|
||||
|
||||
For the complete parameter and header reference, see [Brave POI Descriptions API documentation](https://api-dashboard.search.brave.com/api-reference/web/poi_descriptions).
|
||||
|
||||
## Custom Headers
|
||||
|
||||
All tools support custom HTTP request headers. The Web Search tool, for example, accepts geolocation headers for location-aware results:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(
|
||||
headers={
|
||||
"x-loc-lat": "37.7749",
|
||||
"x-loc-long": "-122.4194",
|
||||
"x-loc-city": "San Francisco",
|
||||
"x-loc-state": "CA",
|
||||
"x-loc-country": "US",
|
||||
}
|
||||
)
|
||||
|
||||
results = tool.run(q="best coffee shops nearby")
|
||||
```
|
||||
|
||||
You can also update headers after initialization using the `set_headers()` method:
|
||||
|
||||
```python Code
|
||||
tool.set_headers({"api-version": "2025-01-01"})
|
||||
```
|
||||
|
||||
## Raw Mode
|
||||
|
||||
By default, each tool refines the API response into a concise list of results. If you need the full, unprocessed API response, enable raw mode:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(raw=True)
|
||||
full_response = tool.run(q="Brave Search API")
|
||||
```
|
||||
|
||||
## Agent Integration Example
|
||||
|
||||
Here's how to integrate the `BraveSearchTool` with a CrewAI agent:
|
||||
Here's how to equip a CrewAI agent with multiple Brave Search tools:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent
|
||||
from crewai.project import agent
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveWebSearchTool, BraveNewsSearchTool
|
||||
|
||||
# Initialize the tool
|
||||
brave_search_tool = BraveSearchTool()
|
||||
web_search = BraveWebSearchTool()
|
||||
news_search = BraveNewsSearchTool()
|
||||
|
||||
# Define an agent with the BraveSearchTool
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["researcher"],
|
||||
allow_delegation=False,
|
||||
tools=[brave_search_tool]
|
||||
tools=[web_search, news_search],
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Example
|
||||
|
||||
Combining multiple parameters for a targeted search:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(
|
||||
requests_per_second=0.5, # conservative rate limit
|
||||
save_file=True,
|
||||
)
|
||||
|
||||
results = tool.run(
|
||||
q="artificial intelligence news",
|
||||
country="US",
|
||||
search_lang="en",
|
||||
count=5,
|
||||
freshness="pm", # past month only
|
||||
extra_snippets=True,
|
||||
)
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Migrating from `BraveSearchTool` (Legacy)
|
||||
|
||||
If you are currently using `BraveSearchTool`, switching to the new tools is straightforward:
|
||||
|
||||
```python Code
|
||||
# Before (legacy)
|
||||
from crewai_tools import BraveSearchTool
|
||||
|
||||
tool = BraveSearchTool(country="US", n_results=5, save_file=True)
|
||||
results = tool.run(search_query="AI agents")
|
||||
|
||||
# After (recommended)
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(save_file=True)
|
||||
results = tool.run(q="AI agents", country="US", count=5)
|
||||
```
|
||||
|
||||
Key differences:
|
||||
- **Import**: Use `BraveWebSearchTool` (or the news/image/video variant) instead of `BraveSearchTool`.
|
||||
- **Query parameter**: Use `q` instead of `search_query`. (Both `search_query` and `query` are still accepted for convenience, but `q` is the preferred parameter.)
|
||||
- **Result count**: Pass `count` as a query parameter instead of `n_results` at init time.
|
||||
- **Country**: Pass `country` as a query parameter instead of at init time.
|
||||
- **API key**: Can now be passed directly via `api_key=` in addition to the `BRAVE_API_KEY` environment variable.
|
||||
- **Rate limiting**: Configurable via `requests_per_second` with automatic retry on `429` responses.
|
||||
|
||||
## Conclusion
|
||||
|
||||
By integrating the `BraveSearchTool` into Python projects, users gain the ability to conduct real-time, relevant searches across the internet directly from their applications. The tool provides a simple interface to the powerful Brave Search API, making it easy to retrieve and process search results programmatically. By adhering to the setup and usage guidelines provided, incorporating this tool into projects is streamlined and straightforward.
|
||||
The Brave Search tool suite gives your CrewAI agents flexible, endpoint-specific access to the Brave Search API. Whether you need web pages, breaking news, images, or videos, there is a dedicated tool with validated parameters and built-in resilience. Pick the tool that fits your use case, and refer to the [Brave Search API documentation](https://brave.com/search/api/) for the full details on available parameters and response formats.
|
||||
|
||||
@@ -4,6 +4,137 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="2026년 3월 15일">
|
||||
## v1.11.0rc1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.11.0rc1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- Plus API 토큰 인증 추가
|
||||
- 에서 계획 실행 패턴 구현
|
||||
|
||||
### 버그 수정
|
||||
- 코드 인터프리터 샌드박스 탈출 문제 해결
|
||||
|
||||
### 문서
|
||||
- v1.10.2rc2의 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@Copilot, @greysonlalonde, @lorenzejay, @theCyberTech
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 14일">
|
||||
## v1.10.2rc2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 버그 수정
|
||||
- 읽기 전용 스토리지 작업에서 독점 잠금 제거
|
||||
|
||||
### 문서
|
||||
- v1.10.2rc1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 13일">
|
||||
## v1.10.2rc1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- 릴리스 명령 추가 및 PyPI 게시 트리거
|
||||
|
||||
### 버그 수정
|
||||
- 보호되지 않은 I/O에 대한 프로세스 간 및 스레드 안전 잠금 수정
|
||||
- 모든 스레드 및 실행기 경계를 넘는 contextvars 전파
|
||||
- async 작업 스레드로 ContextVars 전파
|
||||
|
||||
### 문서
|
||||
- v1.10.2a1에 대한 변경 로그 및 버전 업데이트
|
||||
|
||||
## 기여자
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 11일">
|
||||
## v1.10.2a1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- Anthropics에 대한 도구 검색 지원 추가, 토큰 저장, 실행 중 적절한 도구를 동적으로 주입하는 기능 추가.
|
||||
- 더 많은 Brave Search 도구 도입.
|
||||
- 야간 릴리스를 위한 액션 생성.
|
||||
|
||||
### 버그 수정
|
||||
- 동시 다중 프로세스 실행 중 LockException 수정.
|
||||
- 단일 사용자 메시지에서 병렬 도구 결과 그룹화 문제 해결.
|
||||
- MCP 도구 해상도 문제 해결 및 모든 공유 가변 연결 제거.
|
||||
- human_feedback 함수에서 LLM 매개변수 처리 업데이트.
|
||||
- LockedListProxy 및 LockedDictProxy에 누락된 list/dict 메서드 추가.
|
||||
- 병렬 도구 호출 스레드에 contextvars 컨텍스트 전파.
|
||||
- CVE 경로 탐색 취약점을 해결하기 위해 gitpython 의존성을 >=3.1.41로 업데이트.
|
||||
|
||||
### 리팩토링
|
||||
- 메모리 클래스를 직렬화 가능하도록 리팩토링.
|
||||
|
||||
### 문서
|
||||
- v1.10.1에 대한 변경 로그 및 버전 업데이트.
|
||||
|
||||
## 기여자
|
||||
|
||||
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 3월 4일">
|
||||
## v1.10.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
### 기능
|
||||
- Gemini GenAI 업그레이드
|
||||
|
||||
### 버그 수정
|
||||
- 재귀를 피하기 위해 실행기 리스너 값을 조정
|
||||
- Gemini에서 병렬 함수 응답 부분을 단일 Content 객체로 그룹화
|
||||
- Gemini에서 사고 모델의 사고 출력을 표시
|
||||
- 에이전트 도구가 None일 때 MCP 및 플랫폼 도구 로드
|
||||
- A2A에서 실행 이벤트 루프가 있는 Jupyter 환경 지원
|
||||
- 일시적인 추적을 위해 익명 ID 사용
|
||||
- 조건부로 플러스 헤더 전달
|
||||
- 원격 측정을 위해 비주 스레드에서 신호 처리기 등록 건너뛰기
|
||||
- 도구 오류를 관찰로 주입하고 이름 충돌 해결
|
||||
- Dependabot 경고를 해결하기 위해 pypdf를 4.x에서 6.7.4로 업그레이드
|
||||
- 심각 및 높은 Dependabot 보안 경고 해결
|
||||
|
||||
### 문서
|
||||
- Composio 도구 문서를 지역별로 동기화
|
||||
|
||||
## 기여자
|
||||
|
||||
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2026년 2월 27일">
|
||||
## v1.10.1a1
|
||||
|
||||
|
||||
@@ -62,22 +62,22 @@ agent = Agent(
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key#web_search_exa"
|
||||
```
|
||||
|
||||
### CrewAI AMP 마켓플레이스
|
||||
### 연결된 MCP 통합
|
||||
|
||||
CrewAI AMP 마켓플레이스의 도구에 액세스하세요:
|
||||
CrewAI 카탈로그에서 MCP 서버를 연결하거나 직접 가져올 수 있습니다. 계정에 연결한 후 슬러그로 참조하세요:
|
||||
|
||||
```python
|
||||
# 모든 도구가 포함된 전체 서비스
|
||||
"crewai-amp:financial-data"
|
||||
# 모든 도구가 포함된 연결된 MCP
|
||||
"snowflake"
|
||||
|
||||
# AMP 서비스의 특정 도구
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
# 연결된 MCP의 특정 도구
|
||||
"stripe#list_invoices"
|
||||
|
||||
# 다중 AMP 서비스
|
||||
# 여러 연결된 MCP
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
"snowflake",
|
||||
"stripe",
|
||||
"github"
|
||||
]
|
||||
```
|
||||
|
||||
@@ -99,10 +99,10 @@ multi_source_agent = Agent(
|
||||
"https://mcp.exa.ai/mcp?api_key=your_exa_key&profile=research",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# CrewAI AMP 마켓플레이스
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
# 카탈로그에서 연결된 MCP
|
||||
"snowflake",
|
||||
"stripe#list_invoices",
|
||||
"github#search_repositories"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -154,7 +154,7 @@ agent = Agent(
|
||||
"https://reliable-server.com/mcp", # 작동할 것
|
||||
"https://unreachable-server.com/mcp", # 우아하게 건너뛸 것
|
||||
"https://slow-server.com/mcp", # 우아하게 타임아웃될 것
|
||||
"crewai-amp:working-service" # 작동할 것
|
||||
"snowflake" # 카탈로그에서 연결된 MCP
|
||||
]
|
||||
)
|
||||
# 에이전트는 작동하는 서버의 도구를 사용하고 실패한 서버에 대한 경고를 로그에 남깁니다
|
||||
@@ -229,6 +229,6 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://primary-api.com/mcp", # 주요 선택
|
||||
"https://backup-api.com/mcp", # 백업 옵션
|
||||
"crewai-amp:reliable-service" # AMP 폴백
|
||||
"snowflake" # 연결된 MCP 폴백
|
||||
]
|
||||
```
|
||||
|
||||
@@ -25,8 +25,8 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key", # 외부 MCP 서버
|
||||
"https://api.weather.com/mcp#get_forecast", # 서버의 특정 도구
|
||||
"crewai-amp:financial-data", # CrewAI AMP 마켓플레이스
|
||||
"crewai-amp:research-tools#pubmed_search" # 특정 AMP 도구
|
||||
"snowflake", # 카탈로그에서 연결된 MCP
|
||||
"stripe#list_invoices" # 연결된 MCP의 특정 도구
|
||||
]
|
||||
)
|
||||
# MCP 도구들이 이제 자동으로 에이전트에서 사용 가능합니다!
|
||||
|
||||
@@ -4,6 +4,137 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="15 mar 2026">
|
||||
## v1.11.0rc1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.11.0rc1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Funcionalidades
|
||||
- Adicionar autenticação de token da API Plus
|
||||
- Implementar padrão de execução de plano
|
||||
|
||||
### Correções de Bugs
|
||||
- Resolver problema de escape do sandbox do interpretador de código
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2rc2
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@Copilot, @greysonlalonde, @lorenzejay, @theCyberTech
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="14 mar 2026">
|
||||
## v1.10.2rc2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Correções de Bugs
|
||||
- Remover bloqueios exclusivos de operações de armazenamento somente leitura
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2rc1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="13 mar 2026">
|
||||
## v1.10.2rc1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2rc1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
### Funcionalidades
|
||||
- Adicionar comando de lançamento e acionar publicação no PyPI
|
||||
|
||||
### Correções de Bugs
|
||||
- Corrigir bloqueio seguro entre processos e threads para I/O não protegido
|
||||
- Propagar contextvars através de todos os limites de thread e executor
|
||||
- Propagar ContextVars para threads de tarefas assíncronas
|
||||
|
||||
### Documentação
|
||||
- Atualizar changelog e versão para v1.10.2a1
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@danglies007, @greysonlalonde
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 mar 2026">
|
||||
## v1.10.2a1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.2a1)
|
||||
|
||||
## O que mudou
|
||||
|
||||
### Recursos
|
||||
- Adicionar suporte para busca de ferramentas, salvamento de tokens e injeção dinâmica de ferramentas apropriadas durante a execução para Anthropics.
|
||||
- Introduzir mais ferramentas de Busca Brave.
|
||||
- Criar ação para lançamentos noturnos.
|
||||
|
||||
### Correções de Bugs
|
||||
- Corrigir LockException durante a execução concorrente de múltiplos processos.
|
||||
- Resolver problemas com a agrupação de resultados de ferramentas paralelas em uma única mensagem de usuário.
|
||||
- Abordar resoluções de ferramentas MCP e eliminar todas as conexões mutáveis compartilhadas.
|
||||
- Atualizar o manuseio de parâmetros LLM na função human_feedback.
|
||||
- Adicionar métodos de lista/dicionário ausentes a LockedListProxy e LockedDictProxy.
|
||||
- Propagar o contexto de contextvars para as threads de chamada de ferramentas paralelas.
|
||||
- Atualizar a dependência gitpython para >=3.1.41 para resolver a vulnerabilidade de travessia de diretórios CVE.
|
||||
|
||||
### Refatoração
|
||||
- Refatorar classes de memória para serem serializáveis.
|
||||
|
||||
### Documentação
|
||||
- Atualizar o changelog e a versão para v1.10.1.
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@akaKuruma, @github-actions[bot], @giulio-leone, @greysonlalonde, @joaomdmoura, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="04 mar 2026">
|
||||
## v1.10.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.1)
|
||||
|
||||
## O que mudou
|
||||
|
||||
### Recursos
|
||||
- Atualizar Gemini GenAI
|
||||
|
||||
### Correções de Bugs
|
||||
- Ajustar o valor do listener do executor para evitar recursão
|
||||
- Agrupar partes da resposta da função paralela em um único objeto Content no Gemini
|
||||
- Exibir a saída de pensamento dos modelos de pensamento no Gemini
|
||||
- Carregar ferramentas MCP e da plataforma quando as ferramentas do agente forem None
|
||||
- Suportar ambientes Jupyter com loops de eventos em A2A
|
||||
- Usar ID anônimo para rastreamentos efêmeros
|
||||
- Passar condicionalmente o cabeçalho plus
|
||||
- Ignorar o registro do manipulador de sinal em threads não principais para telemetria
|
||||
- Injetar erros de ferramentas como observações e resolver colisões de nomes
|
||||
- Atualizar pypdf de 4.x para 6.7.4 para resolver alertas do Dependabot
|
||||
- Resolver alertas de segurança críticos e altos do Dependabot
|
||||
|
||||
### Documentação
|
||||
- Sincronizar a documentação da ferramenta Composio entre locais
|
||||
|
||||
## Contribuidores
|
||||
|
||||
@giulio-leone, @greysonlalonde, @haxzie, @joaomdmoura, @lorenzejay, @mattatcha, @mplachta, @nicoferdi96
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="27 fev 2026">
|
||||
## v1.10.1a1
|
||||
|
||||
|
||||
@@ -62,22 +62,22 @@ Use a sintaxe `#` para selecionar ferramentas específicas de um servidor:
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave#web_search_exa"
|
||||
```
|
||||
|
||||
### Marketplace CrewAI AMP
|
||||
### Integrações MCP Conectadas
|
||||
|
||||
Acesse ferramentas do marketplace CrewAI AMP:
|
||||
Conecte servidores MCP do catálogo CrewAI ou traga os seus próprios. Uma vez conectados em sua conta, referencie-os pelo slug:
|
||||
|
||||
```python
|
||||
# Serviço completo com todas as ferramentas
|
||||
"crewai-amp:financial-data"
|
||||
# MCP conectado com todas as ferramentas
|
||||
"snowflake"
|
||||
|
||||
# Ferramenta específica do serviço AMP
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
# Ferramenta específica de um MCP conectado
|
||||
"stripe#list_invoices"
|
||||
|
||||
# Múltiplos serviços AMP
|
||||
# Múltiplos MCPs conectados
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
"snowflake",
|
||||
"stripe",
|
||||
"github"
|
||||
]
|
||||
```
|
||||
|
||||
@@ -99,10 +99,10 @@ agente_multi_fonte = Agent(
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave_exa&profile=pesquisa",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# Marketplace CrewAI AMP
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
# MCPs conectados do catálogo
|
||||
"snowflake",
|
||||
"stripe#list_invoices",
|
||||
"github#search_repositories"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -154,7 +154,7 @@ agente = Agent(
|
||||
"https://servidor-confiavel.com/mcp", # Vai funcionar
|
||||
"https://servidor-inalcancavel.com/mcp", # Será ignorado graciosamente
|
||||
"https://servidor-lento.com/mcp", # Timeout gracioso
|
||||
"crewai-amp:servico-funcionando" # Vai funcionar
|
||||
"snowflake" # MCP conectado do catálogo
|
||||
]
|
||||
)
|
||||
# O agente usará ferramentas de servidores funcionais e registrará avisos para os que falharem
|
||||
@@ -229,6 +229,6 @@ agente = Agent(
|
||||
mcps=[
|
||||
"https://api-principal.com/mcp", # Escolha principal
|
||||
"https://api-backup.com/mcp", # Opção de backup
|
||||
"crewai-amp:servico-confiavel" # Fallback AMP
|
||||
"snowflake" # Fallback MCP conectado
|
||||
]
|
||||
```
|
||||
|
||||
@@ -25,8 +25,8 @@ agent = Agent(
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave", # Servidor MCP externo
|
||||
"https://api.weather.com/mcp#get_forecast", # Ferramenta específica do servidor
|
||||
"crewai-amp:financial-data", # Marketplace CrewAI AMP
|
||||
"crewai-amp:research-tools#pubmed_search" # Ferramenta AMP específica
|
||||
"snowflake", # MCP conectado do catálogo
|
||||
"stripe#list_invoices" # Ferramenta específica de MCP conectado
|
||||
]
|
||||
)
|
||||
# Ferramentas MCP agora estão automaticamente disponíveis para seu agente!
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.11.0rc1"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.10.1",
|
||||
"crewai==1.11.0rc1",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
@@ -108,7 +108,7 @@ stagehand = [
|
||||
"stagehand>=0.4.1",
|
||||
]
|
||||
github = [
|
||||
"gitpython==3.1.38",
|
||||
"gitpython>=3.1.41,<4",
|
||||
"PyGithub==1.59.1",
|
||||
]
|
||||
rag = [
|
||||
|
||||
@@ -10,7 +10,18 @@ from crewai_tools.aws.s3.writer_tool import S3WriterTool
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
|
||||
BrightDataDatasetTool,
|
||||
)
|
||||
@@ -200,7 +211,14 @@ __all__ = [
|
||||
"ArxivPaperTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
@@ -291,4 +309,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.11.0rc1"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from lancedb import ( # type: ignore[import-untyped]
|
||||
DBConnection as LanceDBConnection,
|
||||
connect as lancedb_connect,
|
||||
@@ -33,10 +35,12 @@ class LanceDBAdapter(Adapter):
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
self._lock_name = f"lancedb:{os.path.realpath(str(self.uri))}"
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
@@ -56,4 +60,5 @@ class LanceDBAdapter(Adapter):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._table.add(*args, **kwargs)
|
||||
with store_lock(self._lock_name):
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -18,6 +21,9 @@ class BrowserSessionManager:
|
||||
This class maintains separate browser sessions for different threads,
|
||||
enabling concurrent usage of browsers in multi-threaded environments.
|
||||
Browsers are created lazily only when needed by tools.
|
||||
|
||||
Uses per-key events to serialize creation for the same thread_id without
|
||||
blocking unrelated callers or wasting resources on duplicate sessions.
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
@@ -27,8 +33,10 @@ class BrowserSessionManager:
|
||||
region: AWS region for browser client
|
||||
"""
|
||||
self.region = region
|
||||
self._lock = threading.Lock()
|
||||
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
|
||||
self._creating: dict[str, threading.Event] = {}
|
||||
|
||||
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||
"""Get or create an async browser for the specified thread.
|
||||
@@ -39,10 +47,29 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
An async browser instance specific to the thread
|
||||
"""
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._async_sessions:
|
||||
return self._async_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
ctx = contextvars.copy_context()
|
||||
await loop.run_in_executor(None, ctx.run, event.wait)
|
||||
|
||||
return await self._create_async_browser_session(thread_id)
|
||||
try:
|
||||
browser_client, browser = await self._create_async_browser_session(
|
||||
thread_id
|
||||
)
|
||||
with self._lock:
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
return browser
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
|
||||
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
|
||||
"""Get or create a sync browser for the specified thread.
|
||||
@@ -53,19 +80,33 @@ class BrowserSessionManager:
|
||||
Returns:
|
||||
A sync browser instance specific to the thread
|
||||
"""
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
while True:
|
||||
with self._lock:
|
||||
if thread_id in self._sync_sessions:
|
||||
return self._sync_sessions[thread_id][1]
|
||||
if thread_id not in self._creating:
|
||||
self._creating[thread_id] = threading.Event()
|
||||
break
|
||||
event = self._creating[thread_id]
|
||||
event.wait()
|
||||
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
try:
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
finally:
|
||||
with self._lock:
|
||||
evt = self._creating.pop(thread_id)
|
||||
evt.set()
|
||||
|
||||
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
|
||||
async def _create_async_browser_session(
|
||||
self, thread_id: str
|
||||
) -> tuple[BrowserClient, AsyncBrowser]:
|
||||
"""Create a new async browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
|
||||
Returns:
|
||||
The newly created async browser instance
|
||||
Tuple of (BrowserClient, AsyncBrowser).
|
||||
|
||||
Raises:
|
||||
Exception: If browser session creation fails
|
||||
@@ -75,10 +116,8 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -87,7 +126,6 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = await async_playwright().start()
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -96,17 +134,13 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to async browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
# Store session resources
|
||||
self._async_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
return browser_client, browser
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create async browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -132,10 +166,8 @@ class BrowserSessionManager:
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
# Start browser session
|
||||
browser_client.start()
|
||||
|
||||
# Get WebSocket connection info
|
||||
ws_url, headers = browser_client.generate_ws_headers()
|
||||
|
||||
logger.info(
|
||||
@@ -144,7 +176,6 @@ class BrowserSessionManager:
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
# Connect to browser using Playwright
|
||||
playwright = sync_playwright().start()
|
||||
browser = playwright.chromium.connect_over_cdp(
|
||||
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||
@@ -153,8 +184,8 @@ class BrowserSessionManager:
|
||||
f"Successfully connected to sync browser for thread {thread_id}"
|
||||
)
|
||||
|
||||
# Store session resources
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
with self._lock:
|
||||
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||
|
||||
return browser
|
||||
|
||||
@@ -163,7 +194,6 @@ class BrowserSessionManager:
|
||||
f"Failed to create sync browser session for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Clean up resources if session creation fails
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -178,13 +208,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
with self._lock:
|
||||
if thread_id not in self._async_sessions:
|
||||
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._async_sessions[thread_id]
|
||||
browser_client, browser = self._async_sessions.pop(thread_id)
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
@@ -193,7 +223,6 @@ class BrowserSessionManager:
|
||||
f"Error closing async browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -202,8 +231,6 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._async_sessions[thread_id]
|
||||
logger.info(f"Async browser session cleaned up for thread {thread_id}")
|
||||
|
||||
def close_sync_browser(self, thread_id: str) -> None:
|
||||
@@ -212,13 +239,13 @@ class BrowserSessionManager:
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
"""
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
with self._lock:
|
||||
if thread_id not in self._sync_sessions:
|
||||
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||
return
|
||||
|
||||
browser_client, browser = self._sync_sessions[thread_id]
|
||||
browser_client, browser = self._sync_sessions.pop(thread_id)
|
||||
|
||||
# Close browser
|
||||
if browser:
|
||||
try:
|
||||
browser.close()
|
||||
@@ -227,7 +254,6 @@ class BrowserSessionManager:
|
||||
f"Error closing sync browser for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Stop browser client
|
||||
if browser_client:
|
||||
try:
|
||||
browser_client.stop()
|
||||
@@ -236,19 +262,17 @@ class BrowserSessionManager:
|
||||
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove session from dictionary
|
||||
del self._sync_sessions[thread_id]
|
||||
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
|
||||
|
||||
async def close_all_browsers(self) -> None:
|
||||
"""Close all browser sessions."""
|
||||
# Close all async browsers
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
with self._lock:
|
||||
async_thread_ids = list(self._async_sessions.keys())
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
|
||||
for thread_id in async_thread_ids:
|
||||
await self.close_async_browser(thread_id)
|
||||
|
||||
# Close all sync browsers
|
||||
sync_thread_ids = list(self._sync_sessions.keys())
|
||||
for thread_id in sync_thread_ids:
|
||||
self.close_sync_browser(thread_id)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -38,22 +40,32 @@ class RAG(Adapter):
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
_embedding_service: EmbeddingService = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
try:
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(path=self.persist_directory)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
self._lock_name = (
|
||||
f"chromadb:{os.path.realpath(self.persist_directory)}"
|
||||
if self.persist_directory
|
||||
else "chromadb:ephemeral"
|
||||
)
|
||||
|
||||
with store_lock(self._lock_name):
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=self.persist_directory
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(
|
||||
provider=self.embedding_provider,
|
||||
model=self.embedding_model,
|
||||
@@ -87,29 +99,8 @@ class RAG(Adapter):
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
documents = []
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata["chunk_index"] = i
|
||||
@@ -136,7 +127,6 @@ class RAG(Adapter):
|
||||
|
||||
ids = [doc.id for doc in documents]
|
||||
metadatas = []
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update(
|
||||
@@ -148,16 +138,36 @@ class RAG(Adapter):
|
||||
)
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
with store_lock(self._lock_name):
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||
try:
|
||||
@@ -201,7 +211,8 @@ class RAG(Adapter):
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
try:
|
||||
self._client.delete_collection(self.collection_name)
|
||||
with store_lock(self._lock_name):
|
||||
self._client.delete_collection(self.collection_name)
|
||||
logger.info(f"Deleted collection: {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool import (
|
||||
BrightDataDatasetTool,
|
||||
BrightDataSearchTool,
|
||||
@@ -185,7 +196,14 @@ __all__ = [
|
||||
"AIMindTool",
|
||||
"ApifyActorsTool",
|
||||
"ArxivPaperTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Brave API error codes that indicate non-retryable quota/usage exhaustion.
|
||||
_QUOTA_CODES = frozenset({"QUOTA_LIMITED", "USAGE_LIMIT_EXCEEDED"})
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def _parse_error_body(resp: requests.Response) -> dict[str, Any] | None:
|
||||
"""Extract the structured "error" object from a Brave API error response."""
|
||||
try:
|
||||
body = resp.json()
|
||||
error = body.get("error")
|
||||
return error if isinstance(error, dict) else None
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def _raise_for_error(resp: requests.Response) -> None:
|
||||
"""Brave Search API error responses contain helpful JSON payloads"""
|
||||
status = resp.status_code
|
||||
try:
|
||||
body = json.dumps(resp.json())
|
||||
except (ValueError, KeyError):
|
||||
body = resp.text[:500]
|
||||
|
||||
raise RuntimeError(f"Brave Search API error (HTTP {status}): {body}")
|
||||
|
||||
|
||||
def _is_retryable(resp: requests.Response) -> bool:
|
||||
"""Return True for transient failures that are worth retrying.
|
||||
|
||||
* 429 + RATE_LIMITED — the per-second sliding window is full.
|
||||
* 5xx — transient server-side errors.
|
||||
|
||||
Quota exhaustion (QUOTA_LIMITED, USAGE_LIMIT_EXCEEDED) is
|
||||
explicitly excluded: retrying will never succeed until the billing
|
||||
period resets.
|
||||
"""
|
||||
if resp.status_code == 429:
|
||||
error = _parse_error_body(resp) or {}
|
||||
return error.get("code") not in _QUOTA_CODES
|
||||
return 500 <= resp.status_code < 600
|
||||
|
||||
|
||||
def _retry_delay(resp: requests.Response, attempt: int) -> float:
|
||||
"""Compute wait time before the next retry attempt.
|
||||
|
||||
Prefers the server-supplied Retry-After header when available;
|
||||
falls back to exponential backoff (1s, 2s, 4s, ...).
|
||||
"""
|
||||
retry_after = resp.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
try:
|
||||
return max(0.0, float(retry_after))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return float(2**attempt)
|
||||
|
||||
|
||||
class BraveSearchToolBase(BaseTool, ABC):
|
||||
"""
|
||||
Base class for Brave Search API interactions.
|
||||
|
||||
Individual tool subclasses must provide the following:
|
||||
- search_url
|
||||
- header_schema (pydantic model)
|
||||
- args_schema (pydantic model)
|
||||
- _refine_payload() -> dict[str, Any]
|
||||
"""
|
||||
|
||||
search_url: str
|
||||
raw: bool = False
|
||||
args_schema: type[BaseModel]
|
||||
header_schema: type[BaseModel]
|
||||
|
||||
# Tool options (legacy parameters)
|
||||
country: str | None = None
|
||||
save_file: bool = False
|
||||
n_results: int = 10
|
||||
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY",
|
||||
description="API key for Brave Search",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
requests_per_second: float = 1.0,
|
||||
save_file: bool = False,
|
||||
raw: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key or os.environ.get("BRAVE_API_KEY")
|
||||
if not self._api_key:
|
||||
raise ValueError("BRAVE_API_KEY environment variable is required")
|
||||
|
||||
self.raw = bool(raw)
|
||||
self._timeout = int(timeout)
|
||||
self.save_file = bool(save_file)
|
||||
self._requests_per_second = float(requests_per_second)
|
||||
self._headers = self._build_and_validate_headers(headers or {})
|
||||
# Per-instance rate limiting: each instance has its own clock and lock.
|
||||
# Total process rate is the sum of limits of instances you create.
|
||||
self._last_request_time: float = 0
|
||||
self._rate_limit_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, Any]:
|
||||
return self._headers
|
||||
|
||||
def set_headers(self, headers: dict[str, Any]) -> BraveSearchToolBase:
|
||||
merged = {**self._headers, **{k.lower(): v for k, v in headers.items()}}
|
||||
self._headers = self._build_and_validate_headers(merged)
|
||||
return self
|
||||
|
||||
def _build_and_validate_headers(self, headers: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = {k.lower(): v for k, v in headers.items()}
|
||||
normalized.setdefault("x-subscription-token", self._api_key)
|
||||
normalized.setdefault("accept", "application/json")
|
||||
|
||||
try:
|
||||
self.header_schema(**normalized)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid headers: {e}") from e
|
||||
|
||||
return normalized
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
"""Enforce minimum interval between requests for this instance. Thread-safe."""
|
||||
if self._requests_per_second <= 0:
|
||||
return
|
||||
|
||||
min_interval = 1.0 / self._requests_per_second
|
||||
with self._rate_limit_lock:
|
||||
now = time.time()
|
||||
next_allowed = self._last_request_time + min_interval
|
||||
if now < next_allowed:
|
||||
time.sleep(next_allowed - now)
|
||||
now = time.time()
|
||||
self._last_request_time = now
|
||||
|
||||
def _make_request(
|
||||
self, params: dict[str, Any], *, _max_retries: int = 3
|
||||
) -> dict[str, Any]:
|
||||
"""Execute an HTTP GET against the Brave Search API with retry logic."""
|
||||
last_resp: requests.Response | None = None
|
||||
|
||||
# Retry the request up to _max_retries times
|
||||
for attempt in range(_max_retries):
|
||||
self._rate_limit()
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
resp = requests.get(
|
||||
self.search_url,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except requests.ConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API connection failed: {exc}"
|
||||
) from exc
|
||||
except requests.Timeout as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API request timed out after {self._timeout}s: {exc}"
|
||||
) from exc
|
||||
|
||||
# Log the rate limit headers and request details
|
||||
logger.debug(
|
||||
"Brave Search API request: %s %s -> %d",
|
||||
"GET",
|
||||
resp.url,
|
||||
resp.status_code,
|
||||
)
|
||||
|
||||
# Response was OK, return the JSON body
|
||||
if resp.ok:
|
||||
try:
|
||||
return resp.json()
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
|
||||
) from exc
|
||||
|
||||
# Response was not OK, but is retryable
|
||||
# (e.g., 429 Too Many Requests, 500 Internal Server Error)
|
||||
if _is_retryable(resp) and attempt < _max_retries - 1:
|
||||
delay = _retry_delay(resp, attempt)
|
||||
logger.warning(
|
||||
"Brave Search API returned %d. Retrying in %.1fs (attempt %d/%d)",
|
||||
resp.status_code,
|
||||
delay,
|
||||
attempt + 1,
|
||||
_max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
last_resp = resp
|
||||
continue
|
||||
|
||||
# Response was not OK, nor was it retryable
|
||||
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
|
||||
_raise_for_error(resp)
|
||||
|
||||
# All retries exhausted
|
||||
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
|
||||
return {} # unreachable (here to satisfy the type checker and linter)
|
||||
|
||||
def _run(self, q: str | None = None, **params: Any) -> Any:
|
||||
# Allow positional usage: tool.run("latest Brave browser features")
|
||||
if q is not None:
|
||||
params["q"] = q
|
||||
|
||||
params = self._common_payload_refinement(params)
|
||||
|
||||
# Validate only schema fields
|
||||
schema_keys = self.args_schema.model_fields
|
||||
payload_in = {k: v for k, v in params.items() if k in schema_keys}
|
||||
|
||||
try:
|
||||
validated = self.args_schema(**payload_in)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid parameters: {e}") from e
|
||||
|
||||
# The subclass may have additional refinements to apply to the payload, such as goggles or other parameters
|
||||
payload = self._refine_request_payload(validated.model_dump(exclude_none=True))
|
||||
response = self._make_request(payload)
|
||||
|
||||
if not self.raw:
|
||||
response = self._refine_response(response)
|
||||
|
||||
if self.save_file:
|
||||
_save_results_to_file(json.dumps(response, indent=2))
|
||||
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Subclass must implement: transform validated params dict into API request params."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _refine_response(self, response: dict[str, Any]) -> Any:
|
||||
"""Subclass must implement: transform response dict into a more useful format."""
|
||||
raise NotImplementedError
|
||||
|
||||
_EMPTY_VALUES: ClassVar[tuple[None, str, str, list[Any]]] = (None, "", "null", [])
|
||||
|
||||
def _common_payload_refinement(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Common payload refinement for all tools."""
|
||||
# crewAI's schema pipeline (ensure_all_properties_required in
|
||||
# pydantic_schema_utils.py) marks every property as required so
|
||||
# that OpenAI strict-mode structured outputs work correctly.
|
||||
# The side-effect is that the LLM fills in *every* parameter —
|
||||
# even truly optional ones — using placeholder values such as
|
||||
# None, "", "null", or []. Only optional fields are affected,
|
||||
# so we limit the check to those.
|
||||
fields = self.args_schema.model_fields
|
||||
params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
# Permit custom and required fields, and fields with non-empty values
|
||||
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
|
||||
}
|
||||
|
||||
# Make sure params has "q" for query instead of "query" or "search_query"
|
||||
query = params.get("query") or params.get("search_query")
|
||||
if query is not None and "q" not in params:
|
||||
params["q"] = query
|
||||
params.pop("query", None)
|
||||
params.pop("search_query", None)
|
||||
|
||||
# If "count" was not explicitly provided, use n_results
|
||||
# (only when the schema actually supports a "count" field)
|
||||
if "count" in self.args_schema.model_fields:
|
||||
if "count" not in params and self.n_results is not None:
|
||||
params["count"] = self.n_results
|
||||
|
||||
# If "country" was not explicitly provided, but self.country is set, use it
|
||||
# (only when the schema actually supports a "country" field)
|
||||
if "country" in self.args_schema.model_fields:
|
||||
if "country" not in params and self.country is not None:
|
||||
params["country"] = self.country
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
ImageSearchHeaders,
|
||||
ImageSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveImageSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs image searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Image Search"
|
||||
args_schema: type[BaseModel] = ImageSearchParams
|
||||
header_schema: type[BaseModel] = ImageSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs image searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/images/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("properties", {}).get("url"),
|
||||
"dimensions": f"{w}x{h}"
|
||||
if (w := result.get("properties", {}).get("width"))
|
||||
and (h := result.get("properties", {}).get("height"))
|
||||
else None,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LLMContextHeaders,
|
||||
LLMContextParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveLLMContextTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves context for LLM usage from the Brave Search API."""
|
||||
|
||||
name: str = "Brave LLM Context"
|
||||
args_schema: type[BaseModel] = LLMContextParams
|
||||
header_schema: type[BaseModel] = LLMContextHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that retrieves context for LLM usage from the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/llm/context"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
|
||||
"""The LLM Context response schema is fairly simple. Return as is."""
|
||||
return response
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LocalPOIs
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LocalPOIsDescriptionHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsParams,
|
||||
)
|
||||
|
||||
|
||||
DayOpeningHours = LocalPOIs.DayOpeningHours
|
||||
OpeningHours = LocalPOIs.OpeningHours
|
||||
LocationResult = LocalPOIs.LocationResult
|
||||
LocalPOIsResponse = LocalPOIs.Response
|
||||
|
||||
|
||||
def _flatten_slots(slots: list[DayOpeningHours]) -> list[dict[str, str]]:
|
||||
"""Convert a list of DayOpeningHours dicts into simplified entries."""
|
||||
return [
|
||||
{
|
||||
"day": slot["full_name"].lower(),
|
||||
"opens": slot["opens"],
|
||||
"closes": slot["closes"],
|
||||
}
|
||||
for slot in slots
|
||||
]
|
||||
|
||||
|
||||
def _simplify_opening_hours(result: LocationResult) -> list[dict[str, str]] | None:
|
||||
"""Collapse opening_hours into a flat list of {day, opens, closes} dicts."""
|
||||
hours = result.get("opening_hours")
|
||||
if not hours:
|
||||
return None
|
||||
|
||||
entries: list[dict[str, str]] = []
|
||||
|
||||
current = hours.get("current_day")
|
||||
if current:
|
||||
entries.extend(_flatten_slots(current))
|
||||
|
||||
days = hours.get("days")
|
||||
if days:
|
||||
for day_slots in days:
|
||||
entries.extend(_flatten_slots(day_slots))
|
||||
|
||||
return entries or None
|
||||
|
||||
|
||||
class BraveLocalPOIsTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POIs"
|
||||
args_schema: type[BaseModel] = LocalPOIsParams
|
||||
header_schema: type[BaseModel] = LocalPOIsHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/pois"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("url"),
|
||||
"description": result.get("description"),
|
||||
"address": result.get("postal_address", {}).get("displayAddress"),
|
||||
"contact": result.get("contact", {}).get("telephone")
|
||||
or result.get("contact", {}).get("email")
|
||||
or None,
|
||||
"opening_hours": _simplify_opening_hours(result),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
|
||||
class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POI Descriptions"
|
||||
args_schema: type[BaseModel] = LocalPOIsDescriptionParams
|
||||
header_schema: type[BaseModel] = LocalPOIsDescriptionHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/descriptions"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"id": result.get("id"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
NewsSearchHeaders,
|
||||
NewsSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveNewsSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs news searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave News Search"
|
||||
args_schema: type[BaseModel] = NewsSearchParams
|
||||
header_schema: type[BaseModel] = NewsSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs news searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/news/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -10,17 +9,13 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
@@ -29,51 +24,6 @@ Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
@@ -83,7 +33,7 @@ class BraveSearchTool(BaseTool):
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
@@ -120,8 +70,8 @@ class BraveSearchTool(BaseTool):
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
# Fallback to "query" or "search_query" for backwards compatibility
|
||||
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
@@ -130,8 +80,11 @@ class BraveSearchTool(BaseTool):
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
# Fallback to "search_language" for backwards compatibility
|
||||
if search_lang := kwargs.get("search_lang") or kwargs.get(
|
||||
"search_language"
|
||||
):
|
||||
payload["search_lang"] = search_lang
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
VideoSearchHeaders,
|
||||
VideoSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveVideoSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs video searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Video Search"
|
||||
args_schema: type[BaseModel] = VideoSearchParams
|
||||
header_schema: type[BaseModel] = VideoSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs video searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/videos/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchHeaders,
|
||||
WebSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveWebSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Web Search"
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
header_schema: type[BaseModel] = WebSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
results = response.get("web", {}).get("results", [])
|
||||
refined = []
|
||||
for result in results:
|
||||
snippets = result.get("extra_snippets") or []
|
||||
if not snippets:
|
||||
desc = result.get("description")
|
||||
if desc:
|
||||
snippets = [desc]
|
||||
refined.append(
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"snippets": snippets,
|
||||
}
|
||||
)
|
||||
return refined
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
|
||||
class LocalPOIs:
|
||||
class PostalAddress(TypedDict, total=False):
|
||||
type: Literal["PostalAddress"]
|
||||
country: str
|
||||
postalCode: str
|
||||
streetAddress: str
|
||||
addressRegion: str
|
||||
addressLocality: str
|
||||
displayAddress: str
|
||||
|
||||
class DayOpeningHours(TypedDict):
|
||||
abbr_name: str
|
||||
full_name: str
|
||||
opens: str
|
||||
closes: str
|
||||
|
||||
class OpeningHours(TypedDict, total=False):
|
||||
current_day: list[LocalPOIs.DayOpeningHours]
|
||||
days: list[list[LocalPOIs.DayOpeningHours]]
|
||||
|
||||
class LocationResult(TypedDict, total=False):
|
||||
provider_url: str
|
||||
title: str
|
||||
url: str
|
||||
id: str | None
|
||||
opening_hours: LocalPOIs.OpeningHours | None
|
||||
postal_address: LocalPOIs.PostalAddress | None
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
type: Literal["local_pois"]
|
||||
results: list[LocalPOIs.LocationResult]
|
||||
|
||||
|
||||
class LLMContext:
|
||||
class LLMContextItem(TypedDict, total=False):
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextMapItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextPOIItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class Grounding(TypedDict, total=False):
|
||||
generic: list[LLMContext.LLMContextItem]
|
||||
poi: LLMContext.LLMContextPOIItem
|
||||
map: list[LLMContext.LLMContextMapItem]
|
||||
|
||||
class Sources(TypedDict, total=False):
|
||||
pass
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
grounding: LLMContext.Grounding
|
||||
sources: LLMContext.Sources
|
||||
@@ -0,0 +1,525 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
|
||||
|
||||
# Common types
|
||||
Units = Literal["metric", "imperial"]
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
Freshness = (
|
||||
Literal["pd", "pw", "pm", "py"]
|
||||
| Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
)
|
||||
ResultFilter = list[
|
||||
Literal[
|
||||
"discussions",
|
||||
"faq",
|
||||
"infobox",
|
||||
"news",
|
||||
"query",
|
||||
"summarizer",
|
||||
"videos",
|
||||
"web",
|
||||
"locations",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class LLMContextParams(BaseModel):
|
||||
"""Parameters for Brave LLM Context endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_urls: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of URLs to include in the context.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="The approximate maximum number of tokens to include in the context.",
|
||||
ge=1,
|
||||
le=32768,
|
||||
)
|
||||
maximum_number_of_snippets: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of different snippets to include in the context.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
context_threshold_mode: (
|
||||
Literal["disabled", "strict", "lenient", "balanced"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="The mode to use for the context thresholding.",
|
||||
)
|
||||
maximum_number_of_tokens_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to include for each URL in the context.",
|
||||
ge=1,
|
||||
le=8192,
|
||||
)
|
||||
maximum_number_of_snippets_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of snippets to include per URL.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
enable_local: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable local recall. Not setting this value means auto-detect and uses local recall if any of the localization headers are provided.",
|
||||
)
|
||||
|
||||
|
||||
class WebSearchParams(BaseModel):
|
||||
"""Parameters for Brave Web Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
result_filter: ResultFilter | None = Field(
|
||||
default=None,
|
||||
description="Filter the results by type. Options: discussions/faq/infobox/news/query/summarizer/videos/web/locations. Note: The `count` parameter is applied only to the `web` results.",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
summary: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to generate a summarizer ID for the results.",
|
||||
)
|
||||
enable_rich_callback: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable rich callbacks for the results. Requires Pro level subscription.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsParams(BaseModel):
|
||||
"""Parameters for Brave Local POIs endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionParams(BaseModel):
|
||||
"""Parameters for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
class ImageSearchParams(BaseModel):
|
||||
"""Parameters for Brave Image Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Default is strict.",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=200,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
|
||||
|
||||
class VideoSearchParams(BaseModel):
|
||||
"""Parameters for Brave Video Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class NewsSearchParams(BaseModel):
|
||||
"""Parameters for Brave News Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class BaseSearchHeaders(BaseModel):
|
||||
"""Common headers for Brave Search endpoints."""
|
||||
|
||||
x_subscription_token: str = Field(
|
||||
alias="x-subscription-token",
|
||||
description="API key for Brave Search",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
alias="api-version",
|
||||
default=None,
|
||||
description="API version to use. Default is latest available.",
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}$", # YYYY-MM-DD
|
||||
)
|
||||
accept: Literal["application/json"] | Literal["*/*"] | None = Field(
|
||||
default=None,
|
||||
description="Accept header for the request.",
|
||||
)
|
||||
cache_control: Literal["no-cache"] | None = Field(
|
||||
alias="cache-control",
|
||||
default=None,
|
||||
description="Cache control header for the request.",
|
||||
)
|
||||
user_agent: str | None = Field(
|
||||
alias="user-agent",
|
||||
default=None,
|
||||
description="User agent for the request.",
|
||||
)
|
||||
|
||||
|
||||
class LLMContextHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave LLM Context endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POIs endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
|
||||
class VideoSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Video Search endpoint."""
|
||||
|
||||
|
||||
class ImageSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Image Search endpoint."""
|
||||
|
||||
|
||||
class NewsSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave News Search endpoint."""
|
||||
|
||||
|
||||
class WebSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Web Search endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_timezone: str | None = Field(
|
||||
alias="x-loc-timezone",
|
||||
default=None,
|
||||
description="Timezone of the user's location.",
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
x_loc_postal_code: str | None = Field(
|
||||
alias="x-loc-postal-code",
|
||||
default=None,
|
||||
description="The postal code of the user's location.",
|
||||
)
|
||||
@@ -1,13 +1,27 @@
|
||||
# CodeInterpreterTool
|
||||
|
||||
## Description
|
||||
This tool is used to give the Agent the ability to run code (Python3) from the code generated by the Agent itself. The code is executed in a sandboxed environment, so it is safe to run any code.
|
||||
This tool is used to give the Agent the ability to run code (Python3) from the code generated by the Agent itself. The code is executed in a Docker container for secure isolation.
|
||||
|
||||
It is incredible useful since it allows the Agent to generate code, run it in the same environment, get the result and use it to make decisions.
|
||||
It is incredibly useful since it allows the Agent to generate code, run it in an isolated environment, get the result and use it to make decisions.
|
||||
|
||||
## ⚠️ Security Requirements
|
||||
|
||||
**Docker is REQUIRED** for safe code execution. The tool will refuse to execute code without Docker to prevent security vulnerabilities.
|
||||
|
||||
### Why Docker is Required
|
||||
|
||||
Previous versions included a "restricted sandbox" fallback when Docker was unavailable. This has been **removed** due to critical security vulnerabilities:
|
||||
|
||||
- The Python-based sandbox could be escaped via object introspection
|
||||
- Attackers could recover the original `__import__` function and access any module
|
||||
- This allowed arbitrary command execution on the host system
|
||||
|
||||
**Docker provides real process isolation** and is the only secure way to execute untrusted code.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Docker
|
||||
- **Docker (REQUIRED)** - Install from [docker.com](https://docs.docker.com/get-docker/)
|
||||
|
||||
## Installation
|
||||
Install the crewai_tools package
|
||||
@@ -17,7 +31,9 @@ pip install 'crewai[tools]'
|
||||
|
||||
## Example
|
||||
|
||||
Remember that when using this tool, the code must be generated by the Agent itself. The code must be a Python3 code. And it will take some time for the first time to run because it needs to build the Docker image.
|
||||
Remember that when using this tool, the code must be generated by the Agent itself. The code must be Python3 code. It will take some time the first time to run because it needs to build the Docker image.
|
||||
|
||||
### Basic Usage (Docker Container - Recommended)
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
@@ -28,7 +44,9 @@ Agent(
|
||||
)
|
||||
```
|
||||
|
||||
Or if you need to pass your own Dockerfile just do this
|
||||
### Custom Dockerfile
|
||||
|
||||
If you need to pass your own Dockerfile:
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
@@ -39,15 +57,39 @@ Agent(
|
||||
)
|
||||
```
|
||||
|
||||
If it is difficult to connect to docker daemon automatically (especially for macOS users), you can do this to setup docker host manually
|
||||
### Manual Docker Host Configuration
|
||||
|
||||
If it is difficult to connect to the Docker daemon automatically (especially for macOS users), you can set up the Docker host manually:
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
Agent(
|
||||
...
|
||||
tools=[CodeInterpreterTool(user_docker_base_url="<Docker Host Base Url>",
|
||||
user_dockerfile_path="<Dockerfile_path>")],
|
||||
tools=[CodeInterpreterTool(
|
||||
user_docker_base_url="<Docker Host Base Url>",
|
||||
user_dockerfile_path="<Dockerfile_path>"
|
||||
)],
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Unsafe Mode (NOT RECOMMENDED)
|
||||
|
||||
If you absolutely cannot use Docker and **fully trust the code source**, you can use unsafe mode:
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
# WARNING: Only use with fully trusted code!
|
||||
Agent(
|
||||
...
|
||||
tools=[CodeInterpreterTool(unsafe_mode=True)],
|
||||
)
|
||||
```
|
||||
|
||||
**⚠️ SECURITY WARNING:** `unsafe_mode=True` executes code directly on the host without any isolation. Only use this if:
|
||||
- You completely trust the code being executed
|
||||
- You understand the security risks
|
||||
- You cannot install Docker in your environment
|
||||
|
||||
For production use, **always use Docker** (the default mode).
|
||||
|
||||
@@ -8,6 +8,7 @@ potentially unsafe operations and importing restricted modules.
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any, ClassVar, TypedDict
|
||||
|
||||
@@ -50,11 +51,16 @@ class CodeInterpreterSchema(BaseModel):
|
||||
|
||||
|
||||
class SandboxPython:
|
||||
"""A restricted Python execution environment for running code safely.
|
||||
"""INSECURE: A restricted Python execution environment with known vulnerabilities.
|
||||
|
||||
This class provides methods to safely execute Python code by restricting access to
|
||||
potentially dangerous modules and built-in functions. It creates a sandboxed
|
||||
environment where harmful operations are blocked.
|
||||
WARNING: This class does NOT provide real security isolation and is vulnerable to
|
||||
sandbox escape attacks via Python object introspection. Attackers can recover the
|
||||
original __import__ function and bypass all restrictions.
|
||||
|
||||
DO NOT USE for untrusted code execution. Use Docker containers instead.
|
||||
|
||||
This class attempts to restrict access to dangerous modules and built-in functions
|
||||
but provides no real security boundary against a motivated attacker.
|
||||
"""
|
||||
|
||||
BLOCKED_MODULES: ClassVar[set[str]] = {
|
||||
@@ -299,8 +305,8 @@ class CodeInterpreterTool(BaseTool):
|
||||
def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code in the safest available environment.
|
||||
|
||||
Attempts to run code in Docker if available, falls back to a restricted
|
||||
sandbox if Docker is not available.
|
||||
Requires Docker to be available for secure code execution. Fails closed
|
||||
if Docker is not available to prevent sandbox escape vulnerabilities.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
@@ -308,10 +314,24 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
Returns:
|
||||
The output of the executed code as a string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Docker is not available, as the restricted sandbox
|
||||
is vulnerable to escape attacks and should not be used
|
||||
for untrusted code execution.
|
||||
"""
|
||||
if self._check_docker_available():
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
return self.run_code_in_restricted_sandbox(code)
|
||||
|
||||
error_msg = (
|
||||
"Docker is required for safe code execution but is not available. "
|
||||
"The restricted sandbox fallback has been removed due to security vulnerabilities "
|
||||
"that allow sandbox escape via Python object introspection. "
|
||||
"Please install Docker (https://docs.docker.com/get-docker/) or use unsafe_mode=True "
|
||||
"if you trust the code source and understand the security risks."
|
||||
)
|
||||
Printer.print(error_msg, color="bold_red")
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs Python code in a Docker container for safe isolation.
|
||||
@@ -342,10 +362,19 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
@staticmethod
|
||||
def run_code_in_restricted_sandbox(code: str) -> str:
|
||||
"""Runs Python code in a restricted sandbox environment.
|
||||
"""DEPRECATED AND INSECURE: Runs Python code in a restricted sandbox environment.
|
||||
|
||||
Executes the code with restricted access to potentially dangerous modules and
|
||||
built-in functions for basic safety when Docker is not available.
|
||||
WARNING: This method is vulnerable to sandbox escape attacks via Python object
|
||||
introspection and should NOT be used for untrusted code execution. It has been
|
||||
deprecated and is only kept for backward compatibility with trusted code.
|
||||
|
||||
The "restricted" environment can be bypassed by attackers who can:
|
||||
- Use object graph introspection to recover the original __import__ function
|
||||
- Access any Python module including os, subprocess, sys, etc.
|
||||
- Execute arbitrary commands on the host system
|
||||
|
||||
Use run_code_in_docker() for secure code execution, or run_code_unsafe()
|
||||
if you explicitly acknowledge the security risks.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
@@ -354,7 +383,10 @@ class CodeInterpreterTool(BaseTool):
|
||||
The value of the 'result' variable from the executed code,
|
||||
or an error message if execution failed.
|
||||
"""
|
||||
Printer.print("Running code in restricted sandbox", color="yellow")
|
||||
Printer.print(
|
||||
"WARNING: Running code in INSECURE restricted sandbox (vulnerable to escape attacks)",
|
||||
color="bold_red"
|
||||
)
|
||||
exec_locals: dict[str, Any] = {}
|
||||
try:
|
||||
SandboxPython.exec(code=code, locals_=exec_locals)
|
||||
@@ -380,7 +412,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
|
||||
# Install libraries on the host machine
|
||||
for library in libraries_used:
|
||||
os.system(f"pip install {library}") # noqa: S605
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", library], check=False) # noqa: S603
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
|
||||
@@ -30,9 +30,8 @@ class FileWriterTool(BaseTool):
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Create the directory if it doesn't exist
|
||||
if kwargs.get("directory") and not os.path.exists(kwargs["directory"]):
|
||||
os.makedirs(kwargs["directory"])
|
||||
if kwargs.get("directory"):
|
||||
os.makedirs(kwargs["directory"], exist_ok=True)
|
||||
|
||||
# Construct the full path
|
||||
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])
|
||||
|
||||
@@ -99,8 +99,8 @@ class FileCompressorTool(BaseTool):
|
||||
def _prepare_output(output_path: str, overwrite: bool) -> bool:
|
||||
"""Ensures output path is ready for writing."""
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if os.path.exists(output_path) and not overwrite:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -18,7 +18,6 @@ class MergeAgentHandlerToolError(Exception):
|
||||
"""Base exception for Merge Agent Handler tool errors."""
|
||||
|
||||
|
||||
|
||||
class MergeAgentHandlerTool(BaseTool):
|
||||
"""
|
||||
Wrapper for Merge Agent Handler tools.
|
||||
@@ -174,7 +173,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tool = MergeAgentHandlerTool.from_tool_name(
|
||||
... tool_name="linear__create_issue",
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... )
|
||||
"""
|
||||
# Create an empty args schema model (proper BaseModel subclass)
|
||||
@@ -210,7 +209,10 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
if "parameters" in tool_schema:
|
||||
try:
|
||||
params = tool_schema["parameters"]
|
||||
if params.get("type") == "object" and "properties" in params:
|
||||
if (
|
||||
params.get("type") == "object"
|
||||
and "properties" in params
|
||||
):
|
||||
# Build field definitions for Pydantic
|
||||
fields = {}
|
||||
properties = params["properties"]
|
||||
@@ -298,7 +300,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
>>> tools = MergeAgentHandlerTool.from_tool_pack(
|
||||
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
|
||||
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"]
|
||||
... tool_names=["linear__create_issue", "linear__get_issues"],
|
||||
... )
|
||||
"""
|
||||
# Create a temporary instance to fetch the tool list
|
||||
|
||||
@@ -110,11 +110,13 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
self.custom_embedding_fn(query)
|
||||
if self.custom_embedding_fn
|
||||
else (
|
||||
lambda: __import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
lambda: (
|
||||
__import__("openai")
|
||||
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
)()
|
||||
)
|
||||
results = self.client.query_points(
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -33,6 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
@@ -102,7 +104,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
_connection_pool: list[SnowflakeConnection] | None = None
|
||||
_pool_lock: asyncio.Lock | None = None
|
||||
_pool_lock: threading.Lock | None = None
|
||||
_thread_pool: ThreadPoolExecutor | None = None
|
||||
_model_rebuilt: bool = False
|
||||
package_dependencies: list[str] = Field(
|
||||
@@ -122,7 +124,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
try:
|
||||
if SNOWFLAKE_AVAILABLE:
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._pool_lock = threading.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
else:
|
||||
raise ImportError
|
||||
@@ -147,7 +149,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._pool_lock = threading.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install Snowflake dependencies") from e
|
||||
@@ -163,13 +165,12 @@ class SnowflakeSearchTool(BaseTool):
|
||||
raise RuntimeError("Pool lock not initialized")
|
||||
if self._connection_pool is None:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
with self._pool_lock:
|
||||
if self._connection_pool:
|
||||
return self._connection_pool.pop()
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
@@ -204,9 +205,10 @@ class SnowflakeSearchTool(BaseTool):
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
with _cache_lock:
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
@@ -225,7 +227,8 @@ class SnowflakeSearchTool(BaseTool):
|
||||
]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
with _cache_lock:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
@@ -234,7 +237,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
self._pool_lock is not None
|
||||
and self._connection_pool is not None
|
||||
):
|
||||
async with self._pool_lock:
|
||||
with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e: # noqa: PERF203
|
||||
if attempt == self.max_retries - 1:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -137,7 +138,9 @@ class StagehandTool(BaseTool):
|
||||
- 'observe': For finding elements in a specific area
|
||||
"""
|
||||
args_schema: type[BaseModel] = StagehandToolSchema
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["stagehand<=0.5.9"]
|
||||
)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -620,9 +623,12 @@ class StagehandTool(BaseTool):
|
||||
# We're in an existing event loop, use it
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, self._async_run(instruction, url, command_type)
|
||||
ctx.run,
|
||||
asyncio.run,
|
||||
self._async_run(instruction, url, command_type),
|
||||
)
|
||||
result = future.result()
|
||||
else:
|
||||
@@ -706,11 +712,12 @@ class StagehandTool(BaseTool):
|
||||
if loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor() as executor
|
||||
):
|
||||
future = executor.submit(
|
||||
asyncio.run, self._async_close()
|
||||
ctx.run, asyncio.run, self._async_close()
|
||||
)
|
||||
future.result()
|
||||
else:
|
||||
|
||||
@@ -1,80 +1,777 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests as requests_lib
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
text: str = "",
|
||||
) -> MagicMock:
|
||||
"""Build a ``requests.Response``-like mock with the attributes used by ``_make_request``."""
|
||||
resp = MagicMock(spec=requests_lib.Response)
|
||||
resp.status_code = status_code
|
||||
resp.ok = 200 <= status_code < 400
|
||||
resp.url = "https://api.search.brave.com/res/v1/web/search?q=test"
|
||||
resp.text = text or (str(json_data) if json_data else "")
|
||||
resp.headers = headers or {}
|
||||
resp.json.return_value = json_data if json_data is not None else {}
|
||||
return resp
|
||||
|
||||
|
||||
# Fixtures
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _brave_env_and_rate_limit():
|
||||
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
|
||||
with patch.dict(os.environ, {"BRAVE_API_KEY": "test-api-key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
return BraveSearchTool(n_results=2)
|
||||
def web_tool():
|
||||
return BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_brave_tool_initialization():
|
||||
tool = BraveSearchTool()
|
||||
assert tool.n_results == 10
|
||||
@pytest.fixture
|
||||
def image_tool():
|
||||
return BraveImageSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def news_tool():
|
||||
return BraveNewsSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_tool():
|
||||
return BraveVideoSearchTool()
|
||||
|
||||
|
||||
# Initialization
|
||||
|
||||
ALL_TOOL_CLASSES = [
|
||||
BraveWebSearchTool,
|
||||
BraveImageSearchTool,
|
||||
BraveNewsSearchTool,
|
||||
BraveVideoSearchTool,
|
||||
BraveLLMContextTool,
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_env_var(tool_cls):
|
||||
"""Each tool can be created when BRAVE_API_KEY is in the environment."""
|
||||
tool = tool_cls()
|
||||
assert tool.api_key == "test-api-key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_explicit_key(tool_cls):
|
||||
"""An explicit api_key takes precedence over the environment."""
|
||||
tool = tool_cls(api_key="explicit-key")
|
||||
assert tool.api_key == "explicit-key"
|
||||
|
||||
|
||||
def test_missing_api_key_raises():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="BRAVE_API_KEY"):
|
||||
BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_default_attributes():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.save_file is False
|
||||
assert tool.n_results == 10
|
||||
assert tool._timeout == 30
|
||||
assert tool._requests_per_second == 1.0
|
||||
assert tool.raw is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool_search(mock_get, brave_tool):
|
||||
mock_response = {
|
||||
def test_custom_constructor_args():
|
||||
tool = BraveWebSearchTool(
|
||||
save_file=True,
|
||||
timeout=60,
|
||||
n_results=5,
|
||||
requests_per_second=0.5,
|
||||
raw=True,
|
||||
)
|
||||
assert tool.save_file is True
|
||||
assert tool._timeout == 60
|
||||
assert tool.n_results == 5
|
||||
assert tool._requests_per_second == 0.5
|
||||
assert tool.raw is True
|
||||
|
||||
|
||||
# Headers
|
||||
|
||||
|
||||
def test_default_headers():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
assert tool.headers["accept"] == "application/json"
|
||||
|
||||
|
||||
def test_set_headers_merges_and_normalizes():
|
||||
tool = BraveWebSearchTool()
|
||||
tool.set_headers({"Cache-Control": "no-cache"})
|
||||
assert tool.headers["cache-control"] == "no-cache"
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
|
||||
|
||||
def test_set_headers_returns_self_for_chaining():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.set_headers({"Cache-Control": "no-cache"}) is tool
|
||||
|
||||
|
||||
def test_invalid_header_value_raises():
|
||||
tool = BraveImageSearchTool()
|
||||
with pytest.raises(ValueError, match="Invalid headers"):
|
||||
tool.set_headers({"Accept": "text/xml"})
|
||||
|
||||
|
||||
# Endpoint & Schema Wiring
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_cls, expected_url, expected_params, expected_headers",
|
||||
[
|
||||
(
|
||||
BraveWebSearchTool,
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveImageSearchTool,
|
||||
"https://api.search.brave.com/res/v1/images/search",
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveNewsSearchTool,
|
||||
"https://api.search.brave.com/res/v1/news/search",
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveVideoSearchTool,
|
||||
"https://api.search.brave.com/res/v1/videos/search",
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveLLMContextTool,
|
||||
"https://api.search.brave.com/res/v1/llm/context",
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsTool,
|
||||
"https://api.search.brave.com/res/v1/local/pois",
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
"https://api.search.brave.com/res/v1/local/descriptions",
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_wiring(tool_cls, expected_url, expected_params, expected_headers):
|
||||
tool = tool_cls()
|
||||
assert tool.search_url == expected_url
|
||||
assert tool.args_schema is expected_params
|
||||
assert tool.header_schema is expected_headers
|
||||
|
||||
|
||||
# Payload Refinement (e.g., `query` -> `q`, `count` fallback, param pass-through)
|
||||
|
||||
|
||||
def test_web_refine_request_payload_passes_all_params(web_tool):
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"count": 5,
|
||||
"offset": 2,
|
||||
"safesearch": "moderate",
|
||||
"freshness": "pw",
|
||||
}
|
||||
)
|
||||
refined_params = web_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "test"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["count"] == 5
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["search_lang"] == "en"
|
||||
assert refined_params["offset"] == 2
|
||||
assert refined_params["safesearch"] == "moderate"
|
||||
assert refined_params["freshness"] == "pw"
|
||||
|
||||
|
||||
def test_image_refine_request_payload_passes_all_params(image_tool):
|
||||
params = image_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "cat photos",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"safesearch": "strict",
|
||||
"count": 50,
|
||||
"spellcheck": True,
|
||||
}
|
||||
)
|
||||
refined_params = image_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "cat photos"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["safesearch"] == "strict"
|
||||
assert refined_params["count"] == 50
|
||||
assert refined_params["spellcheck"] is True
|
||||
|
||||
|
||||
def test_news_refine_request_payload_passes_all_params(news_tool):
|
||||
params = news_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "breaking news",
|
||||
"country": "US",
|
||||
"count": 10,
|
||||
"offset": 1,
|
||||
"freshness": "pd",
|
||||
"extra_snippets": True,
|
||||
}
|
||||
)
|
||||
refined_params = news_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "breaking news"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 1
|
||||
assert refined_params["freshness"] == "pd"
|
||||
assert refined_params["extra_snippets"] is True
|
||||
|
||||
|
||||
def test_video_refine_request_payload_passes_all_params(video_tool):
|
||||
params = video_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "tutorial",
|
||||
"country": "US",
|
||||
"count": 25,
|
||||
"offset": 0,
|
||||
"safesearch": "strict",
|
||||
"freshness": "pm",
|
||||
}
|
||||
)
|
||||
refined_params = video_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "tutorial"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 0
|
||||
assert refined_params["freshness"] == "pm"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_flow_into_query_params():
|
||||
"""The legacy n_results and country constructor params are applied as defaults
|
||||
when count/country are not explicitly provided at call time."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement({"query": "test"})
|
||||
|
||||
assert params["count"] == 3
|
||||
assert params["country"] == "BR"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_do_not_override_explicit_query_params():
|
||||
"""Explicit query-time count/country take precedence over constructor defaults."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement(
|
||||
{"query": "test", "count": 10, "country": "US"}
|
||||
)
|
||||
|
||||
assert params["count"] == 10
|
||||
assert params["country"] == "US"
|
||||
|
||||
|
||||
def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_tool):
|
||||
result = web_tool._refine_request_payload(
|
||||
{
|
||||
"query": "test",
|
||||
"goggles": ["goggle1", "goggle2"],
|
||||
}
|
||||
)
|
||||
assert result["goggles"] == ["goggle1", "goggle2"]
|
||||
|
||||
|
||||
# Null-like / empty value stripping
|
||||
#
|
||||
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
|
||||
# every schema property as required for OpenAI strict-mode compatibility.
|
||||
# Because optional Brave API parameters look required to the LLM, it fills
|
||||
# them with placeholder junk — None, "", "null", or []. The test below
|
||||
# verifies that _common_payload_refinement strips these from optional fields.
|
||||
|
||||
|
||||
def test_common_refinement_strips_null_like_values(web_tool):
|
||||
"""_common_payload_refinement drops optional keys with None / '' / 'null' / []."""
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "",
|
||||
"freshness": "null",
|
||||
"count": 5,
|
||||
"goggles": [],
|
||||
}
|
||||
)
|
||||
assert params["q"] == "test"
|
||||
assert params["country"] == "US"
|
||||
assert params["count"] == 5
|
||||
assert "search_lang" not in params
|
||||
assert "freshness" not in params
|
||||
assert "goggles" not in params
|
||||
|
||||
|
||||
# End-to-End _run() with Mocked HTTP Response
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_web_search_end_to_end(mock_get, web_tool):
|
||||
web_tool.raw = True
|
||||
data = {"web": {"results": [{"title": "R", "url": "http://r.co"}]}}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
call_args = mock_get.call_args.kwargs
|
||||
assert call_args["params"]["q"] == "test"
|
||||
assert call_args["headers"]["x-subscription-token"] == "test-api-key"
|
||||
assert result == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_image_search_end_to_end(mock_get, image_tool):
|
||||
image_tool.raw = True
|
||||
data = {"results": [{"url": "http://img.co/a.jpg"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert image_tool._run(query="cats") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_news_search_end_to_end(mock_get, news_tool):
|
||||
news_tool.raw = True
|
||||
data = {"results": [{"title": "News", "url": "http://n.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert news_tool._run(query="headlines") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_video_search_end_to_end(mock_get, video_tool):
|
||||
video_tool.raw = True
|
||||
data = {"results": [{"title": "Vid", "url": "http://v.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert video_tool._run(query="python tutorial") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_raw_false_calls_refine_response(mock_get, web_tool):
|
||||
"""With raw=False (the default), _refine_response transforms the API response."""
|
||||
api_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "http://test.com",
|
||||
"description": "Test Description",
|
||||
"title": "CrewAI",
|
||||
"url": "https://crewai.com",
|
||||
"description": "AI agent framework",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value = _mock_response(json_data=api_response)
|
||||
|
||||
result = brave_tool.run(query="test")
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
assert data[0]["title"] == "Test Title"
|
||||
assert data[0]["url"] == "http://test.com"
|
||||
assert web_tool.raw is False
|
||||
result = web_tool._run(query="crewai")
|
||||
|
||||
# The web tool's _refine_response extracts and reshapes results.
|
||||
# The key assertion: we should NOT get back the raw API envelope.
|
||||
assert result != api_response
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
# Backward Compatibility & Legacy Parameter Support
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_brave_tool()
|
||||
test_brave_tool_initialization()
|
||||
# test_brave_tool_search(brave_tool)
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_positional_query_argument(mock_get, web_tool):
|
||||
"""tool.run('my query') works as a positional argument."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run("positional test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "positional test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_search_query_backward_compat(mock_get, web_tool):
|
||||
"""The legacy 'search_query' param is mapped to 'query'."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run(search_query="legacy test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "legacy test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base._save_results_to_file")
|
||||
def test_save_file_called_when_enabled(mock_save, mock_get):
|
||||
mock_get.return_value = _mock_response(json_data={"results": []})
|
||||
|
||||
tool = BraveWebSearchTool(save_file=True)
|
||||
tool._run(query="test")
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
# Error Handling
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_connection_error_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.ConnectionError("refused")
|
||||
with pytest.raises(RuntimeError, match="Brave Search API connection failed"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_timeout_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.Timeout("timed out")
|
||||
with pytest.raises(RuntimeError, match="timed out"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_params_raises_value_error(mock_get, web_tool):
|
||||
"""count=999 exceeds WebSearchParams.count le=20."""
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
web_tool._run(query="test", count=999)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_4xx_error_raises_with_api_detail(mock_get, web_tool):
|
||||
"""A 422 with a structured error body includes code and detail in the message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=422,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "abc-123",
|
||||
"status": 422,
|
||||
"code": "OPTION_NOT_IN_PLAN",
|
||||
"detail": "extra_snippets requires a Pro plan",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="OPTION_NOT_IN_PLAN") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "extra_snippets requires a Pro plan" in str(exc_info.value)
|
||||
assert "HTTP 422" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_auth_error_raises_immediately(mock_get, web_tool):
|
||||
"""A 401 with SUBSCRIPTION_TOKEN_INVALID is not retried."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=401,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "xyz",
|
||||
"status": 401,
|
||||
"code": "SUBSCRIPTION_TOKEN_INVALID",
|
||||
"detail": "The subscription token is invalid",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="SUBSCRIPTION_TOKEN_INVALID"):
|
||||
web_tool._run(query="test")
|
||||
# Should NOT have retried — only one call.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_quota_limited_429_raises_immediately(mock_get, web_tool):
|
||||
"""A 429 with QUOTA_LIMITED is NOT retried — quota exhaustion is terminal."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ql-1",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Monthly quota exceeded",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="QUOTA_LIMITED") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "Monthly quota exceeded" in str(exc_info.value)
|
||||
# Terminal — only one HTTP call, no retries.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_usage_limit_exceeded_429_raises_immediately(mock_get, web_tool):
|
||||
"""USAGE_LIMIT_EXCEEDED is also non-retryable, just like QUOTA_LIMITED."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ule-1",
|
||||
"status": 429,
|
||||
"code": "USAGE_LIMIT_EXCEEDED",
|
||||
}
|
||||
},
|
||||
text="usage limit exceeded",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="USAGE_LIMIT_EXCEEDED"):
|
||||
web_tool._run(query="test")
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_body_is_fully_included_in_message(mock_get, web_tool):
|
||||
"""The full JSON error body is included in the RuntimeError message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "x",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Exceeded",
|
||||
"meta": {"plan": "free", "limit": 1000},
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
web_tool._run(query="test")
|
||||
msg = str(exc_info.value)
|
||||
assert "HTTP 429" in msg
|
||||
assert "QUOTA_LIMITED" in msg
|
||||
assert "free" in msg
|
||||
assert "1000" in msg
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_without_json_body_falls_back_to_text(mock_get, web_tool):
|
||||
"""When the error response isn't valid JSON, resp.text is used as the detail."""
|
||||
resp = _mock_response(status_code=500, text="Internal Server Error")
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="Internal Server Error"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_json_on_success_raises_runtime_error(mock_get, web_tool):
|
||||
"""A 200 OK with a non-JSON body raises RuntimeError."""
|
||||
resp = _mock_response(status_code=200)
|
||||
resp.json.side_effect = ValueError("Expecting value")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="invalid JSON"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
# Rate Limiting
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_sleeps_when_too_fast(mock_time, mock_get, web_tool):
|
||||
"""Back-to-back calls within the interval trigger a sleep."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Simulate: last request was at t=100, "now" is t=100.2 (only 0.2s elapsed).
|
||||
# With default 1 req/s the min interval is 1.0s, so it should sleep ~0.8s.
|
||||
mock_time.time.return_value = 100.2
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_called_once()
|
||||
sleep_duration = mock_time.sleep.call_args[0][0]
|
||||
assert 0.7 < sleep_duration < 0.9 # approximately 0.8s
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_skips_sleep_when_enough_time_passed(mock_time, mock_get, web_tool):
|
||||
"""No sleep when the elapsed time already exceeds the interval."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Last request was at t=100, "now" is t=102 (2s elapsed > 1s interval).
|
||||
mock_time.time.return_value = 102.0
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_disabled_when_zero(mock_time, mock_get, web_tool):
|
||||
"""requests_per_second=0 disables rate limiting entirely."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._last_request_time = 100.0
|
||||
mock_time.time.return_value = 100.0 # same instant
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_per_instance_independent(mock_time, mock_get, web_tool, image_tool):
|
||||
"""Each instance has its own rate-limit clock; a request on one does not delay the other."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Web tool fires at t=100 (its clock goes 0 -> 100).
|
||||
mock_time.time.return_value = 100.0
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Image tool fires at t=100.3. Its clock is still 0 (separate instance), so
|
||||
# next_allowed = 1.0 and 100.3 > 1.0 — no sleep. Total process rate can be sum of instance limits.
|
||||
mock_time.time.return_value = 100.3
|
||||
image_tool._run(query="cats")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
# Retry Behavior
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_retries_then_succeeds(mock_time, mock_get, web_tool):
|
||||
"""A transient RATE_LIMITED 429 is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
headers={"Retry-After": "2"},
|
||||
)
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_429, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
# Slept for the Retry-After value.
|
||||
retry_sleeps = [c for c in mock_time.sleep.call_args_list if c[0][0] == 2.0]
|
||||
assert len(retry_sleeps) == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_5xx_is_retried(mock_time, mock_get, web_tool):
|
||||
"""A 502 server error is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_502 = _mock_response(status_code=502, text="Bad Gateway")
|
||||
resp_502.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_502, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_exhausts_retries(mock_time, mock_get, web_tool):
|
||||
"""Persistent RATE_LIMITED 429s exhaust retries and raise RuntimeError."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
)
|
||||
mock_get.return_value = resp_429
|
||||
|
||||
with pytest.raises(RuntimeError, match="RATE_LIMITED"):
|
||||
web_tool._run(query="test")
|
||||
# 3 attempts (default _max_retries).
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_retry_uses_exponential_backoff_when_no_retry_after(
|
||||
mock_time, mock_get, web_tool
|
||||
):
|
||||
"""Without Retry-After, backoff is 2^attempt (1s, 2s, ...)."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_503 = _mock_response(status_code=503, text="Service Unavailable")
|
||||
resp_503.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"ok": True})
|
||||
mock_get.side_effect = [resp_503, resp_503, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Two retries: attempt 0 → sleep(1.0), attempt 1 → sleep(2.0).
|
||||
retry_sleeps = [c[0][0] for c in mock_time.sleep.call_args_list]
|
||||
assert 1.0 in retry_sleeps
|
||||
assert 2.0 in retry_sleeps
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
@@ -76,24 +77,22 @@ print("This is line 2")"""
|
||||
)
|
||||
|
||||
|
||||
def test_restricted_sandbox_basic_code_execution(printer_mock, docker_unavailable_mock):
|
||||
"""Test basic code execution."""
|
||||
def test_docker_unavailable_raises_error(printer_mock, docker_unavailable_mock):
|
||||
"""Test that execution fails when Docker is unavailable in safe mode."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
result = 2 + 2
|
||||
print(result)
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert result == 4
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
tool.run(code=code, libraries_used=[])
|
||||
|
||||
assert "Docker is required for safe code execution" in str(exc_info.value)
|
||||
assert "sandbox escape" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_modules(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted modules cannot be imported."""
|
||||
def test_restricted_sandbox_running_with_blocked_modules():
|
||||
"""Test that restricted modules cannot be imported when using the deprecated sandbox directly."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_modules = SandboxPython.BLOCKED_MODULES
|
||||
|
||||
@@ -102,18 +101,15 @@ def test_restricted_sandbox_running_with_blocked_modules(
|
||||
import {module}
|
||||
result = "Import succeeded"
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
|
||||
assert f"An error occurred: Importing '{module}' is not allowed" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_builtins(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted builtins are not available."""
|
||||
def test_restricted_sandbox_running_with_blocked_builtins():
|
||||
"""Test that restricted builtins are not available when using the deprecated sandbox directly."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_builtins = SandboxPython.UNSAFE_BUILTINS
|
||||
|
||||
@@ -122,25 +118,23 @@ def test_restricted_sandbox_running_with_blocked_builtins(
|
||||
{builtin}("test")
|
||||
result = "Builtin available"
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
assert f"An error occurred: name '{builtin}' is not defined" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_no_result_variable(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test behavior when no result variable is set."""
|
||||
"""Test behavior when no result variable is set in deprecated sandbox."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
x = 10
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
assert result == "No result variable found."
|
||||
|
||||
|
||||
@@ -159,6 +153,44 @@ x = 10
|
||||
assert result == "No result variable found."
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.subprocess.run")
|
||||
def test_unsafe_mode_installs_libraries_without_shell(
|
||||
subprocess_run_mock, printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that library installation uses subprocess.run with shell=False, not os.system."""
|
||||
tool = CodeInterpreterTool(unsafe_mode=True)
|
||||
code = "result = 1"
|
||||
libraries_used = ["numpy", "pandas"]
|
||||
|
||||
tool.run(code=code, libraries_used=libraries_used)
|
||||
|
||||
assert subprocess_run_mock.call_count == 2
|
||||
for call, library in zip(subprocess_run_mock.call_args_list, libraries_used):
|
||||
args, kwargs = call
|
||||
# Must be list form (no shell expansion possible)
|
||||
assert args[0] == [sys.executable, "-m", "pip", "install", library]
|
||||
# shell= must not be True (defaults to False)
|
||||
assert kwargs.get("shell", False) is False
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.subprocess.run")
|
||||
def test_unsafe_mode_library_name_with_shell_metacharacters_does_not_invoke_shell(
|
||||
subprocess_run_mock, printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that a malicious library name cannot inject shell commands."""
|
||||
tool = CodeInterpreterTool(unsafe_mode=True)
|
||||
code = "result = 1"
|
||||
malicious_library = "numpy; rm -rf /"
|
||||
|
||||
tool.run(code=code, libraries_used=[malicious_library])
|
||||
|
||||
subprocess_run_mock.assert_called_once()
|
||||
args, kwargs = subprocess_run_mock.call_args
|
||||
# The entire malicious string is passed as a single argument — no shell parsing
|
||||
assert args[0] == [sys.executable, "-m", "pip", "install", malicious_library]
|
||||
assert kwargs.get("shell", False) is False
|
||||
|
||||
|
||||
def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
|
||||
"""Test behavior when no result variable is set."""
|
||||
tool = CodeInterpreterTool(unsafe_mode=True)
|
||||
@@ -172,3 +204,50 @@ result = eval("5/1")
|
||||
"WARNING: Running code in unsafe mode", color="bold_magenta"
|
||||
)
|
||||
assert 5.0 == result
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"run_code_in_restricted_sandbox is known to be vulnerable to sandbox "
|
||||
"escape via object introspection. This test encodes the desired secure "
|
||||
"behavior (no escape possible) and will start passing once the "
|
||||
"vulnerability is fixed or the function is removed."
|
||||
)
|
||||
)
|
||||
def test_sandbox_escape_vulnerability_demonstration(printer_mock):
|
||||
"""Demonstrate that the restricted sandbox is vulnerable to escape attacks.
|
||||
|
||||
This test shows that an attacker can use Python object introspection to bypass
|
||||
the restricted sandbox and access blocked modules like 'os'. This is why the
|
||||
sandbox should never be used for untrusted code execution.
|
||||
|
||||
NOTE: This test uses the deprecated run_code_in_restricted_sandbox directly
|
||||
to demonstrate the vulnerability. In production, Docker is now required.
|
||||
"""
|
||||
tool = CodeInterpreterTool()
|
||||
|
||||
# Classic Python sandbox escape via object introspection
|
||||
escape_code = """
|
||||
# Recover the real __import__ function via object introspection
|
||||
for cls in ().__class__.__bases__[0].__subclasses__():
|
||||
if cls.__name__ == 'catch_warnings':
|
||||
# Get the real builtins module
|
||||
real_builtins = cls()._module.__builtins__
|
||||
real_import = real_builtins['__import__']
|
||||
# Now we can import os and execute commands
|
||||
os = real_import('os')
|
||||
# Demonstrate we have escaped the sandbox
|
||||
result = "SANDBOX_ESCAPED" if hasattr(os, 'system') else "FAILED"
|
||||
break
|
||||
"""
|
||||
|
||||
# The deprecated sandbox is vulnerable to this attack
|
||||
result = tool.run_code_in_restricted_sandbox(escape_code)
|
||||
|
||||
# Desired behavior: the restricted sandbox should prevent this escape.
|
||||
# If this assertion fails, run_code_in_restricted_sandbox remains vulnerable.
|
||||
assert result != "SANDBOX_ESCAPED", (
|
||||
"The restricted sandbox was bypassed via object introspection. "
|
||||
"This indicates run_code_in_restricted_sandbox is still vulnerable and "
|
||||
"is why Docker is now required for safe code execution."
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.10.1",
|
||||
"crewai-tools==1.11.0rc1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Any
|
||||
import urllib.request
|
||||
import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -40,7 +42,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.11.0rc1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -66,7 +68,8 @@ def _track_install() -> None:
|
||||
def _track_install_async() -> None:
|
||||
"""Track installation in background thread to avoid blocking imports."""
|
||||
if not Telemetry._is_telemetry_disabled():
|
||||
thread = threading.Thread(target=_track_install, daemon=True)
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
@@ -100,6 +103,7 @@ __all__ = [
|
||||
"Knowledge",
|
||||
"LLMGuardrail",
|
||||
"Memory",
|
||||
"PlanningConfig",
|
||||
"Process",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.a2a.auth.client_schemes import (
|
||||
)
|
||||
from crewai.a2a.auth.server_schemes import (
|
||||
AuthenticatedUser,
|
||||
EnterpriseTokenAuth,
|
||||
OIDCAuth,
|
||||
ServerAuthScheme,
|
||||
SimpleTokenAuth,
|
||||
@@ -25,6 +26,7 @@ __all__ = [
|
||||
"AuthenticatedUser",
|
||||
"BearerTokenAuth",
|
||||
"ClientAuthScheme",
|
||||
"EnterpriseTokenAuth",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
|
||||
@@ -4,6 +4,7 @@ These schemes validate incoming requests to A2A server endpoints.
|
||||
|
||||
Supported authentication methods:
|
||||
- Simple token validation with static bearer tokens
|
||||
- Enterprise token validation (via PlusAPI)
|
||||
- OpenID Connect with JWT validation using JWKS
|
||||
- OAuth2 with JWT validation or token introspection
|
||||
"""
|
||||
@@ -16,6 +17,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from pydantic import (
|
||||
@@ -33,6 +35,7 @@ from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import OAuth2SecurityScheme
|
||||
from jwt.types import Options
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -183,6 +186,24 @@ class SimpleTokenAuth(ServerAuthScheme):
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTokenAuth(ServerAuthScheme):
|
||||
"""Enterprise token authentication.
|
||||
|
||||
Validates tokens via the PlusAPI enterprise verification endpoint.
|
||||
"""
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using enterprise token verification.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Raises:
|
||||
NotImplementedError
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OIDCAuth(ServerAuthScheme):
|
||||
"""OpenID Connect authentication.
|
||||
|
||||
@@ -475,7 +496,7 @@ class OAuth2ServerAuth(ServerAuthScheme):
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_options: dict[str, Any] = {
|
||||
decode_options: Options = {
|
||||
"require": self.required_claims,
|
||||
}
|
||||
|
||||
@@ -556,7 +577,6 @@ class OAuth2ServerAuth(ServerAuthScheme):
|
||||
|
||||
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||
import httpx
|
||||
|
||||
if not self.introspection_url:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -633,6 +633,10 @@ class A2AServerConfig(BaseModel):
|
||||
default=False,
|
||||
description="Whether agent provides extended card to authenticated users",
|
||||
)
|
||||
extended_skills: list[AgentSkill] = Field(
|
||||
default_factory=list,
|
||||
description="Additional skills visible only to authenticated users in the extended card",
|
||||
)
|
||||
url: Url | None = Field(
|
||||
default=None,
|
||||
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
|
||||
|
||||
@@ -63,6 +63,9 @@ class A2AErrorCode(IntEnum):
|
||||
INVALID_AGENT_RESPONSE = -32006
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED = -32007
|
||||
"""Authenticated extended card feature is not configured."""
|
||||
|
||||
# CrewAI Custom Extensions (-32768 to -32100)
|
||||
UNSUPPORTED_VERSION = -32009
|
||||
"""The requested A2A protocol version is not supported."""
|
||||
@@ -108,6 +111,7 @@ ERROR_MESSAGES: dict[int, str] = {
|
||||
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
|
||||
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
|
||||
A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: "Authenticated Extended Card is not configured",
|
||||
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
|
||||
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
|
||||
@@ -284,6 +288,15 @@ class InvalidAgentResponseError(A2AError):
|
||||
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedExtendedCardNotConfiguredError(A2AError):
|
||||
"""Authenticated extended card is not configured."""
|
||||
|
||||
code: int = field(
|
||||
default=A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED, init=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedVersionError(A2AError):
|
||||
"""The requested A2A version is not supported."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
@@ -147,8 +148,9 @@ def fetch_agent_card(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@@ -215,8 +217,9 @@ def _fetch_agent_card_cached(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
import concurrent.futures
|
||||
from contextlib import asynccontextmanager
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
@@ -229,8 +230,9 @@ def execute_a2a_delegation(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
import json
|
||||
from types import MethodType
|
||||
@@ -278,7 +279,9 @@ def _fetch_agent_cards_concurrently(
|
||||
max_workers = min(len(a2a_agents), 10)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(_fetch_card_from_config, config): config
|
||||
executor.submit(
|
||||
contextvars.copy_context().run, _fetch_card_from_config, config
|
||||
): config
|
||||
for config in a2a_agents
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Sequence
|
||||
import contextvars
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -22,6 +23,7 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
@@ -191,13 +193,23 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
reasoning: bool = Field(
|
||||
planning_config: PlanningConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for agent planning before task execution.",
|
||||
)
|
||||
planning: bool = Field(
|
||||
default=False,
|
||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||
)
|
||||
reasoning: bool = Field(
|
||||
default=False,
|
||||
description="[DEPRECATED: Use planning_config instead] Whether the agent should reflect and create a plan before executing a task.",
|
||||
deprecated=True,
|
||||
)
|
||||
max_reasoning_attempts: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
description="[DEPRECATED: Use planning_config.max_attempts instead] Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
deprecated=True,
|
||||
)
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
@@ -264,8 +276,26 @@ class Agent(BaseAgent):
|
||||
if self.allow_code_execution:
|
||||
self._validate_docker_installation()
|
||||
|
||||
# Handle backward compatibility: convert reasoning=True to planning_config
|
||||
if self.reasoning and self.planning_config is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"The 'reasoning' parameter is deprecated. Use 'planning_config=PlanningConfig()' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.planning_config = PlanningConfig(
|
||||
max_attempts=self.max_reasoning_attempts,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def planning_enabled(self) -> bool:
|
||||
"""Check if planning is enabled for this agent."""
|
||||
return self.planning_config is not None or self.planning
|
||||
|
||||
def _setup_agent_executor(self) -> None:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
@@ -334,7 +364,11 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
# Only call handle_reasoning for legacy CrewAgentExecutor
|
||||
# For AgentExecutor, planning is handled in AgentExecutor.generate_plan()
|
||||
if self.executor_class is not AgentExecutor:
|
||||
handle_reasoning(self, task)
|
||||
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
@@ -513,9 +547,13 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task
|
||||
ctx.run,
|
||||
self._execute_without_timeout,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -572,7 +610,10 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
if self.executor_class is not AgentExecutor:
|
||||
handle_reasoning(
|
||||
self, task
|
||||
) # we need this till CrewAgentExecutor migrates to AgentExecutor
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
@@ -1418,17 +1459,19 @@ class Agent(BaseAgent):
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to save kickoff result to memory: {e}")
|
||||
|
||||
def _execute_and_build_output(
|
||||
def _build_output_from_result(
|
||||
self,
|
||||
result: dict[str, Any],
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent and build the output object.
|
||||
"""Build a LiteAgentOutput from an executor result dict.
|
||||
|
||||
Shared logic used by both sync and async execution paths.
|
||||
|
||||
Args:
|
||||
result: The result dictionary from executor.invoke / invoke_async.
|
||||
executor: The executor instance.
|
||||
inputs: Input dictionary for execution.
|
||||
response_format: Optional response format.
|
||||
|
||||
Returns:
|
||||
@@ -1436,8 +1479,6 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
import json
|
||||
|
||||
# Execute the agent (this is called from sync path, so invoke returns dict)
|
||||
result = cast(dict[str, Any], executor.invoke(inputs))
|
||||
output = result.get("output", "")
|
||||
|
||||
# Handle response format conversion
|
||||
@@ -1485,91 +1526,39 @@ class Agent(BaseAgent):
|
||||
else str(raw_output)
|
||||
)
|
||||
|
||||
todo_results = LiteAgentOutput.from_todo_items(executor.state.todos.items)
|
||||
|
||||
return LiteAgentOutput(
|
||||
raw=raw_str,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
messages=executor.messages,
|
||||
messages=list(executor.state.messages),
|
||||
plan=executor.state.plan,
|
||||
todos=todo_results,
|
||||
replan_count=executor.state.replan_count,
|
||||
last_replan_reason=executor.state.last_replan_reason,
|
||||
)
|
||||
|
||||
def _execute_and_build_output(
|
||||
self,
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent synchronously and build the output object."""
|
||||
result = cast(dict[str, Any], executor.invoke(inputs))
|
||||
return self._build_output_from_result(result, executor, response_format)
|
||||
|
||||
async def _execute_and_build_output_async(
|
||||
self,
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent asynchronously and build the output object.
|
||||
|
||||
This is the async version of _execute_and_build_output that uses
|
||||
invoke_async() for native async execution within event loops.
|
||||
|
||||
Args:
|
||||
executor: The executor instance.
|
||||
inputs: Input dictionary for execution.
|
||||
response_format: Optional response format.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput with raw output, formatted result, and metrics.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Execute the agent asynchronously
|
||||
"""Execute the agent asynchronously and build the output object."""
|
||||
result = await executor.invoke_async(inputs)
|
||||
output = result.get("output", "")
|
||||
|
||||
# Handle response format conversion
|
||||
formatted_result: BaseModel | None = None
|
||||
raw_output: str
|
||||
|
||||
if isinstance(output, BaseModel):
|
||||
formatted_result = output
|
||||
raw_output = output.model_dump_json()
|
||||
elif response_format:
|
||||
raw_output = str(output) if not isinstance(output, str) else output
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=raw_output,
|
||||
model=response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
conversion_result = converter.to_pydantic()
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
pass # Keep raw output if conversion fails
|
||||
else:
|
||||
raw_output = str(output) if not isinstance(output, str) else output
|
||||
|
||||
# Get token usage metrics
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
raw_str = (
|
||||
raw_output
|
||||
if isinstance(raw_output, str)
|
||||
else raw_output.model_dump_json()
|
||||
if isinstance(raw_output, BaseModel)
|
||||
else str(raw_output)
|
||||
)
|
||||
|
||||
return LiteAgentOutput(
|
||||
raw=raw_str,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
messages=executor.messages,
|
||||
)
|
||||
return self._build_output_from_result(result, executor, response_format)
|
||||
|
||||
def _process_kickoff_guardrail(
|
||||
self,
|
||||
|
||||
138
lib/crewai/src/crewai/agent/planning_config.py
Normal file
138
lib/crewai/src/crewai/agent/planning_config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class PlanningConfig(BaseModel):
|
||||
"""Configuration for agent planning/reasoning before task execution.
|
||||
|
||||
This allows users to customize the planning behavior including prompts,
|
||||
iteration limits, the LLM used for planning, and the reasoning effort
|
||||
level that controls post-step observation and replanning behavior.
|
||||
|
||||
Note: To disable planning, don't pass a planning_config or set planning=False
|
||||
on the Agent. The presence of a PlanningConfig enables planning.
|
||||
|
||||
Attributes:
|
||||
reasoning_effort: Controls observation and replanning after each step.
|
||||
- "low": Observe each step (validates success), but skip the
|
||||
decide/replan/refine pipeline. Steps are marked complete and
|
||||
execution continues linearly. Fastest option.
|
||||
- "medium": Observe each step. On failure, trigger replanning.
|
||||
On success, skip refinement and continue. Balanced option.
|
||||
- "high": Full observation pipeline — observe every step, then
|
||||
route through decide_next_action which can trigger early goal
|
||||
achievement, full replanning, or lightweight refinement.
|
||||
Most adaptive but adds latency per step.
|
||||
max_attempts: Maximum number of planning refinement attempts.
|
||||
If None, will continue until the agent indicates readiness.
|
||||
max_steps: Maximum number of steps in the generated plan.
|
||||
system_prompt: Custom system prompt for planning. Uses default if None.
|
||||
plan_prompt: Custom prompt for creating the initial plan.
|
||||
refine_prompt: Custom prompt for refining the plan.
|
||||
llm: LLM to use for planning. Uses agent's LLM if None.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
|
||||
# Simple usage — fast, linear execution (default)
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
planning_config=PlanningConfig(),
|
||||
)
|
||||
|
||||
# Balanced — replan only when steps fail
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
planning_config=PlanningConfig(
|
||||
reasoning_effort="medium",
|
||||
),
|
||||
)
|
||||
|
||||
# Full adaptive planning with refinement and replanning
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
planning_config=PlanningConfig(
|
||||
reasoning_effort="high",
|
||||
max_attempts=3,
|
||||
max_steps=10,
|
||||
plan_prompt="Create a focused plan for: {description}",
|
||||
llm="gpt-4o-mini", # Use cheaper model for planning
|
||||
),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
reasoning_effort: Literal["low", "medium", "high"] = Field(
|
||||
default="medium",
|
||||
description=(
|
||||
"Controls post-step observation and replanning behavior. "
|
||||
"'low' observes steps but skips replanning/refinement (fastest). "
|
||||
"'medium' observes and replans only on step failure (balanced). "
|
||||
"'high' runs full observation pipeline with replanning, refinement, "
|
||||
"and early goal detection (most adaptive, highest latency)."
|
||||
),
|
||||
)
|
||||
max_attempts: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Maximum number of planning refinement attempts. "
|
||||
"If None, will continue until the agent indicates readiness."
|
||||
),
|
||||
)
|
||||
max_steps: int = Field(
|
||||
default=20,
|
||||
description="Maximum number of steps in the generated plan.",
|
||||
ge=1,
|
||||
)
|
||||
system_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="Custom system prompt for planning. Uses default if None.",
|
||||
)
|
||||
plan_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="Custom prompt for creating the initial plan.",
|
||||
)
|
||||
refine_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="Custom prompt for refining the plan.",
|
||||
)
|
||||
max_replans: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of full replanning attempts before finalizing.",
|
||||
ge=0,
|
||||
)
|
||||
max_step_iterations: int = Field(
|
||||
default=15,
|
||||
description=(
|
||||
"Maximum LLM iterations per step in the StepExecutor multi-turn loop. "
|
||||
"Lower values make steps faster but less thorough."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
step_timeout: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Maximum wall-clock seconds for a single step execution. "
|
||||
"If exceeded, the step is marked as failed and observation decides "
|
||||
"whether to continue or replan. None means no per-step timeout."
|
||||
),
|
||||
)
|
||||
llm: str | BaseLLM | None = Field(
|
||||
default=None,
|
||||
description="LLM to use for planning. Uses agent's LLM if None.",
|
||||
)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
@@ -28,13 +28,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
"""Handle the reasoning process for an agent before task execution.
|
||||
"""Handle the reasoning/planning process for an agent before task execution.
|
||||
|
||||
This function checks if planning is enabled for the agent and, if so,
|
||||
creates a plan that gets appended to the task description.
|
||||
|
||||
Note: This function is used by CrewAgentExecutor (legacy path).
|
||||
For AgentExecutor, planning is handled in AgentExecutor.generate_plan().
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
if not agent.reasoning:
|
||||
# Check if planning is enabled using the planning_enabled property
|
||||
if not getattr(agent, "planning_enabled", False):
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -43,13 +50,13 @@ def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
planning_handler = AgentReasoning(agent=agent, task=task)
|
||||
planning_output: AgentReasoningOutput = (
|
||||
planning_handler.handle_agent_reasoning()
|
||||
)
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
task.description += f"\n\nPlanning:\n{planning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
agent._logger.log("error", f"Error during planning: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
|
||||
@@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -30,12 +30,9 @@ class CrewAgentExecutorMixin:
|
||||
memory = getattr(self.agent, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
if memory is None or not self.task or getattr(memory, "_read_only", False):
|
||||
if memory is None or not self.task or memory.read_only:
|
||||
return
|
||||
if (
|
||||
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
in output.text
|
||||
):
|
||||
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
|
||||
return
|
||||
try:
|
||||
raw = (
|
||||
@@ -48,6 +45,4 @@ class CrewAgentExecutorMixin:
|
||||
if extracted:
|
||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to save to memory: {e}"
|
||||
)
|
||||
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -755,6 +756,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
self._execute_single_native_tool_call,
|
||||
call_id=call_id,
|
||||
func_name=func_name,
|
||||
@@ -893,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
||||
args_dict, parse_error = parse_tool_call_args(
|
||||
func_args, func_name, call_id, original_tool
|
||||
)
|
||||
if parse_error is not None:
|
||||
return parse_error
|
||||
|
||||
|
||||
345
lib/crewai/src/crewai/agents/planner_observer.py
Normal file
345
lib/crewai/src/crewai/agents/planner_observer.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""PlannerObserver: Observation phase after each step execution.
|
||||
|
||||
Implements the "Observe" phase. After every step execution, the Planner
|
||||
analyzes what happened, what new information was learned, and whether the
|
||||
remaining plan is still valid.
|
||||
|
||||
This is NOT an error detector — it runs on every step, including successes,
|
||||
to incorporate runtime observations into the remaining plan.
|
||||
|
||||
Refinements are structured (StepRefinement objects) and applied directly
|
||||
from the observation result — no second LLM call required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.observation_events import (
|
||||
StepObservationCompletedEvent,
|
||||
StepObservationFailedEvent,
|
||||
StepObservationStartedEvent,
|
||||
)
|
||||
from crewai.utilities.agent_utils import extract_task_section
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.planning_types import StepObservation, TodoItem
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlannerObserver:
|
||||
"""Observes step execution results and decides on plan continuation.
|
||||
|
||||
After EVERY step execution, this class:
|
||||
1. Analyzes what the step accomplished
|
||||
2. Identifies new information learned
|
||||
3. Decides if the remaining plan is still valid
|
||||
4. Suggests lightweight refinements or triggers full replanning
|
||||
|
||||
LLM resolution (magical fallback):
|
||||
- If ``agent.planning_config.llm`` is explicitly set → use that
|
||||
- Otherwise → fall back to ``agent.llm`` (same LLM for everything)
|
||||
|
||||
Args:
|
||||
agent: The agent instance (for LLM resolution and config).
|
||||
task: Optional task context (for description and expected output).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
task: Task | None = None,
|
||||
kickoff_input: str = "",
|
||||
) -> None:
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self.kickoff_input = kickoff_input
|
||||
self.llm = self._resolve_llm()
|
||||
self._i18n: I18N = get_i18n()
|
||||
|
||||
def _resolve_llm(self) -> Any:
|
||||
"""Resolve which LLM to use for observation/planning.
|
||||
|
||||
Mirrors AgentReasoning._resolve_llm(): uses planning_config.llm
|
||||
if explicitly set, otherwise falls back to agent.llm.
|
||||
|
||||
Returns:
|
||||
The resolved LLM instance.
|
||||
"""
|
||||
from crewai.llm import LLM
|
||||
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and config.llm is not None:
|
||||
if isinstance(config.llm, LLM):
|
||||
return config.llm
|
||||
return create_llm(config.llm)
|
||||
return self.agent.llm
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def observe(
|
||||
self,
|
||||
completed_step: TodoItem,
|
||||
result: str,
|
||||
all_completed: list[TodoItem],
|
||||
remaining_todos: list[TodoItem],
|
||||
) -> StepObservation:
|
||||
"""Observe a step's result and decide on plan continuation.
|
||||
|
||||
This runs after EVERY step execution — not just failures.
|
||||
|
||||
Args:
|
||||
completed_step: The todo item that was just executed.
|
||||
result: The final result string from the step.
|
||||
all_completed: All previously completed todos (for context).
|
||||
remaining_todos: The pending todos still in the plan.
|
||||
|
||||
Returns:
|
||||
StepObservation with the Planner's analysis. Any suggested
|
||||
refinements are structured StepRefinement objects ready for
|
||||
direct application — no second LLM call needed.
|
||||
"""
|
||||
agent_role = self.agent.role
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
event=StepObservationStartedEvent(
|
||||
agent_role=agent_role,
|
||||
step_number=completed_step.step_number,
|
||||
step_description=completed_step.description,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
),
|
||||
)
|
||||
|
||||
messages = self._build_observation_messages(
|
||||
completed_step, result, all_completed, remaining_todos
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.llm.call(
|
||||
messages,
|
||||
response_model=StepObservation,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
observation = self._parse_observation_response(response)
|
||||
|
||||
refinement_summaries = (
|
||||
[
|
||||
f"Step {r.step_number}: {r.new_description}"
|
||||
for r in observation.suggested_refinements
|
||||
]
|
||||
if observation.suggested_refinements
|
||||
else None
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
event=StepObservationCompletedEvent(
|
||||
agent_role=agent_role,
|
||||
step_number=completed_step.step_number,
|
||||
step_description=completed_step.description,
|
||||
step_completed_successfully=observation.step_completed_successfully,
|
||||
key_information_learned=observation.key_information_learned,
|
||||
remaining_plan_still_valid=observation.remaining_plan_still_valid,
|
||||
needs_full_replan=observation.needs_full_replan,
|
||||
replan_reason=observation.replan_reason,
|
||||
goal_already_achieved=observation.goal_already_achieved,
|
||||
suggested_refinements=refinement_summaries,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
),
|
||||
)
|
||||
|
||||
return observation
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Observation LLM call failed: {e}. Defaulting to conservative replan."
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
event=StepObservationFailedEvent(
|
||||
agent_role=agent_role,
|
||||
step_number=completed_step.step_number,
|
||||
step_description=completed_step.description,
|
||||
error=str(e),
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
),
|
||||
)
|
||||
|
||||
# Don't force a full replan — the step may have succeeded even if the
|
||||
# observer LLM failed to parse the result. Defaulting to "continue" is
|
||||
# far less disruptive than wiping the entire plan on every observer error.
|
||||
return StepObservation(
|
||||
step_completed_successfully=True,
|
||||
key_information_learned="",
|
||||
remaining_plan_still_valid=True,
|
||||
needs_full_replan=False,
|
||||
)
|
||||
|
||||
def apply_refinements(
|
||||
self,
|
||||
observation: StepObservation,
|
||||
remaining_todos: list[TodoItem],
|
||||
) -> list[TodoItem]:
|
||||
"""Apply structured refinements from the observation directly to todo descriptions.
|
||||
|
||||
No LLM call needed — refinements are already structured StepRefinement
|
||||
objects produced by the observation call. This is a pure in-memory update.
|
||||
|
||||
Args:
|
||||
observation: The observation containing structured refinements.
|
||||
remaining_todos: The pending todos to update in-place.
|
||||
|
||||
Returns:
|
||||
The same todo list with updated descriptions where refinements applied.
|
||||
"""
|
||||
if not observation.suggested_refinements:
|
||||
return remaining_todos
|
||||
|
||||
todo_by_step: dict[int, TodoItem] = {t.step_number: t for t in remaining_todos}
|
||||
for refinement in observation.suggested_refinements:
|
||||
if refinement.step_number in todo_by_step and refinement.new_description:
|
||||
todo_by_step[
|
||||
refinement.step_number
|
||||
].description = refinement.new_description
|
||||
|
||||
return remaining_todos
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Message building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_observation_messages(
|
||||
self,
|
||||
completed_step: TodoItem,
|
||||
result: str,
|
||||
all_completed: list[TodoItem],
|
||||
remaining_todos: list[TodoItem],
|
||||
) -> list[LLMMessage]:
|
||||
"""Build messages for the observation LLM call."""
|
||||
task_desc = ""
|
||||
task_goal = ""
|
||||
if self.task:
|
||||
task_desc = self.task.description or ""
|
||||
task_goal = self.task.expected_output or ""
|
||||
elif self.kickoff_input:
|
||||
# Standalone kickoff path — no Task object, but we have the raw input.
|
||||
# Extract just the ## Task section so the observer sees the actual goal,
|
||||
# not the full enriched instruction with env/tools/verification noise.
|
||||
task_desc = extract_task_section(self.kickoff_input)
|
||||
task_goal = "Complete the task successfully"
|
||||
|
||||
system_prompt = self._i18n.retrieve("planning", "observation_system_prompt")
|
||||
|
||||
# Build context of what's been done
|
||||
completed_summary = ""
|
||||
if all_completed:
|
||||
completed_lines = []
|
||||
for todo in all_completed:
|
||||
result_preview = (todo.result or "")[:200]
|
||||
completed_lines.append(
|
||||
f" Step {todo.step_number}: {todo.description}\n"
|
||||
f" Result: {result_preview}"
|
||||
)
|
||||
completed_summary = "\n## Previously completed steps:\n" + "\n".join(
|
||||
completed_lines
|
||||
)
|
||||
|
||||
# Build remaining plan
|
||||
remaining_summary = ""
|
||||
if remaining_todos:
|
||||
remaining_lines = [
|
||||
f" Step {todo.step_number}: {todo.description}"
|
||||
for todo in remaining_todos
|
||||
]
|
||||
remaining_summary = "\n## Remaining plan steps:\n" + "\n".join(
|
||||
remaining_lines
|
||||
)
|
||||
|
||||
user_prompt = self._i18n.retrieve("planning", "observation_user_prompt").format(
|
||||
task_description=task_desc,
|
||||
task_goal=task_goal,
|
||||
completed_summary=completed_summary,
|
||||
step_number=completed_step.step_number,
|
||||
step_description=completed_step.description,
|
||||
step_result=result,
|
||||
remaining_summary=remaining_summary,
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _parse_observation_response(response: Any) -> StepObservation:
|
||||
"""Parse the LLM response into a StepObservation.
|
||||
|
||||
The LLM may return:
|
||||
- A StepObservation instance directly (streaming + litellm path)
|
||||
- A JSON string (non-streaming path serialises model_dump_json())
|
||||
- A dict (some provider paths)
|
||||
- Something else (unexpected)
|
||||
|
||||
We handle all cases to avoid silently falling back to a
|
||||
hardcoded success default.
|
||||
"""
|
||||
|
||||
if isinstance(response, StepObservation):
|
||||
return response
|
||||
|
||||
# JSON string path — most common miss before this fix
|
||||
if isinstance(response, str):
|
||||
text = response.strip()
|
||||
try:
|
||||
return StepObservation.model_validate_json(text)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
# Some LLMs wrap the JSON in markdown fences
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
# Strip first and last lines (``` markers)
|
||||
inner = "\n".join(
|
||||
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||
)
|
||||
try:
|
||||
return StepObservation.model_validate_json(inner.strip())
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
# Dict path
|
||||
if isinstance(response, dict):
|
||||
try:
|
||||
return StepObservation.model_validate(response)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
# Last resort — log what we got so it's diagnosable
|
||||
logger.warning(
|
||||
"Could not parse observation response (type=%s). "
|
||||
"Falling back to default failure observation. Preview: %.200s",
|
||||
type(response).__name__,
|
||||
str(response),
|
||||
)
|
||||
return StepObservation(
|
||||
step_completed_successfully=False,
|
||||
key_information_learned=str(response) if response else "",
|
||||
remaining_plan_still_valid=False,
|
||||
)
|
||||
629
lib/crewai/src/crewai/agents/step_executor.py
Normal file
629
lib/crewai/src/crewai/agents/step_executor.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""StepExecutor: Isolated executor for a single plan step.
|
||||
|
||||
Implements the direct-action execution pattern from Plan-and-Act
|
||||
(arxiv 2503.09572): the Executor receives one step description,
|
||||
makes a single LLM call, executes any tool call returned, and
|
||||
returns the result immediately.
|
||||
|
||||
There is no inner loop. Recovery from failure (retry, replan) is
|
||||
the responsibility of PlannerObserver and AgentExecutor — keeping
|
||||
this class single-purpose and fast.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
build_tool_calls_assistant_message,
|
||||
check_native_tool_support,
|
||||
enforce_rpm_limit,
|
||||
execute_single_native_tool_call,
|
||||
extract_task_section,
|
||||
format_message_for_llm,
|
||||
is_tool_call_list,
|
||||
process_llm_response,
|
||||
setup_native_tools,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.planning_types import TodoItem
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class StepExecutor:
|
||||
"""Executes a SINGLE todo item using direct-action execution.
|
||||
|
||||
The StepExecutor owns its own message list per invocation. It never reads
|
||||
or writes the AgentExecutor's state. Results flow back via StepResult.
|
||||
|
||||
Execution pattern (per Plan-and-Act, arxiv 2503.09572):
|
||||
1. Build messages from todo + context
|
||||
2. Call LLM once (with or without native tools)
|
||||
3. If tool call → execute it → return tool result
|
||||
4. If text answer → return it directly
|
||||
No inner loop — recovery is PlannerObserver's responsibility.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for execution.
|
||||
tools: Structured tools available to the executor.
|
||||
agent: The agent instance (for role/goal/verbose/config).
|
||||
original_tools: Original BaseTool instances (needed for native tool schema).
|
||||
tools_handler: Optional tools handler for caching and delegation tracking.
|
||||
task: Optional task context.
|
||||
crew: Optional crew context.
|
||||
function_calling_llm: Optional separate LLM for function calling.
|
||||
request_within_rpm_limit: Optional RPM limit function.
|
||||
callbacks: Optional list of callbacks.
|
||||
i18n: Optional i18n instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
tools: list[CrewStructuredTool],
|
||||
agent: Agent,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
task: Task | None = None,
|
||||
crew: Crew | None = None,
|
||||
function_calling_llm: BaseLLM | None = None,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
self.agent = agent
|
||||
self.original_tools = original_tools or []
|
||||
self.tools_handler = tools_handler
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.callbacks = callbacks or []
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self._printer: Printer = Printer()
|
||||
|
||||
# Native tool support — set up once
|
||||
self._use_native_tools = check_native_tool_support(
|
||||
self.llm, self.original_tools
|
||||
)
|
||||
self._openai_tools: list[dict[str, Any]] = []
|
||||
self._available_functions: dict[str, Callable[..., Any]] = {}
|
||||
if self._use_native_tools and self.original_tools:
|
||||
(
|
||||
self._openai_tools,
|
||||
self._available_functions,
|
||||
_,
|
||||
) = setup_native_tools(self.original_tools)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(
|
||||
self,
|
||||
todo: TodoItem,
|
||||
context: StepExecutionContext,
|
||||
max_step_iterations: int = 15,
|
||||
step_timeout: int | None = None,
|
||||
) -> StepResult:
|
||||
"""Execute a single todo item using a multi-turn action loop.
|
||||
|
||||
Enforces the RPM limit, builds a fresh message list, then iterates
|
||||
LLM call → tool execution → observation until the LLM signals it is
|
||||
done (text answer) or max_step_iterations is reached. Never touches
|
||||
external AgentExecutor state.
|
||||
|
||||
Args:
|
||||
todo: The todo item to execute.
|
||||
context: Immutable context with task info and dependency results.
|
||||
max_step_iterations: Maximum LLM iterations in the multi-turn loop.
|
||||
step_timeout: Maximum wall-clock seconds for this step. None = no limit.
|
||||
|
||||
Returns:
|
||||
StepResult with the outcome.
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
tool_calls_made: list[str] = []
|
||||
|
||||
try:
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
messages = self._build_isolated_messages(todo, context)
|
||||
|
||||
if self._use_native_tools:
|
||||
result_text = self._execute_native(
|
||||
messages,
|
||||
tool_calls_made,
|
||||
max_step_iterations=max_step_iterations,
|
||||
step_timeout=step_timeout,
|
||||
start_time=start_time,
|
||||
)
|
||||
else:
|
||||
result_text = self._execute_text_parsed(
|
||||
messages,
|
||||
tool_calls_made,
|
||||
max_step_iterations=max_step_iterations,
|
||||
step_timeout=step_timeout,
|
||||
start_time=start_time,
|
||||
)
|
||||
self._validate_expected_tool_usage(todo, tool_calls_made)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
return StepResult(
|
||||
success=True,
|
||||
result=result_text,
|
||||
tool_calls_made=tool_calls_made,
|
||||
execution_time=elapsed,
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
return StepResult(
|
||||
success=False,
|
||||
result="",
|
||||
error=str(e),
|
||||
tool_calls_made=tool_calls_made,
|
||||
execution_time=elapsed,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Message building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_isolated_messages(
|
||||
self, todo: TodoItem, context: StepExecutionContext
|
||||
) -> list[LLMMessage]:
|
||||
"""Build a fresh message list for this step's execution.
|
||||
|
||||
System prompt tells the LLM it is an Executor focused on one step.
|
||||
User prompt provides the step description, dependencies, and tools.
|
||||
"""
|
||||
system_prompt = self._build_system_prompt()
|
||||
user_prompt = self._build_user_prompt(todo, context)
|
||||
|
||||
return [
|
||||
format_message_for_llm(system_prompt, role="system"),
|
||||
format_message_for_llm(user_prompt, role="user"),
|
||||
]
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the Executor's system prompt."""
|
||||
role = self.agent.role if self.agent else "Assistant"
|
||||
goal = self.agent.goal if self.agent else "Complete tasks efficiently"
|
||||
backstory = getattr(self.agent, "backstory", "") or ""
|
||||
|
||||
tools_section = ""
|
||||
if self.tools and not self._use_native_tools:
|
||||
tool_names = ", ".join(sanitize_tool_name(t.name) for t in self.tools)
|
||||
tools_section = self._i18n.retrieve(
|
||||
"planning", "step_executor_tools_section"
|
||||
).format(tool_names=tool_names)
|
||||
elif self.tools:
|
||||
tool_names = ", ".join(sanitize_tool_name(t.name) for t in self.tools)
|
||||
tools_section = f"\n\nAvailable tools: {tool_names}"
|
||||
|
||||
return self._i18n.retrieve("planning", "step_executor_system_prompt").format(
|
||||
role=role,
|
||||
backstory=backstory,
|
||||
goal=goal,
|
||||
tools_section=tools_section,
|
||||
)
|
||||
|
||||
def _build_user_prompt(self, todo: TodoItem, context: StepExecutionContext) -> str:
|
||||
"""Build the user prompt for this specific step."""
|
||||
parts: list[str] = []
|
||||
|
||||
# Include overall task context so the executor knows the full goal and
|
||||
# required output format/location — critical for knowing WHAT to produce.
|
||||
# We extract only the task body (not tool instructions or verification
|
||||
# sections) to avoid duplicating directives already in the system prompt.
|
||||
if context.task_description:
|
||||
task_section = extract_task_section(context.task_description)
|
||||
if task_section:
|
||||
parts.append(
|
||||
self._i18n.retrieve(
|
||||
"planning", "step_executor_task_context"
|
||||
).format(
|
||||
task_context=task_section,
|
||||
)
|
||||
)
|
||||
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_user_prompt").format(
|
||||
step_description=todo.description,
|
||||
)
|
||||
)
|
||||
|
||||
if todo.tool_to_use:
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_suggested_tool").format(
|
||||
tool_to_use=todo.tool_to_use,
|
||||
)
|
||||
)
|
||||
|
||||
# Include dependency results (final results only, no traces)
|
||||
if context.dependency_results:
|
||||
parts.append(
|
||||
self._i18n.retrieve("planning", "step_executor_context_header")
|
||||
)
|
||||
for step_num, result in sorted(context.dependency_results.items()):
|
||||
parts.append(
|
||||
self._i18n.retrieve(
|
||||
"planning", "step_executor_context_entry"
|
||||
).format(step_number=step_num, result=result)
|
||||
)
|
||||
|
||||
parts.append(self._i18n.retrieve("planning", "step_executor_complete_step"))
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Multi-turn execution loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_text_parsed(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
tool_calls_made: list[str],
|
||||
max_step_iterations: int = 15,
|
||||
step_timeout: int | None = None,
|
||||
start_time: float | None = None,
|
||||
) -> str:
|
||||
"""Execute step using text-parsed tool calling with a multi-turn loop.
|
||||
|
||||
Iterates LLM call → tool execution → observation until the LLM
|
||||
produces a Final Answer or max_step_iterations is reached.
|
||||
This allows the agent to: run a command, see the output, adjust its
|
||||
approach, and run another command — all within a single plan step.
|
||||
"""
|
||||
use_stop_words = self.llm.supports_stop_words() if self.llm else False
|
||||
last_tool_result = ""
|
||||
|
||||
for _ in range(max_step_iterations):
|
||||
# Check step timeout
|
||||
if step_timeout and start_time:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= step_timeout:
|
||||
return last_tool_result or f"Step timed out after {elapsed:.0f}s"
|
||||
answer = self.llm.call(
|
||||
messages,
|
||||
callbacks=self.callbacks,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
if not answer:
|
||||
raise ValueError("Empty response from LLM")
|
||||
|
||||
answer_str = str(answer)
|
||||
formatted = process_llm_response(answer_str, use_stop_words)
|
||||
|
||||
if isinstance(formatted, AgentFinish):
|
||||
return str(formatted.output)
|
||||
|
||||
if isinstance(formatted, AgentAction):
|
||||
tool_calls_made.append(formatted.tool)
|
||||
tool_result = self._execute_text_tool_with_events(formatted)
|
||||
last_tool_result = tool_result
|
||||
# Append the assistant's reasoning + action, then the observation.
|
||||
# _build_observation_message handles vision sentinels so the LLM
|
||||
# receives an image content block instead of raw base64 text.
|
||||
messages.append({"role": "assistant", "content": answer_str})
|
||||
messages.append(self._build_observation_message(tool_result))
|
||||
continue
|
||||
|
||||
# Raw text response with no Final Answer marker — treat as done
|
||||
return answer_str
|
||||
|
||||
# Max iterations reached — return the last tool result we accumulated
|
||||
return last_tool_result
|
||||
|
||||
def _execute_text_tool_with_events(self, formatted: AgentAction) -> str:
|
||||
"""Execute text-parsed tool calls with tool usage events."""
|
||||
args_dict = self._parse_tool_args(formatted.tool_input)
|
||||
agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown"
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=formatted.tool,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(self.agent.security_config.fingerprint)
|
||||
}
|
||||
|
||||
tool_result = execute_tool_and_check_finality(
|
||||
agent_action=formatted,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
crew=self.crew,
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=formatted.tool,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=str(tool_result.result),
|
||||
tool_name=formatted.tool,
|
||||
tool_args=args_dict,
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
agent_key=agent_key,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
return str(tool_result.result)
|
||||
|
||||
def _parse_tool_args(self, tool_input: Any) -> dict[str, Any]:
|
||||
"""Parse tool args from the parser output into a dict payload for events."""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input
|
||||
if isinstance(tool_input, str):
|
||||
stripped_input = tool_input.strip()
|
||||
if not stripped_input:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(stripped_input)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return {"input": parsed}
|
||||
except json.JSONDecodeError:
|
||||
return {"input": stripped_input}
|
||||
return {"input": str(tool_input)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal: Vision support
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _parse_vision_sentinel(raw: str) -> tuple[str, str] | None:
|
||||
"""Parse a VISION_IMAGE sentinel into (media_type, base64_data), or None."""
|
||||
prefix = "VISION_IMAGE:"
|
||||
if not raw.startswith(prefix):
|
||||
return None
|
||||
rest = raw[len(prefix) :]
|
||||
sep = rest.find(":")
|
||||
if sep <= 0:
|
||||
return None
|
||||
return rest[:sep], rest[sep + 1 :]
|
||||
|
||||
@staticmethod
|
||||
def _build_observation_message(tool_result: str) -> LLMMessage:
|
||||
"""Build an observation message, converting vision sentinels to image blocks.
|
||||
|
||||
When a tool returns a VISION_IMAGE sentinel (e.g. from read_image),
|
||||
we build a multimodal content block so the LLM can actually *see*
|
||||
the image rather than receiving a wall of base64 text.
|
||||
|
||||
Uses the standard image_url / data-URI format so each LLM provider's
|
||||
SDK (OpenAI, LiteLLM, etc.) handles the provider-specific conversion.
|
||||
|
||||
Format: ``VISION_IMAGE:<media_type>:<base64_data>``
|
||||
"""
|
||||
parsed = StepExecutor._parse_vision_sentinel(tool_result)
|
||||
if parsed:
|
||||
media_type, b64_data = parsed
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Observation: Here is the image:"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{media_type};base64,{b64_data}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
return {"role": "user", "content": f"Observation: {tool_result}"}
|
||||
|
||||
def _validate_expected_tool_usage(
|
||||
self,
|
||||
todo: TodoItem,
|
||||
tool_calls_made: list[str],
|
||||
) -> None:
|
||||
"""Fail step execution when a required tool is configured but not called."""
|
||||
expected_tool = getattr(todo, "tool_to_use", None)
|
||||
if not expected_tool:
|
||||
return
|
||||
expected_tool_name = sanitize_tool_name(expected_tool)
|
||||
available_tool_names = {
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in self.tools
|
||||
if getattr(tool, "name", "")
|
||||
} | set(self._available_functions.keys())
|
||||
if expected_tool_name not in available_tool_names:
|
||||
return
|
||||
called_names = {sanitize_tool_name(name) for name in tool_calls_made}
|
||||
if expected_tool_name not in called_names:
|
||||
raise ValueError(
|
||||
f"Expected tool '{expected_tool_name}' was not called "
|
||||
f"for step {todo.step_number}."
|
||||
)
|
||||
|
||||
def _execute_native(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
tool_calls_made: list[str],
|
||||
max_step_iterations: int = 15,
|
||||
step_timeout: int | None = None,
|
||||
start_time: float | None = None,
|
||||
) -> str:
|
||||
"""Execute step using native function calling with a multi-turn loop.
|
||||
|
||||
Iterates LLM call → tool execution → appended results until the LLM
|
||||
returns a text answer (no more tool calls) or max_step_iterations is
|
||||
reached. This lets the agent run a shell command, observe the output,
|
||||
correct mistakes, and issue follow-up commands — all within one step.
|
||||
"""
|
||||
accumulated_results: list[str] = []
|
||||
|
||||
for _ in range(max_step_iterations):
|
||||
# Check step timeout
|
||||
if step_timeout and start_time:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= step_timeout:
|
||||
return (
|
||||
"\n\n".join(accumulated_results)
|
||||
if accumulated_results
|
||||
else f"Step timed out after {elapsed:.0f}s"
|
||||
)
|
||||
answer = self.llm.call(
|
||||
messages,
|
||||
tools=self._openai_tools,
|
||||
callbacks=self.callbacks,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
if not answer:
|
||||
raise ValueError("Empty response from LLM")
|
||||
|
||||
if isinstance(answer, BaseModel):
|
||||
return answer.model_dump_json()
|
||||
|
||||
if isinstance(answer, list) and answer and is_tool_call_list(answer):
|
||||
# _execute_native_tool_calls appends assistant + tool messages
|
||||
# to `messages` as a side-effect, so the next LLM call will
|
||||
# see the full conversation history including tool outputs.
|
||||
result = self._execute_native_tool_calls(
|
||||
answer, messages, tool_calls_made
|
||||
)
|
||||
accumulated_results.append(result)
|
||||
continue
|
||||
|
||||
# Text answer → LLM decided the step is done
|
||||
return str(answer)
|
||||
|
||||
# Max iterations reached — return everything we accumulated
|
||||
return "\n".join(filter(None, accumulated_results))
|
||||
|
||||
def _execute_native_tool_calls(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
messages: list[LLMMessage],
|
||||
tool_calls_made: list[str],
|
||||
) -> str:
|
||||
"""Execute a batch of native tool calls and return their results.
|
||||
|
||||
Returns the result of the first tool marked result_as_answer if any,
|
||||
otherwise returns all tool results concatenated.
|
||||
"""
|
||||
assistant_message, _reports = build_tool_calls_assistant_message(tool_calls)
|
||||
if assistant_message:
|
||||
messages.append(assistant_message)
|
||||
|
||||
tool_results: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
call_result = execute_single_native_tool_call(
|
||||
tool_call,
|
||||
available_functions=self._available_functions,
|
||||
original_tools=self.original_tools,
|
||||
structured_tools=self.tools,
|
||||
tools_handler=self.tools_handler,
|
||||
agent=self.agent,
|
||||
task=self.task,
|
||||
crew=self.crew,
|
||||
event_source=self,
|
||||
printer=self._printer,
|
||||
verbose=bool(self.agent and self.agent.verbose),
|
||||
)
|
||||
|
||||
if call_result.func_name:
|
||||
tool_calls_made.append(call_result.func_name)
|
||||
|
||||
if call_result.result_as_answer:
|
||||
return str(call_result.result)
|
||||
|
||||
if call_result.tool_message:
|
||||
raw_content = call_result.tool_message.get("content", "")
|
||||
if isinstance(raw_content, str):
|
||||
parsed = self._parse_vision_sentinel(raw_content)
|
||||
if parsed:
|
||||
media_type, b64_data = parsed
|
||||
# Replace the sentinel with a standard image_url content block.
|
||||
# Each provider's _format_messages handles conversion to
|
||||
# its native format (e.g. Anthropic image blocks).
|
||||
modified: LLMMessage = cast(
|
||||
LLMMessage, dict(call_result.tool_message)
|
||||
)
|
||||
modified["content"] = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{media_type};base64,{b64_data}",
|
||||
},
|
||||
}
|
||||
]
|
||||
messages.append(modified)
|
||||
tool_results.append("[image]")
|
||||
else:
|
||||
messages.append(call_result.tool_message)
|
||||
if raw_content:
|
||||
tool_results.append(raw_content)
|
||||
else:
|
||||
messages.append(call_result.tool_message)
|
||||
if raw_content:
|
||||
tool_results.append(str(raw_content))
|
||||
|
||||
return "\n".join(tool_results) if tool_results else ""
|
||||
@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
|
||||
@crewai.command()
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l", "--long", is_flag=True, hidden=True,
|
||||
"-l",
|
||||
"--long",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s", "--short", is_flag=True, hidden=True,
|
||||
"-s",
|
||||
"--short",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e", "--entities", is_flag=True, hidden=True,
|
||||
"-e",
|
||||
"--entities",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@@ -218,7 +227,13 @@ def reset_memories(
|
||||
# Treat legacy flags as --memory with a deprecation warning
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
||||
f
|
||||
for f, v in [
|
||||
("--long", long),
|
||||
("--short", short),
|
||||
("--entities", entities),
|
||||
]
|
||||
if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
@@ -238,9 +253,7 @@ def reset_memories(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
reset_memories_command(
|
||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
@@ -669,18 +682,11 @@ def traces_enable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to enable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = True
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": True, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"✅ Trace collection has been enabled!\n\n"
|
||||
@@ -699,18 +705,11 @@ def traces_disable():
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
_load_user_data,
|
||||
_save_user_data,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import update_user_data
|
||||
|
||||
console = Console()
|
||||
|
||||
# Update user data to disable traces
|
||||
user_data = _load_user_data()
|
||||
user_data["trace_consent"] = False
|
||||
user_data["first_execution_done"] = True
|
||||
_save_user_data(user_data)
|
||||
update_user_data({"trace_consent": False, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"❌ Trace collection has been disabled!\n\n"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import json
|
||||
from pathlib import Path
|
||||
import platform
|
||||
@@ -80,7 +81,10 @@ def run_chat() -> None:
|
||||
|
||||
# Start loading indicator
|
||||
loading_complete = threading.Event()
|
||||
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
|
||||
ctx = contextvars.copy_context()
|
||||
loading_thread = threading.Thread(
|
||||
target=ctx.run, args=(show_loading, loading_complete)
|
||||
)
|
||||
loading_thread.start()
|
||||
|
||||
try:
|
||||
|
||||
@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
storage = (
|
||||
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
)
|
||||
embedder = None
|
||||
if embedder_config is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(embedder_config)
|
||||
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
|
||||
self._memory = (
|
||||
Memory(storage=storage, embedder=embedder)
|
||||
if embedder
|
||||
else Memory(storage=storage)
|
||||
)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
|
||||
if len(record.content) > 80
|
||||
else record.content
|
||||
)
|
||||
label = (
|
||||
f"{date_str} "
|
||||
f"[bold]{record.importance:.1f}[/] "
|
||||
f"{preview}"
|
||||
)
|
||||
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||
option_list.add_option(label)
|
||||
|
||||
def _populate_recall_list(self) -> None:
|
||||
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
|
||||
else m.record.content
|
||||
)
|
||||
label = (
|
||||
f"[bold]\\[{m.score:.2f}][/] "
|
||||
f"{preview} "
|
||||
f"[dim]scope={m.record.scope}[/]"
|
||||
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
|
||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||
lines.append(
|
||||
f"[dim]Created:[/] "
|
||||
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Last accessed:[/] "
|
||||
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = (
|
||||
self._selected_scope
|
||||
if self._selected_scope != "/"
|
||||
else None
|
||||
)
|
||||
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||
loop = asyncio.get_event_loop()
|
||||
matches = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._memory.recall(
|
||||
query, scope=scope, limit=10, depth="deep"
|
||||
),
|
||||
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||
)
|
||||
self._recall_matches = matches or []
|
||||
self._view_mode = "recall"
|
||||
|
||||
@@ -95,9 +95,7 @@ def reset_memories_command(
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.11.0rc1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.11.0rc1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.11.0rc1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for search_path in search_paths:
|
||||
for root, dirs, files in os.walk(search_path):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
]
|
||||
if flow_path in files and "cli/templates" not in root:
|
||||
file_os_path = os.path.join(root, flow_path)
|
||||
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
try:
|
||||
if flow_instance := get_flow_instance(
|
||||
module_attr
|
||||
):
|
||||
if flow_instance := get_flow_instance(module_attr):
|
||||
flow_instances.append(flow_instance)
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
@@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return tools
|
||||
|
||||
def _add_memory_tools(
|
||||
self, tools: list[BaseTool], memory: Any
|
||||
) -> list[BaseTool]:
|
||||
def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]:
|
||||
"""Add recall and remember tools when memory is available.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -75,6 +75,14 @@ from crewai.events.types.mcp_events import (
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.observation_events import (
|
||||
GoalAchievedEarlyEvent,
|
||||
PlanRefinementEvent,
|
||||
PlanReplanTriggeredEvent,
|
||||
StepObservationCompletedEvent,
|
||||
StepObservationFailedEvent,
|
||||
StepObservationStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
@@ -535,6 +543,64 @@ class EventListener(BaseEventListener):
|
||||
event.error,
|
||||
)
|
||||
|
||||
# ----------- OBSERVATION EVENTS (Plan-and-Execute) -----------
|
||||
|
||||
@crewai_event_bus.on(StepObservationStartedEvent)
|
||||
def on_step_observation_started(
|
||||
_: Any, event: StepObservationStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_observation_started(
|
||||
event.agent_role,
|
||||
event.step_number,
|
||||
event.step_description,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(StepObservationCompletedEvent)
|
||||
def on_step_observation_completed(
|
||||
_: Any, event: StepObservationCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_observation_completed(
|
||||
event.agent_role,
|
||||
event.step_number,
|
||||
event.step_completed_successfully,
|
||||
event.remaining_plan_still_valid,
|
||||
event.key_information_learned,
|
||||
event.needs_full_replan,
|
||||
event.goal_already_achieved,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(StepObservationFailedEvent)
|
||||
def on_step_observation_failed(
|
||||
_: Any, event: StepObservationFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_observation_failed(
|
||||
event.step_number,
|
||||
event.error,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(PlanRefinementEvent)
|
||||
def on_plan_refinement(_: Any, event: PlanRefinementEvent) -> None:
|
||||
self.formatter.handle_plan_refinement(
|
||||
event.step_number,
|
||||
event.refined_step_count,
|
||||
event.refinements,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(PlanReplanTriggeredEvent)
|
||||
def on_plan_replan_triggered(_: Any, event: PlanReplanTriggeredEvent) -> None:
|
||||
self.formatter.handle_plan_replan(
|
||||
event.replan_reason,
|
||||
event.replan_count,
|
||||
event.completed_steps_preserved,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(GoalAchievedEarlyEvent)
|
||||
def on_goal_achieved_early(_: Any, event: GoalAchievedEarlyEvent) -> None:
|
||||
self.formatter.handle_goal_achieved_early(
|
||||
event.steps_completed,
|
||||
event.steps_remaining,
|
||||
)
|
||||
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentLogsStartedEvent)
|
||||
|
||||
@@ -93,6 +93,14 @@ from crewai.events.types.memory_events import (
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.observation_events import (
|
||||
GoalAchievedEarlyEvent,
|
||||
PlanRefinementEvent,
|
||||
PlanReplanTriggeredEvent,
|
||||
StepObservationCompletedEvent,
|
||||
StepObservationFailedEvent,
|
||||
StepObservationStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
@@ -437,6 +445,39 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
# Observation events (Plan-and-Execute)
|
||||
@event_bus.on(StepObservationStartedEvent)
|
||||
def on_step_observation_started(
|
||||
source: Any, event: StepObservationStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("step_observation_started", source, event)
|
||||
|
||||
@event_bus.on(StepObservationCompletedEvent)
|
||||
def on_step_observation_completed(
|
||||
source: Any, event: StepObservationCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("step_observation_completed", source, event)
|
||||
|
||||
@event_bus.on(StepObservationFailedEvent)
|
||||
def on_step_observation_failed(
|
||||
source: Any, event: StepObservationFailedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("step_observation_failed", source, event)
|
||||
|
||||
@event_bus.on(PlanRefinementEvent)
|
||||
def on_plan_refinement(source: Any, event: PlanRefinementEvent) -> None:
|
||||
self._handle_action_event("plan_refinement", source, event)
|
||||
|
||||
@event_bus.on(PlanReplanTriggeredEvent)
|
||||
def on_plan_replan_triggered(
|
||||
source: Any, event: PlanReplanTriggeredEvent
|
||||
) -> None:
|
||||
self._handle_action_event("plan_replan_triggered", source, event)
|
||||
|
||||
@event_bus.on(GoalAchievedEarlyEvent)
|
||||
def on_goal_achieved_early(source: Any, event: GoalAchievedEarlyEvent) -> None:
|
||||
self._handle_action_event("goal_achieved_early", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(
|
||||
source: Any, event: KnowledgeRetrievalStartedEvent
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from contextvars import ContextVar, Token
|
||||
from datetime import datetime
|
||||
import getpass
|
||||
@@ -18,6 +19,7 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
@@ -137,12 +139,25 @@ def _load_user_data() -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict[str, Any]) -> None:
|
||||
def _user_data_lock_name() -> str:
|
||||
"""Return a stable lock name for the user data file."""
|
||||
return f"file:{os.path.realpath(_user_data_file())}"
|
||||
|
||||
|
||||
def update_user_data(updates: dict[str, Any]) -> None:
|
||||
"""Atomically read-modify-write the user data file.
|
||||
|
||||
Args:
|
||||
updates: Key-value pairs to merge into the existing user data.
|
||||
"""
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
data.update(updates)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Failed to save user data: {e}")
|
||||
logger.warning(f"Failed to update user data: {e}")
|
||||
|
||||
|
||||
def has_user_declined_tracing() -> bool:
|
||||
@@ -357,24 +372,30 @@ def _get_generic_system_id() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
def _generate_user_id() -> str:
|
||||
"""Compute an anonymized user identifier from username and machine ID."""
|
||||
try:
|
||||
username = getpass.getuser()
|
||||
except Exception:
|
||||
username = "unknown"
|
||||
|
||||
seed = f"{username}|{_get_machine_id()}"
|
||||
uid = hashlib.sha256(seed.encode()).hexdigest()
|
||||
return hashlib.sha256(seed.encode()).hexdigest()
|
||||
|
||||
data["user_id"] = uid
|
||||
_save_user_data(data)
|
||||
return uid
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return cast(str, data["user_id"])
|
||||
|
||||
uid = _generate_user_id()
|
||||
data["user_id"] = uid
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
return uid
|
||||
|
||||
|
||||
def is_first_execution() -> bool:
|
||||
@@ -389,20 +410,23 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
|
||||
Args:
|
||||
user_consented: Whether the user consented to trace collection.
|
||||
"""
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
with store_lock(_user_data_lock_name()):
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": get_user_id(),
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
uid = data.get("user_id") or _generate_user_id()
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": uid,
|
||||
"machine_id": _get_machine_id(),
|
||||
"trace_consent": user_consented,
|
||||
}
|
||||
)
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
@@ -509,7 +533,8 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
# Handle all input-related errors silently
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
ctx = contextvars.copy_context()
|
||||
input_thread = threading.Thread(target=ctx.run, args=(get_input,), daemon=True)
|
||||
input_thread.start()
|
||||
input_thread.join(timeout=timeout_seconds)
|
||||
|
||||
|
||||
99
lib/crewai/src/crewai/events/types/observation_events.py
Normal file
99
lib/crewai/src/crewai/events/types/observation_events.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Observation events for the Plan-and-Execute architecture.
|
||||
|
||||
Emitted during the Observation phase (PLAN-AND-ACT Section 3.3) when the
|
||||
PlannerObserver analyzes step execution results and decides on plan
|
||||
continuation, refinement, or replanning.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class ObservationEvent(BaseEvent):
|
||||
"""Base event for observation phase events."""
|
||||
|
||||
type: str
|
||||
agent_role: str
|
||||
step_number: int
|
||||
step_description: str = ""
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_task_params(data)
|
||||
self._set_agent_params(data)
|
||||
|
||||
|
||||
class StepObservationStartedEvent(ObservationEvent):
|
||||
"""Emitted when the Planner begins observing a step's result.
|
||||
|
||||
Fires after every step execution, before the observation LLM call.
|
||||
"""
|
||||
|
||||
type: str = "step_observation_started"
|
||||
|
||||
|
||||
class StepObservationCompletedEvent(ObservationEvent):
|
||||
"""Emitted when the Planner finishes observing a step's result.
|
||||
|
||||
Contains the full observation analysis: what was learned, whether
|
||||
the plan is still valid, and what action to take next.
|
||||
"""
|
||||
|
||||
type: str = "step_observation_completed"
|
||||
step_completed_successfully: bool = True
|
||||
key_information_learned: str = ""
|
||||
remaining_plan_still_valid: bool = True
|
||||
needs_full_replan: bool = False
|
||||
replan_reason: str | None = None
|
||||
goal_already_achieved: bool = False
|
||||
suggested_refinements: list[str] | None = None
|
||||
|
||||
|
||||
class StepObservationFailedEvent(ObservationEvent):
|
||||
"""Emitted when the observation LLM call itself fails.
|
||||
|
||||
The system defaults to continuing the plan when this happens,
|
||||
but the event allows monitoring/alerting on observation failures.
|
||||
"""
|
||||
|
||||
type: str = "step_observation_failed"
|
||||
error: str = ""
|
||||
|
||||
|
||||
class PlanRefinementEvent(ObservationEvent):
|
||||
"""Emitted when the Planner refines upcoming step descriptions.
|
||||
|
||||
This is the lightweight refinement path — no full replan, just
|
||||
sharpening pending todo descriptions based on new information.
|
||||
"""
|
||||
|
||||
type: str = "plan_refinement"
|
||||
refined_step_count: int = 0
|
||||
refinements: list[str] | None = None
|
||||
|
||||
|
||||
class PlanReplanTriggeredEvent(ObservationEvent):
|
||||
"""Emitted when the Planner triggers a full replan.
|
||||
|
||||
The remaining plan was deemed fundamentally wrong and will be
|
||||
regenerated from scratch, preserving completed step results.
|
||||
"""
|
||||
|
||||
type: str = "plan_replan_triggered"
|
||||
replan_reason: str = ""
|
||||
replan_count: int = 0
|
||||
completed_steps_preserved: int = 0
|
||||
|
||||
|
||||
class GoalAchievedEarlyEvent(ObservationEvent):
|
||||
"""Emitted when the Planner detects the goal was achieved early.
|
||||
|
||||
Remaining steps will be skipped and execution will finalize.
|
||||
"""
|
||||
|
||||
type: str = "goal_achieved_early"
|
||||
steps_remaining: int = 0
|
||||
steps_completed: int = 0
|
||||
@@ -9,7 +9,7 @@ class ReasoningEvent(BaseEvent):
|
||||
type: str
|
||||
attempt: int = 1
|
||||
agent_role: str
|
||||
task_id: str
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
@@ -43,6 +43,7 @@ def should_suppress_console_output() -> bool:
|
||||
|
||||
class ConsoleFormatter:
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
_tool_counts_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
@@ -445,9 +446,11 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
# Update tool usage count
|
||||
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
with self._tool_counts_lock:
|
||||
self.tool_usage_counts[tool_name] = (
|
||||
self.tool_usage_counts.get(tool_name, 0) + 1
|
||||
)
|
||||
iteration = self.tool_usage_counts[tool_name]
|
||||
|
||||
content = Text()
|
||||
content.append("Tool: ", style="white")
|
||||
@@ -474,7 +477,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Completed\n", style="green bold")
|
||||
@@ -500,7 +504,8 @@ To enable tracing, do any one of these:
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
with self._tool_counts_lock:
|
||||
iteration = self.tool_usage_counts.get(tool_name, 1)
|
||||
|
||||
content = Text()
|
||||
content.append("Tool Failed\n", style="red bold")
|
||||
@@ -936,6 +941,152 @@ To enable tracing, do any one of these:
|
||||
)
|
||||
self.print_panel(error_content, "❌ Reasoning Error", "red")
|
||||
|
||||
# ----------- OBSERVATION EVENTS (Plan-and-Execute) -----------
|
||||
|
||||
def handle_observation_started(
|
||||
self,
|
||||
agent_role: str,
|
||||
step_number: int,
|
||||
step_description: str,
|
||||
) -> None:
|
||||
"""Handle step observation started event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("Observation Started\n", style="cyan bold")
|
||||
content.append("Agent: ", style="white")
|
||||
content.append(f"{agent_role}\n", style="cyan")
|
||||
content.append("Step: ", style="white")
|
||||
content.append(f"{step_number}\n", style="cyan")
|
||||
if step_description:
|
||||
desc_preview = step_description[:80] + (
|
||||
"..." if len(step_description) > 80 else ""
|
||||
)
|
||||
content.append("Description: ", style="white")
|
||||
content.append(f"{desc_preview}\n", style="cyan")
|
||||
|
||||
self.print_panel(content, "🔍 Observing Step Result", "cyan")
|
||||
|
||||
def handle_observation_completed(
|
||||
self,
|
||||
agent_role: str,
|
||||
step_number: int,
|
||||
step_completed: bool,
|
||||
plan_valid: bool,
|
||||
key_info: str,
|
||||
needs_replan: bool,
|
||||
goal_achieved: bool,
|
||||
) -> None:
|
||||
"""Handle step observation completed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if goal_achieved:
|
||||
style = "green"
|
||||
status = "Goal Achieved Early"
|
||||
elif needs_replan:
|
||||
style = "yellow"
|
||||
status = "Replan Needed"
|
||||
elif plan_valid:
|
||||
style = "green"
|
||||
status = "Plan Valid — Continue"
|
||||
else:
|
||||
style = "red"
|
||||
status = "Step Failed"
|
||||
|
||||
content = Text()
|
||||
content.append("Observation Complete\n", style=f"{style} bold")
|
||||
content.append("Step: ", style="white")
|
||||
content.append(f"{step_number}\n", style=style)
|
||||
content.append("Status: ", style="white")
|
||||
content.append(f"{status}\n", style=style)
|
||||
if key_info:
|
||||
info_preview = key_info[:120] + ("..." if len(key_info) > 120 else "")
|
||||
content.append("Learned: ", style="white")
|
||||
content.append(f"{info_preview}\n", style=style)
|
||||
|
||||
self.print_panel(content, "🔍 Observation Result", style)
|
||||
|
||||
def handle_observation_failed(
|
||||
self,
|
||||
step_number: int,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Handle step observation failure event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
error_content = self.create_status_content(
|
||||
"Observation Failed",
|
||||
"Error",
|
||||
"red",
|
||||
Step=str(step_number),
|
||||
Error=error,
|
||||
)
|
||||
self.print_panel(error_content, "❌ Observation Error", "red")
|
||||
|
||||
def handle_plan_refinement(
|
||||
self,
|
||||
step_number: int,
|
||||
refined_count: int,
|
||||
refinements: list[str] | None,
|
||||
) -> None:
|
||||
"""Handle plan refinement event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("Plan Refined\n", style="cyan bold")
|
||||
content.append("After Step: ", style="white")
|
||||
content.append(f"{step_number}\n", style="cyan")
|
||||
content.append("Steps Updated: ", style="white")
|
||||
content.append(f"{refined_count}\n", style="cyan")
|
||||
if refinements:
|
||||
for r in refinements[:3]:
|
||||
content.append(f" • {r[:80]}\n", style="white")
|
||||
|
||||
self.print_panel(content, "✏️ Plan Refinement", "cyan")
|
||||
|
||||
def handle_plan_replan(
|
||||
self,
|
||||
reason: str,
|
||||
replan_count: int,
|
||||
preserved_count: int,
|
||||
) -> None:
|
||||
"""Handle plan replan triggered event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("Full Replan Triggered\n", style="yellow bold")
|
||||
content.append("Reason: ", style="white")
|
||||
content.append(f"{reason}\n", style="yellow")
|
||||
content.append("Replan #: ", style="white")
|
||||
content.append(f"{replan_count}\n", style="yellow")
|
||||
content.append("Preserved Steps: ", style="white")
|
||||
content.append(f"{preserved_count}\n", style="yellow")
|
||||
|
||||
self.print_panel(content, "🔄 Dynamic Replan", "yellow")
|
||||
|
||||
def handle_goal_achieved_early(
|
||||
self,
|
||||
steps_completed: int,
|
||||
steps_remaining: int,
|
||||
) -> None:
|
||||
"""Handle goal achieved early event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("Goal Achieved Early!\n", style="green bold")
|
||||
content.append("Completed: ", style="white")
|
||||
content.append(f"{steps_completed} steps\n", style="green")
|
||||
content.append("Skipped: ", style="white")
|
||||
content.append(f"{steps_remaining} remaining steps\n", style="green")
|
||||
|
||||
self.print_panel(content, "🎯 Early Goal Achievement", "green")
|
||||
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
def handle_agent_logs_started(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow.async_feedback import ConsoleProvider
|
||||
|
||||
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
provider=ConsoleProvider(),
|
||||
@@ -46,6 +47,7 @@ class ConsoleProvider:
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
|
||||
@@ -17,6 +17,7 @@ from collections.abc import (
|
||||
ValuesView,
|
||||
)
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
import copy
|
||||
import enum
|
||||
import inspect
|
||||
@@ -497,6 +498,52 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._list)
|
||||
|
||||
def index(
|
||||
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
||||
) -> int: # type: ignore[override]
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
|
||||
def count(self, value: T) -> int:
|
||||
return self._list.count(value)
|
||||
|
||||
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._list.sort(key=key, reverse=reverse)
|
||||
|
||||
def reverse(self) -> None:
|
||||
with self._lock:
|
||||
self._list.reverse()
|
||||
|
||||
def copy(self) -> list[T]:
|
||||
return self._list.copy()
|
||||
|
||||
def __add__(self, other: list[T]) -> list[T]:
|
||||
return self._list + other
|
||||
|
||||
def __radd__(self, other: list[T]) -> list[T]:
|
||||
return other + self._list
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
|
||||
with self._lock:
|
||||
self._list += list(other)
|
||||
return self
|
||||
|
||||
def __mul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __rmul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]:
|
||||
with self._lock:
|
||||
self._list *= n
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[T]:
|
||||
return reversed(self._list)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying list contents."""
|
||||
if isinstance(other, LockedListProxy):
|
||||
@@ -579,6 +626,23 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._dict)
|
||||
|
||||
def copy(self) -> dict[str, T]:
|
||||
return self._dict.copy()
|
||||
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return self._dict | other
|
||||
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return other | self._dict
|
||||
|
||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
|
||||
with self._lock:
|
||||
self._dict |= other
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[str]:
|
||||
return reversed(self._dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying dict contents."""
|
||||
if isinstance(other, LockedDictProxy):
|
||||
@@ -620,6 +684,10 @@ class StateProxy(Generic[T]):
|
||||
if name in ("_proxy_state", "_proxy_lock"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if isinstance(value, LockedListProxy):
|
||||
value = value._list
|
||||
elif isinstance(value, LockedDictProxy):
|
||||
value = value._dict
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
setattr(object.__getattribute__(self, "_proxy_state"), name, value)
|
||||
|
||||
@@ -1746,8 +1814,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, _run_flow()).result()
|
||||
return pool.submit(ctx.run, asyncio.run, _run_flow()).result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_run_flow())
|
||||
|
||||
@@ -2171,8 +2240,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
# Run sync methods in thread pool for isolation
|
||||
# This allows Agent.kickoff() to work synchronously inside Flow methods
|
||||
import contextvars
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
|
||||
finally:
|
||||
@@ -2649,7 +2716,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if not isinstance(e, HumanFeedbackPending):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
if not getattr(e, "_flow_listener_logged", False):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
e._flow_listener_logged = True # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
# ── User Input (self.ask) ────────────────────────────────────────
|
||||
@@ -2791,8 +2860,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Manual executor management to avoid shutdown(wait=True)
|
||||
# deadlock when the provider call outlives the timeout.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
ctx = contextvars.copy_context()
|
||||
future = executor.submit(
|
||||
provider.request_input, message, self, metadata
|
||||
ctx.run, provider.request_input, message, self, metadata
|
||||
)
|
||||
try:
|
||||
raw = future.result(timeout=timeout)
|
||||
|
||||
@@ -188,7 +188,7 @@ def human_feedback(
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider: HumanFeedbackProvider | None = None,
|
||||
learn: bool = False,
|
||||
learn_source: str = "hitl"
|
||||
learn_source: str = "hitl",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator for Flow methods that require human feedback.
|
||||
|
||||
@@ -328,9 +328,7 @@ def human_feedback(
|
||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||
try:
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = flow_instance.memory.recall(
|
||||
query, source=learn_source
|
||||
)
|
||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
@@ -341,7 +339,10 @@ def human_feedback(
|
||||
lessons=lessons,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||
@@ -366,7 +367,10 @@ def human_feedback(
|
||||
feedback=raw_feedback,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
|
||||
{
|
||||
"role": "system",
|
||||
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -408,7 +412,7 @@ def human_feedback(
|
||||
emit=list(emit) if emit else None,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata or {},
|
||||
llm=llm if isinstance(llm, str) else None,
|
||||
llm=llm if isinstance(llm, str) else getattr(llm, "model", None),
|
||||
)
|
||||
|
||||
# Determine effective provider:
|
||||
@@ -487,7 +491,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -507,7 +515,11 @@ def human_feedback(
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
and raw_feedback.strip()
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
return result
|
||||
@@ -534,7 +546,7 @@ def human_feedback(
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
learn=learn,
|
||||
learn_source=learn_source
|
||||
learn_source=learn_source,
|
||||
)
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
"""SQLite-based implementation of flow state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -13,6 +12,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -68,11 +68,16 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self.init_db()
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -113,6 +118,49 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""
|
||||
)
|
||||
|
||||
def _save_state_sql(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
"""Execute the save-state INSERT without acquiring the lock.
|
||||
|
||||
Args:
|
||||
conn: An open SQLite connection.
|
||||
flow_uuid: Unique identifier for the flow instance.
|
||||
method_name: Name of the method that just completed.
|
||||
state_dict: State data as a plain dict.
|
||||
"""
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_state_dict(state_data: dict[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
"""Convert state_data to a plain dict."""
|
||||
if isinstance(state_data, BaseModel):
|
||||
return state_data.model_dump()
|
||||
if isinstance(state_data, dict):
|
||||
return state_data
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
@@ -126,33 +174,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
@@ -163,7 +191,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
@@ -197,24 +225,14 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
context: The pending feedback context with all resume information
|
||||
state_data: Current state data
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
# Convert state_data to dict
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = state_data.model_dump()
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, context.method_name, state_dict)
|
||||
|
||||
# Also save to regular state table for consistency
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO pending_feedback (
|
||||
@@ -248,7 +266,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json, context_json
|
||||
@@ -272,7 +290,10 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with (
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -600,7 +600,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
def _save_to_memory(self, output_text: str) -> None:
|
||||
"""Extract discrete memories from the run and remember each. No-op if _memory is None or read-only."""
|
||||
if self._memory is None or getattr(self._memory, "_read_only", False):
|
||||
if self._memory is None or self._memory.read_only:
|
||||
return
|
||||
input_str = self._get_last_user_content() or "User request"
|
||||
try:
|
||||
|
||||
@@ -6,9 +6,27 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.planning_types import TodoItem
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class TodoExecutionResult(BaseModel):
|
||||
"""Summary of a single todo execution."""
|
||||
|
||||
step_number: int = Field(description="Step number in the plan")
|
||||
description: str = Field(description="What the todo was supposed to do")
|
||||
tool_used: str | None = Field(
|
||||
default=None, description="Tool that was used for this step"
|
||||
)
|
||||
status: str = Field(description="Final status: completed, failed, pending")
|
||||
result: str | None = Field(
|
||||
default=None, description="Result or error message from execution"
|
||||
)
|
||||
depends_on: list[int] = Field(
|
||||
default_factory=list, description="Step numbers this depended on"
|
||||
)
|
||||
|
||||
|
||||
class LiteAgentOutput(BaseModel):
|
||||
"""Class that represents the result of a LiteAgent execution."""
|
||||
|
||||
@@ -24,12 +42,75 @@ class LiteAgentOutput(BaseModel):
|
||||
)
|
||||
messages: list[LLMMessage] = Field(description="Messages of the agent", default=[])
|
||||
|
||||
plan: str | None = Field(
|
||||
default=None, description="The execution plan that was generated, if any"
|
||||
)
|
||||
todos: list[TodoExecutionResult] = Field(
|
||||
default_factory=list,
|
||||
description="List of todos that were executed with their results",
|
||||
)
|
||||
replan_count: int = Field(
|
||||
default=0, description="Number of times the plan was regenerated"
|
||||
)
|
||||
last_replan_reason: str | None = Field(
|
||||
default=None, description="Reason for the last replan, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_todo_items(cls, todo_items: list[TodoItem]) -> list[TodoExecutionResult]:
|
||||
"""Convert TodoItem objects to TodoExecutionResult summaries.
|
||||
|
||||
Args:
|
||||
todo_items: List of TodoItem objects from execution.
|
||||
|
||||
Returns:
|
||||
List of TodoExecutionResult summaries.
|
||||
"""
|
||||
return [
|
||||
TodoExecutionResult(
|
||||
step_number=item.step_number,
|
||||
description=item.description,
|
||||
tool_used=item.tool_to_use,
|
||||
status=item.status,
|
||||
result=item.result,
|
||||
depends_on=item.depends_on,
|
||||
)
|
||||
for item in todo_items
|
||||
]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
return {}
|
||||
|
||||
@property
|
||||
def completed_todos(self) -> list[TodoExecutionResult]:
|
||||
"""Get only the completed todos."""
|
||||
return [t for t in self.todos if t.status == "completed"]
|
||||
|
||||
@property
|
||||
def failed_todos(self) -> list[TodoExecutionResult]:
|
||||
"""Get only the failed todos."""
|
||||
return [t for t in self.todos if t.status == "failed"]
|
||||
|
||||
@property
|
||||
def had_plan(self) -> bool:
|
||||
"""Check if the agent executed with a plan."""
|
||||
return self.plan is not None or len(self.todos) > 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the raw output as a string."""
|
||||
return self.raw
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a detailed representation including todo summary."""
|
||||
parts = [f"LiteAgentOutput(role={self.agent_role!r}"]
|
||||
if self.todos:
|
||||
completed = len(self.completed_todos)
|
||||
total = len(self.todos)
|
||||
parts.append(f", todos={completed}/{total} completed")
|
||||
if self.replan_count > 0:
|
||||
parts.append(f", replans={self.replan_count}")
|
||||
parts.append(")")
|
||||
return "".join(parts)
|
||||
|
||||
@@ -22,7 +22,12 @@ if TYPE_CHECKING:
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types import (
|
||||
Message,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
@@ -31,6 +36,11 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
TOOL_SEARCH_TOOL_TYPES: Final[tuple[str, ...]] = (
|
||||
"tool_search_tool_regex_20251119",
|
||||
"tool_search_tool_bm25_20251119",
|
||||
)
|
||||
|
||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||
|
||||
@@ -117,6 +127,22 @@ class AnthropicThinkingConfig(BaseModel):
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicToolSearchConfig(BaseModel):
|
||||
"""Configuration for Anthropic's server-side tool search.
|
||||
|
||||
When enabled, tools marked with defer_loading=True are not loaded into
|
||||
context immediately. Instead, Claude uses the tool search tool to
|
||||
dynamically discover and load relevant tools on-demand.
|
||||
|
||||
Attributes:
|
||||
type: The tool search variant to use.
|
||||
- "regex": Claude constructs regex patterns to search tool names/descriptions.
|
||||
- "bm25": Claude uses natural language queries to search tools.
|
||||
"""
|
||||
|
||||
type: Literal["regex", "bm25"] = "bm25"
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -140,6 +166,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -159,6 +186,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -190,6 +221,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
@@ -432,10 +470,23 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
converted_tools = self._convert_tools_for_interference(tools)
|
||||
|
||||
# When tool_search is enabled and there are 2+ regular tools,
|
||||
# inject the search tool and mark regular tools with defer_loading.
|
||||
# With only 1 tool there's nothing to search — skip tool search
|
||||
# entirely so the normal forced tool_choice optimisation still works.
|
||||
regular_tools = [
|
||||
t
|
||||
for t in converted_tools
|
||||
if t.get("type", "") not in TOOL_SEARCH_TOOL_TYPES
|
||||
]
|
||||
if self.tool_search is not None and len(regular_tools) >= 2:
|
||||
converted_tools = self._apply_tool_search(converted_tools)
|
||||
|
||||
params["tools"] = converted_tools
|
||||
|
||||
if available_functions and len(converted_tools) == 1:
|
||||
tool_name = converted_tools[0].get("name")
|
||||
if available_functions and len(regular_tools) == 1:
|
||||
tool_name = regular_tools[0].get("name")
|
||||
if tool_name and tool_name in available_functions:
|
||||
params["tool_choice"] = {"type": "tool", "name": tool_name}
|
||||
|
||||
@@ -454,6 +505,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
# Pass through tool search tool definitions unchanged
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in TOOL_SEARCH_TOOL_TYPES:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
|
||||
if "input_schema" in tool and "name" in tool and "description" in tool:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
@@ -466,15 +523,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
anthropic_tool = {
|
||||
anthropic_tool: dict[str, Any] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
else:
|
||||
anthropic_tool["input_schema"] = { # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
@@ -484,6 +541,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _apply_tool_search(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Inject tool search tool and mark regular tools with defer_loading.
|
||||
|
||||
When tool_search is enabled, this method:
|
||||
1. Adds the appropriate tool search tool definition (regex or bm25)
|
||||
2. Marks all regular tools with defer_loading=True so they are only
|
||||
loaded when Claude discovers them via search
|
||||
|
||||
Args:
|
||||
tools: Converted tool definitions in Anthropic format.
|
||||
|
||||
Returns:
|
||||
Updated tools list with tool search tool prepended and
|
||||
regular tools marked as deferred.
|
||||
"""
|
||||
if self.tool_search is None:
|
||||
return tools
|
||||
|
||||
# Check if a tool search tool is already present (user passed one manually)
|
||||
has_search_tool = any(
|
||||
t.get("type", "") in TOOL_SEARCH_TOOL_TYPES for t in tools
|
||||
)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
|
||||
if not has_search_tool:
|
||||
# Map config type to API type identifier
|
||||
type_map = {
|
||||
"regex": "tool_search_tool_regex_20251119",
|
||||
"bm25": "tool_search_tool_bm25_20251119",
|
||||
}
|
||||
tool_type = type_map[self.tool_search.type]
|
||||
# Tool search tool names follow the convention: tool_search_tool_{variant}
|
||||
tool_name = f"tool_search_tool_{self.tool_search.type}"
|
||||
result.append({"type": tool_type, "name": tool_name})
|
||||
|
||||
for tool in tools:
|
||||
# Don't modify tool search tools
|
||||
if tool.get("type", "") in TOOL_SEARCH_TOOL_TYPES:
|
||||
result.append(tool)
|
||||
continue
|
||||
|
||||
# Mark regular tools as deferred if not already set
|
||||
if "defer_loading" not in tool:
|
||||
tool = {**tool, "defer_loading": True}
|
||||
result.append(tool)
|
||||
|
||||
return result
|
||||
|
||||
def _extract_thinking_block(
|
||||
self, content_block: Any
|
||||
) -> ThinkingBlock | dict[str, Any] | None:
|
||||
@@ -512,6 +618,50 @@ class AnthropicCompletion(BaseLLM):
|
||||
return redacted_block
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_image_blocks(content: Any) -> Any:
|
||||
"""Convert OpenAI-style image_url blocks to Anthropic image blocks.
|
||||
|
||||
Upstream code (e.g. StepExecutor) uses the standard ``image_url``
|
||||
format with a ``data:`` URI. Anthropic rejects that — it requires
|
||||
``{"type": "image", "source": {"type": "base64", ...}}``.
|
||||
|
||||
Non-list content and blocks that are not ``image_url`` are passed
|
||||
through unchanged.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
converted: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict) or block.get("type") != "image_url":
|
||||
converted.append(block)
|
||||
continue
|
||||
|
||||
image_info = block.get("image_url", {})
|
||||
url = image_info.get("url", "") if isinstance(image_info, dict) else ""
|
||||
if url.startswith("data:") and ";base64," in url:
|
||||
# Parse data:<media_type>;base64,<data>
|
||||
header, b64_data = url.split(";base64,", 1)
|
||||
media_type = (
|
||||
header.split("data:", 1)[1] if "data:" in header else "image/png"
|
||||
)
|
||||
converted.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": b64_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Non-data URI — pass through as-is (Anthropic supports url source)
|
||||
converted.append(block)
|
||||
|
||||
return converted
|
||||
|
||||
def _format_messages_for_anthropic(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[LLMMessage], str | None]:
|
||||
@@ -550,10 +700,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
tool_call_id = message.get("tool_call_id", "")
|
||||
if not tool_call_id:
|
||||
raise ValueError("Tool message missing required tool_call_id")
|
||||
tool_content = self._convert_image_blocks(content) if content else ""
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content if content else "",
|
||||
"content": tool_content,
|
||||
}
|
||||
pending_tool_results.append(tool_result)
|
||||
elif role == "assistant":
|
||||
@@ -612,7 +763,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
role_str = role if role is not None else "user"
|
||||
if isinstance(content, list):
|
||||
formatted_messages.append({"role": role_str, "content": content})
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": role_str,
|
||||
"content": self._convert_image_blocks(content),
|
||||
}
|
||||
)
|
||||
else:
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append(
|
||||
|
||||
@@ -1781,6 +1781,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
converse_messages: list[LLMMessage] = []
|
||||
system_message: str | None = None
|
||||
pending_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
for message in formatted_messages:
|
||||
role = message.get("role")
|
||||
@@ -1794,56 +1795,62 @@ class BedrockCompletion(BaseLLM):
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = cast(str, content)
|
||||
elif role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
elif role == "tool":
|
||||
if not tool_call_id:
|
||||
raise ValueError("Tool message missing required tool_call_id")
|
||||
converse_messages.append(
|
||||
pending_tool_results.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [
|
||||
{"text": str(content) if content else ""}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [{"text": str(content) if content else ""}],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
if pending_tool_results:
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
)
|
||||
|
||||
if pending_tool_results:
|
||||
converse_messages.append({"role": "user", "content": pending_tool_results})
|
||||
|
||||
# CRITICAL: Handle model-specific conversation requirements
|
||||
# Cohere and some other models require conversation to end with user message
|
||||
# Cohere and some other models require conversation to end with user message.
|
||||
# Anthropic models on Bedrock also reject assistant messages in the final
|
||||
# position when tools are present ("pre-filling the assistant response is
|
||||
# not supported").
|
||||
if converse_messages:
|
||||
last_message = converse_messages[-1]
|
||||
if last_message["role"] == "assistant":
|
||||
@@ -1870,6 +1877,20 @@ class BedrockCompletion(BaseLLM):
|
||||
"content": [{"text": "Continue your response."}],
|
||||
}
|
||||
)
|
||||
# Anthropic (Claude) models reject assistant-last messages when
|
||||
# tools are in the request. Append a user message so the
|
||||
# Converse API accepts the payload.
|
||||
elif "anthropic" in self.model.lower() or "claude" in self.model.lower():
|
||||
converse_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"text": "Please continue and provide your final answer."
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure first message is from user (required by Converse API)
|
||||
if not converse_messages:
|
||||
|
||||
@@ -11,6 +11,7 @@ into a standalone MCPToolResolver. It handles three flavours of MCP reference:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from urllib.parse import urlparse
|
||||
@@ -25,6 +26,7 @@ from crewai.mcp.config import (
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -74,10 +76,9 @@ class MCPToolResolver:
|
||||
elif isinstance(mcp_config, str):
|
||||
amp_refs.append(self._parse_amp_ref(mcp_config))
|
||||
else:
|
||||
tools, client = self._resolve_native(mcp_config)
|
||||
tools, clients = self._resolve_native(mcp_config)
|
||||
all_tools.extend(tools)
|
||||
if client:
|
||||
self._clients.append(client)
|
||||
self._clients.extend(clients)
|
||||
|
||||
if amp_refs:
|
||||
tools, clients = self._resolve_amp(amp_refs)
|
||||
@@ -131,7 +132,7 @@ class MCPToolResolver:
|
||||
all_tools: list[BaseTool] = []
|
||||
all_clients: list[Any] = []
|
||||
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {}
|
||||
|
||||
for slug in unique_slugs:
|
||||
config_dict = amp_configs_map.get(slug)
|
||||
@@ -149,10 +150,9 @@ class MCPToolResolver:
|
||||
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
try:
|
||||
tools, client = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, client)
|
||||
if client:
|
||||
all_clients.append(client)
|
||||
tools, clients = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, clients)
|
||||
all_clients.extend(clients)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -170,8 +170,9 @@ class MCPToolResolver:
|
||||
|
||||
slug_tools, _ = cached
|
||||
if specific_tool:
|
||||
sanitized = sanitize_tool_name(specific_tool)
|
||||
all_tools.extend(
|
||||
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
|
||||
t for t in slug_tools if t.name.endswith(f"_{sanitized}")
|
||||
)
|
||||
else:
|
||||
all_tools.extend(slug_tools)
|
||||
@@ -198,7 +199,6 @@ class MCPToolResolver:
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
|
||||
if response.status_code == 200:
|
||||
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
|
||||
return configs
|
||||
@@ -218,6 +218,7 @@ class MCPToolResolver:
|
||||
|
||||
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Resolve an HTTPS MCP server URL into tools."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
if "#" in mcp_ref:
|
||||
@@ -227,6 +228,9 @@ class MCPToolResolver:
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
sanitized_specific_tool = (
|
||||
sanitize_tool_name(specific_tool) if specific_tool else None
|
||||
)
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
@@ -239,7 +243,7 @@ class MCPToolResolver:
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
if sanitized_specific_tool and tool_name != sanitized_specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -271,14 +275,16 @@ class MCPToolResolver:
|
||||
)
|
||||
return []
|
||||
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
@staticmethod
|
||||
def _create_transport(
|
||||
mcp_config: MCPServerConfig,
|
||||
) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]:
|
||||
"""Create a fresh transport instance from an MCP server config.
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
Returns a ``(transport, server_name)`` tuple. Each call produces an
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -292,38 +298,54 @@ class MCPToolResolver:
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
return transport, server_name
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], list[Any]]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools.
|
||||
|
||||
Returns ``(tools, clients)`` where *clients* is always empty for
|
||||
native tools (clients are now created on-demand per invocation).
|
||||
A ``client_factory`` closure is passed to each ``MCPNativeTool`` so
|
||||
every call -- even concurrent calls to the *same* tool -- gets its
|
||||
own ``MCPClient`` + transport with no shared mutable state.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
discovery_transport, server_name = self._create_transport(mcp_config)
|
||||
discovery_client = MCPClient(
|
||||
transport=discovery_transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
if not discovery_client.connected:
|
||||
await discovery_client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
tools_list = await discovery_client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
if discovery_client.connected:
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
@@ -334,9 +356,10 @@ class MCPToolResolver:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
ctx.run, asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
@@ -376,6 +399,13 @@ class MCPToolResolver:
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
def _client_factory() -> MCPClient:
|
||||
transport, _ = self._create_transport(mcp_config)
|
||||
return MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
@@ -396,7 +426,7 @@ class MCPToolResolver:
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
client_factory=_client_factory,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
@@ -407,10 +437,10 @@ class MCPToolResolver:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
return cast(list[BaseTool], tools), []
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
if discovery_client.connected:
|
||||
asyncio.run(discovery_client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
|
||||
@@ -308,7 +308,9 @@ def analyze_for_save(
|
||||
return MemoryAnalysis.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
|
||||
"Memory save analysis failed, using defaults: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _SAVE_DEFAULTS
|
||||
|
||||
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
|
||||
return ConsolidationPlan.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
|
||||
"Consolidation analysis failed, defaulting to insert: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return _CONSOLIDATION_DEFAULT
|
||||
|
||||
@@ -11,7 +11,9 @@ Orchestrates the encoding side of memory in a single Flow with 5 steps:
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
@@ -28,6 +30,8 @@ from crewai.memory.analyze import (
|
||||
from crewai.memory.types import MemoryConfig, MemoryRecord, embed_texts
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -164,14 +168,20 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
def parallel_find_similar(self) -> None:
|
||||
"""Search storage for similar records, concurrently for all active items."""
|
||||
items = list(self.state.items)
|
||||
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
|
||||
active = [
|
||||
(i, item)
|
||||
for i, item in enumerate(items)
|
||||
if not item.dropped and item.embedding
|
||||
]
|
||||
|
||||
if not active:
|
||||
return
|
||||
|
||||
def _search_one(item: ItemState) -> list[tuple[MemoryRecord, float]]:
|
||||
def _search_one(
|
||||
item: ItemState,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
scope_prefix = item.scope if item.scope and item.scope.strip("/") else None
|
||||
return self._storage.search(
|
||||
return self._storage.search( # type: ignore[no-any-return]
|
||||
item.embedding,
|
||||
scope_prefix=scope_prefix,
|
||||
categories=None,
|
||||
@@ -181,14 +191,37 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
|
||||
if len(active) == 1:
|
||||
_, item = active[0]
|
||||
raw = _search_one(item)
|
||||
try:
|
||||
raw = _search_one(item)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in parallel_find_similar, "
|
||||
"treating item as new",
|
||||
exc_info=True,
|
||||
)
|
||||
raw = []
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
|
||||
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
|
||||
futures = [
|
||||
(
|
||||
i,
|
||||
item,
|
||||
pool.submit(contextvars.copy_context().run, _search_one, item),
|
||||
)
|
||||
for i, item in active
|
||||
]
|
||||
for _, item, future in futures:
|
||||
raw = future.result()
|
||||
try:
|
||||
raw = future.result()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in parallel_find_similar, "
|
||||
"treating item as new",
|
||||
exc_info=True,
|
||||
)
|
||||
raw = []
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
|
||||
@@ -250,24 +283,38 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
# Group B: consolidation only
|
||||
self._apply_defaults(item)
|
||||
consol_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_consolidation,
|
||||
item.content, list(item.similar_records), self._llm,
|
||||
item.content,
|
||||
list(item.similar_records),
|
||||
self._llm,
|
||||
)
|
||||
elif not fields_provided and not has_similar:
|
||||
# Group C: field resolution only
|
||||
save_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_save,
|
||||
item.content, existing_scopes, existing_categories, self._llm,
|
||||
item.content,
|
||||
existing_scopes,
|
||||
existing_categories,
|
||||
self._llm,
|
||||
)
|
||||
else:
|
||||
# Group D: both in parallel
|
||||
save_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_save,
|
||||
item.content, existing_scopes, existing_categories, self._llm,
|
||||
item.content,
|
||||
existing_scopes,
|
||||
existing_categories,
|
||||
self._llm,
|
||||
)
|
||||
consol_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_consolidation,
|
||||
item.content, list(item.similar_records), self._llm,
|
||||
item.content,
|
||||
list(item.similar_records),
|
||||
self._llm,
|
||||
)
|
||||
|
||||
# Collect field-resolution results
|
||||
@@ -300,8 +347,8 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
item.plan = ConsolidationPlan(actions=[], insert_new=True)
|
||||
|
||||
# Collect consolidation results
|
||||
for i, future in consol_futures.items():
|
||||
items[i].plan = future.result()
|
||||
for i, consol_future in consol_futures.items():
|
||||
items[i].plan = consol_future.result()
|
||||
finally:
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
@@ -339,7 +386,9 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
# similar_records overlap). Collect one action per record_id, first wins.
|
||||
# Also build a map from record_id to the original MemoryRecord for updates.
|
||||
dedup_deletes: set[str] = set() # record_ids to delete
|
||||
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
|
||||
dedup_updates: dict[
|
||||
str, tuple[int, str]
|
||||
] = {} # record_id -> (item_idx, new_content)
|
||||
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
|
||||
|
||||
for i, item in enumerate(items):
|
||||
@@ -350,13 +399,24 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
all_similar[r.id] = r
|
||||
for action in item.plan.actions:
|
||||
rid = action.record_id
|
||||
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
|
||||
if (
|
||||
action.action == "delete"
|
||||
and rid not in dedup_deletes
|
||||
and rid not in dedup_updates
|
||||
):
|
||||
dedup_deletes.add(rid)
|
||||
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
|
||||
elif (
|
||||
action.action == "update"
|
||||
and action.new_content
|
||||
and rid not in dedup_deletes
|
||||
and rid not in dedup_updates
|
||||
):
|
||||
dedup_updates[rid] = (i, action.new_content)
|
||||
|
||||
# --- Batch re-embed all update contents in ONE call ---
|
||||
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
|
||||
update_list = list(
|
||||
dedup_updates.items()
|
||||
) # [(record_id, (item_idx, new_content)), ...]
|
||||
update_embeddings: list[list[float]] = []
|
||||
if update_list:
|
||||
update_contents = [content for _, (_, content) in update_list]
|
||||
@@ -377,51 +437,52 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
if item.dropped or item.plan is None:
|
||||
continue
|
||||
if item.plan.insert_new:
|
||||
to_insert.append((i, MemoryRecord(
|
||||
content=item.content,
|
||||
scope=item.resolved_scope,
|
||||
categories=item.resolved_categories,
|
||||
metadata=item.resolved_metadata,
|
||||
importance=item.resolved_importance,
|
||||
embedding=item.embedding if item.embedding else None,
|
||||
source=item.resolved_source,
|
||||
private=item.resolved_private,
|
||||
)))
|
||||
|
||||
# All storage mutations under one lock so no other pipeline can
|
||||
# interleave and cause version conflicts. The lock is reentrant
|
||||
# (RLock) so the individual storage methods re-acquire it safely.
|
||||
updated_records: dict[str, MemoryRecord] = {}
|
||||
with self._storage.write_lock:
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
to_insert.append(
|
||||
(
|
||||
i,
|
||||
MemoryRecord(
|
||||
content=item.content,
|
||||
scope=item.resolved_scope,
|
||||
categories=item.resolved_categories,
|
||||
metadata=item.resolved_metadata,
|
||||
importance=item.resolved_importance,
|
||||
embedding=item.embedding if item.embedding else None,
|
||||
source=item.resolved_source,
|
||||
private=item.resolved_private,
|
||||
),
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
)
|
||||
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
updated_records: dict[str, MemoryRecord] = {}
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
|
||||
# Set result_record for non-insert items (after lock, using updated_records)
|
||||
for _i, item in enumerate(items):
|
||||
|
||||
@@ -3,11 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
@@ -15,22 +13,38 @@ from crewai.memory.types import (
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
)
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
class MemoryScope:
|
||||
class MemoryScope(BaseModel):
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
def __init__(self, memory: Memory, root_path: str) -> None:
|
||||
"""Initialize scope.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
root_path: Root path for this scope (e.g. /agent/1).
|
||||
"""
|
||||
self._memory = memory
|
||||
self._root = root_path.rstrip("/") or ""
|
||||
if self._root and not self._root.startswith("/"):
|
||||
self._root = "/" + self._root
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_root: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemoryScope:
|
||||
"""Extract memory dependency and normalize root path before validation."""
|
||||
if isinstance(data, MemoryScope):
|
||||
return data
|
||||
memory = data.pop("memory")
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
root = instance.root_path.rstrip("/") or ""
|
||||
if root and not root.startswith("/"):
|
||||
root = "/" + root
|
||||
instance._root = root
|
||||
return instance
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
return self._memory.read_only
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
@@ -52,7 +66,7 @@ class MemoryScope:
|
||||
importance: float | None = None,
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
) -> MemoryRecord:
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
@@ -71,7 +85,7 @@ class MemoryScope:
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "deep",
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
@@ -138,34 +152,34 @@ class MemoryScope:
|
||||
"""Return a narrower scope under this scope."""
|
||||
child = path.strip("/")
|
||||
if not child:
|
||||
return MemoryScope(self._memory, self._root or "/")
|
||||
return MemoryScope(memory=self._memory, root_path=self._root or "/")
|
||||
base = self._root.rstrip("/") or ""
|
||||
new_root = f"{base}/{child}" if base else f"/{child}"
|
||||
return MemoryScope(self._memory, new_root)
|
||||
return MemoryScope(memory=self._memory, root_path=new_root)
|
||||
|
||||
|
||||
class MemorySlice:
|
||||
class MemorySlice(BaseModel):
|
||||
"""View over multiple scopes: recall searches all, remember is a no-op when read_only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Memory,
|
||||
scopes: list[str],
|
||||
categories: list[str] | None = None,
|
||||
read_only: bool = True,
|
||||
) -> None:
|
||||
"""Initialize slice.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
scopes: List of scope paths to include.
|
||||
categories: Optional category filter for recall.
|
||||
read_only: If True, remember() is a silent no-op.
|
||||
"""
|
||||
self._memory = memory
|
||||
self._scopes = [s.rstrip("/") or "/" for s in scopes]
|
||||
self._categories = categories
|
||||
self._read_only = read_only
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemorySlice:
|
||||
"""Extract memory dependency and normalize scopes before validation."""
|
||||
if isinstance(data, MemorySlice):
|
||||
return data
|
||||
memory = data.pop("memory")
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
instance._memory = memory
|
||||
return instance
|
||||
|
||||
def remember(
|
||||
self,
|
||||
@@ -178,7 +192,7 @@ class MemorySlice:
|
||||
private: bool = False,
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember into an explicit scope. No-op when read_only=True."""
|
||||
if self._read_only:
|
||||
if self.read_only:
|
||||
return None
|
||||
return self._memory.remember(
|
||||
content,
|
||||
@@ -196,14 +210,14 @@ class MemorySlice:
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "deep",
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall across all slice scopes; results merged and re-ranked."""
|
||||
cats = categories or self._categories
|
||||
cats = categories or self.categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
matches = self._memory.recall(
|
||||
query,
|
||||
scope=sc,
|
||||
@@ -231,7 +245,7 @@ class MemorySlice:
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
return sorted(set(out))
|
||||
@@ -243,15 +257,23 @@ class MemorySlice:
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
children: list[str] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
|
||||
oldest = (
|
||||
inf.oldest_record
|
||||
if oldest is None
|
||||
else min(oldest, inf.oldest_record)
|
||||
)
|
||||
if inf.newest_record:
|
||||
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
|
||||
newest = (
|
||||
inf.newest_record
|
||||
if newest is None
|
||||
else max(newest, inf.newest_record)
|
||||
)
|
||||
children.extend(inf.child_scopes)
|
||||
return ScopeInfo(
|
||||
path=path,
|
||||
@@ -265,7 +287,7 @@ class MemorySlice:
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories and counts across slice scopes."""
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
|
||||
@@ -11,7 +11,9 @@ Implements adaptive-depth retrieval with:
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -29,6 +31,9 @@ from crewai.memory.types import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecallState(BaseModel):
|
||||
"""State for the recall flow."""
|
||||
|
||||
@@ -103,13 +108,12 @@ class RecallFlow(Flow[RecallState]):
|
||||
)
|
||||
# Post-filter by time cutoff
|
||||
if self.state.time_cutoff and raw:
|
||||
raw = [
|
||||
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
|
||||
]
|
||||
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
|
||||
# Privacy filter
|
||||
if not self.state.include_private and raw:
|
||||
raw = [
|
||||
(r, s) for r, s in raw
|
||||
(r, s)
|
||||
for r, s in raw
|
||||
if not r.private or r.source == self.state.source
|
||||
]
|
||||
return scope, raw
|
||||
@@ -125,38 +129,57 @@ class RecallFlow(Flow[RecallState]):
|
||||
|
||||
if len(tasks) <= 1:
|
||||
for emb, sc in tasks:
|
||||
scope, results = _search_one(emb, sc)
|
||||
try:
|
||||
scope, results = _search_one(emb, sc)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in recall flow, skipping scope",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
if results:
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
)
|
||||
findings.append({
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
})
|
||||
findings.append(
|
||||
{
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
}
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
|
||||
futures = {
|
||||
pool.submit(_search_one, emb, sc): (emb, sc)
|
||||
pool.submit(contextvars.copy_context().run, _search_one, emb, sc): (
|
||||
emb,
|
||||
sc,
|
||||
)
|
||||
for emb, sc in tasks
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
scope, results = future.result()
|
||||
try:
|
||||
scope, results = future.result()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage search failed in recall flow, skipping scope",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
if results:
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
)
|
||||
findings.append({
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
})
|
||||
findings.append(
|
||||
{
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
}
|
||||
)
|
||||
|
||||
self.state.chunk_findings = findings
|
||||
self.state.confidence = max(
|
||||
(f["top_score"] for f in findings), default=0.0
|
||||
)
|
||||
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
|
||||
return findings
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -210,12 +233,16 @@ class RecallFlow(Flow[RecallState]):
|
||||
# Parse time_filter into a datetime cutoff
|
||||
if analysis.time_filter:
|
||||
try:
|
||||
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
|
||||
self.state.time_cutoff = datetime.fromisoformat(
|
||||
analysis.time_filter
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Batch-embed all sub-queries in ONE call
|
||||
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||
queries = (
|
||||
analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||
)
|
||||
queries = queries[:3]
|
||||
embeddings = embed_texts(self._embedder, queries)
|
||||
pairs: list[tuple[str, list[float]]] = [
|
||||
@@ -237,13 +264,17 @@ class RecallFlow(Flow[RecallState]):
|
||||
if analysis and analysis.suggested_scopes:
|
||||
candidates = [s for s in analysis.suggested_scopes if s]
|
||||
else:
|
||||
candidates = self._storage.list_scopes(scope_prefix)
|
||||
try:
|
||||
candidates = self._storage.list_scopes(scope_prefix)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Storage list_scopes failed in filter_and_chunk, "
|
||||
"falling back to scope prefix",
|
||||
exc_info=True,
|
||||
)
|
||||
candidates = []
|
||||
if not candidates:
|
||||
info = self._storage.get_scope_info(scope_prefix)
|
||||
if info.record_count > 0:
|
||||
candidates = [scope_prefix]
|
||||
else:
|
||||
candidates = [scope_prefix]
|
||||
candidates = [scope_prefix]
|
||||
self.state.candidate_scopes = candidates[:20]
|
||||
return self.state.candidate_scopes
|
||||
|
||||
@@ -296,17 +327,21 @@ class RecallFlow(Flow[RecallState]):
|
||||
response = self._llm.call([{"role": "user", "content": prompt}])
|
||||
if isinstance(response, str) and "missing" in response.lower():
|
||||
self.state.evidence_gaps.append(response[:200])
|
||||
enhanced.append({
|
||||
"scope": finding["scope"],
|
||||
"extraction": response,
|
||||
"results": finding["results"],
|
||||
})
|
||||
enhanced.append(
|
||||
{
|
||||
"scope": finding["scope"],
|
||||
"extraction": response,
|
||||
"results": finding["results"],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
enhanced.append({
|
||||
"scope": finding["scope"],
|
||||
"extraction": "",
|
||||
"results": finding["results"],
|
||||
})
|
||||
enhanced.append(
|
||||
{
|
||||
"scope": finding["scope"],
|
||||
"extraction": "",
|
||||
"results": finding["results"],
|
||||
}
|
||||
)
|
||||
self.state.chunk_findings = enhanced
|
||||
return enhanced
|
||||
|
||||
@@ -318,7 +353,7 @@ class RecallFlow(Flow[RecallState]):
|
||||
@router(re_search)
|
||||
def re_decide_depth(self) -> str:
|
||||
"""Re-evaluate depth after re-search. Same logic as decide_depth."""
|
||||
return self.decide_depth()
|
||||
return self.decide_depth() # type: ignore[call-arg]
|
||||
|
||||
@listen("synthesize")
|
||||
def synthesize_results(self) -> list[MemoryMatch]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
@@ -8,6 +9,7 @@ from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -24,6 +26,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -38,23 +41,25 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
expected_output TEXT,
|
||||
output JSON,
|
||||
task_index INTEGER,
|
||||
inputs JSON,
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -82,25 +87,26 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(task.id),
|
||||
task.expected_output,
|
||||
json.dumps(output, cls=CrewJSONEncoder),
|
||||
task_index,
|
||||
json.dumps(inputs, cls=CrewJSONEncoder),
|
||||
was_replayed,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -125,30 +131,31 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
fields = []
|
||||
values = []
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||
values.append(task_index)
|
||||
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
cursor.execute(query, tuple(values))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
@@ -166,7 +173,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -205,11 +212,12 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
with store_lock(self._lock_name):
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -9,11 +10,12 @@ import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import lancedb # type: ignore[import-untyped]
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -39,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
|
||||
class LanceDBStorage:
|
||||
"""LanceDB-backed storage for the unified memory system."""
|
||||
|
||||
# Class-level registry: maps resolved database path -> shared write lock.
|
||||
# When multiple Memory instances (e.g. agent + crew) independently create
|
||||
# LanceDBStorage pointing at the same directory, they share one lock so
|
||||
# their writes don't conflict.
|
||||
# Uses RLock (reentrant) so callers can hold the lock for a batch of
|
||||
# operations while the individual methods re-acquire it without deadlocking.
|
||||
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
|
||||
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
@@ -83,13 +76,9 @@ class LanceDBStorage:
|
||||
self._table_name = table_name
|
||||
self._db = lancedb.connect(str(self._path))
|
||||
|
||||
# On macOS and Linux the default per-process open-file limit is 256.
|
||||
# A LanceDB table stores one file per fragment (one fragment per save()
|
||||
# call by default). With hundreds of fragments, a single full-table
|
||||
# scan opens all of them simultaneously, exhausting the limit.
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft < 4096:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||
@@ -99,68 +88,46 @@ class LanceDBStorage:
|
||||
self._compact_every = compact_every
|
||||
self._save_count = 0
|
||||
|
||||
# Get or create a shared write lock for this database path.
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
LanceDBStorage._path_locks[resolved] = threading.RLock()
|
||||
self._write_lock = LanceDBStorage._path_locks[resolved]
|
||||
self._lock_name = f"lancedb:{self._path.resolve()}"
|
||||
|
||||
# Try to open an existing table and infer dimension from its schema.
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
|
||||
self._table: Any = self._db.open_table(self._table_name)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_scope_index()
|
||||
self._compact_if_needed()
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Failed to open existing LanceDB table %r", table_name, exc_info=True
|
||||
)
|
||||
self._table = None
|
||||
self._vector_dim = vector_dim or 0 # 0 = not yet known
|
||||
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
"""The shared reentrant write lock for this database path.
|
||||
|
||||
Callers can acquire this to hold the lock across multiple storage
|
||||
operations (e.g. delete + update + save as one atomic batch).
|
||||
Individual methods also acquire it internally, but since it's
|
||||
reentrant (RLock), the same thread won't deadlock.
|
||||
"""
|
||||
return self._write_lock
|
||||
with store_lock(self._lock_name):
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@staticmethod
|
||||
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
|
||||
def _infer_dim_from_table(table: Any) -> int:
|
||||
"""Read vector dimension from an existing table's schema."""
|
||||
schema = table.schema
|
||||
for field in schema:
|
||||
if field.name == "vector":
|
||||
try:
|
||||
return field.type.list_size
|
||||
return int(field.type.list_size)
|
||||
except Exception:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _retry_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a table operation with retry on LanceDB commit conflicts.
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
Args:
|
||||
op: Method name on the table object (e.g. "add", "delete").
|
||||
*args, **kwargs: Passed to the table method.
|
||||
|
||||
LanceDB uses optimistic concurrency: if two transactions overlap,
|
||||
the second to commit fails with an ``OSError`` containing
|
||||
"Commit conflict". This helper retries with exponential backoff,
|
||||
refreshing the table reference before each retry so the retried
|
||||
call uses the latest committed version (not a stale reference).
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -171,20 +138,24 @@ class LanceDBStorage:
|
||||
raise
|
||||
_logger.debug(
|
||||
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
||||
op, attempt + 1, _MAX_RETRIES, delay,
|
||||
op,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES,
|
||||
delay,
|
||||
)
|
||||
# Refresh table to pick up the latest version before retrying.
|
||||
# The next getattr(self._table, op) will use the fresh table.
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass # table refresh is best-effort
|
||||
except Exception:
|
||||
_logger.debug("Failed to re-open table during retry", exc_info=True)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
"""Create a new table with the given vector dimension."""
|
||||
def _create_table(self, vector_dim: int) -> Any:
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
"id": "__schema_placeholder__",
|
||||
@@ -200,8 +171,12 @@ class LanceDBStorage:
|
||||
"vector": [0.0] * vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
try:
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
except ValueError:
|
||||
table = self._db.open_table(self._table_name)
|
||||
else:
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
return table
|
||||
|
||||
def _ensure_scope_index(self) -> None:
|
||||
@@ -217,8 +192,10 @@ class LanceDBStorage:
|
||||
return
|
||||
try:
|
||||
self._table.create_scalar_index("scope", index_type="BTREE", replace=False)
|
||||
except Exception: # noqa: S110
|
||||
pass # index already exists, table empty, or unsupported version
|
||||
except Exception:
|
||||
_logger.debug(
|
||||
"Scope index creation skipped (may already exist)", exc_info=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Automatic background compaction
|
||||
@@ -238,8 +215,10 @@ class LanceDBStorage:
|
||||
|
||||
def _compact_async(self) -> None:
|
||||
"""Fire-and-forget: compact the table in a daemon background thread."""
|
||||
ctx = contextvars.copy_context()
|
||||
threading.Thread(
|
||||
target=self._compact_safe,
|
||||
target=ctx.run,
|
||||
args=(self._compact_safe,),
|
||||
daemon=True,
|
||||
name="lancedb-compact",
|
||||
).start()
|
||||
@@ -248,13 +227,13 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
self._table.optimize()
|
||||
# Refresh the scope index so new fragments are covered.
|
||||
self._ensure_scope_index()
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> Any:
|
||||
"""Return the table, creating it lazily if needed.
|
||||
|
||||
Args:
|
||||
@@ -280,7 +259,9 @@ class LanceDBStorage:
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"source": record.source or "",
|
||||
"private": record.private,
|
||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
||||
"vector": record.embedding
|
||||
if record.embedding
|
||||
else [0.0] * self._vector_dim,
|
||||
}
|
||||
|
||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||
@@ -296,7 +277,9 @@ class LanceDBStorage:
|
||||
id=str(row["id"]),
|
||||
content=str(row["content"]),
|
||||
scope=str(row["scope"]),
|
||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
||||
categories=json.loads(row["categories_str"])
|
||||
if row.get("categories_str")
|
||||
else [],
|
||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||
importance=float(row.get("importance", 0.5)),
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -316,16 +299,15 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", rows)
|
||||
# Create the scope index on the first save so it covers the initial dataset.
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
rows = [self._record_to_row(rec) for rec in records]
|
||||
for row in rows:
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
# Auto-compact every N saves so fragment files don't pile up.
|
||||
self._save_count += 1
|
||||
if self._compact_every > 0 and self._save_count % self._compact_every == 0:
|
||||
@@ -333,14 +315,14 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._retry_write("delete", f"id = '{safe_id}'")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
row = self._record_to_row(record)
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", [row])
|
||||
self._do_write("add", [row])
|
||||
|
||||
def touch_records(self, record_ids: list[str]) -> None:
|
||||
"""Update last_accessed to now for the given record IDs.
|
||||
@@ -354,11 +336,11 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
self._retry_write(
|
||||
self._do_write(
|
||||
"update",
|
||||
where=f"id IN ({ids_expr})",
|
||||
values={"last_accessed": now},
|
||||
@@ -390,13 +372,17 @@ class LanceDBStorage:
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
distance = row.get("_distance", 0.0)
|
||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||
@@ -416,30 +402,34 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - int(self._table.count_rows())
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
if categories and not any(
|
||||
c in record.categories for c in categories
|
||||
):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
to_delete.append(record.id)
|
||||
if not to_delete:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - int(self._table.count_rows())
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
@@ -449,13 +439,13 @@ class LanceDBStorage:
|
||||
if older_than is not None:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - int(self._table.count_rows())
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", where_expr)
|
||||
return before - int(self._table.count_rows())
|
||||
|
||||
def _scan_rows(
|
||||
self,
|
||||
@@ -482,7 +472,8 @@ class LanceDBStorage:
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
if columns is not None:
|
||||
q = q.select(columns)
|
||||
return q.limit(limit).to_list()
|
||||
result: list[dict[str, Any]] = q.limit(limit).to_list()
|
||||
return result
|
||||
|
||||
def list_records(
|
||||
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
|
||||
@@ -528,7 +519,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix):]
|
||||
rest = sc[len(child_prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(child_prefix + first_component)
|
||||
@@ -539,7 +530,11 @@ class LanceDBStorage:
|
||||
pass
|
||||
created = row.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
||||
dt = (
|
||||
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||
if isinstance(created, str)
|
||||
else created
|
||||
)
|
||||
if isinstance(dt, datetime):
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
@@ -562,7 +557,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix):]
|
||||
rest = sc[len(prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(prefix + first_component)
|
||||
@@ -585,22 +580,24 @@ class LanceDBStorage:
|
||||
if self._table is None:
|
||||
return 0
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
return self._table.count_rows()
|
||||
return int(self._table.count_rows())
|
||||
info = self.get_scope_info(scope_prefix)
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
# Dimension is preserved; table will be recreated on next save.
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
||||
with store_lock(self._lock_name):
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._do_write(
|
||||
"delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'"
|
||||
)
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Compact the table synchronously and refresh the scope index.
|
||||
@@ -614,8 +611,9 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
self.save(records)
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PlainValidator, PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -39,13 +42,18 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def _passthrough(v: Any) -> Any:
|
||||
"""PlainValidator that accepts any value, bypassing strict union discrimination."""
|
||||
return v
|
||||
|
||||
|
||||
def _default_embedder() -> OpenAIEmbeddingFunction:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
|
||||
return build_embedder(spec)
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(BaseModel):
|
||||
"""Unified memory: standalone, LLM-analyzed, with intelligent recall flow.
|
||||
|
||||
Works without agent/crew. Uses LLM to infer scope, categories, importance on save.
|
||||
@@ -53,116 +61,119 @@ class Memory:
|
||||
pluggable storage (LanceDB default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | str = "gpt-4o-mini",
|
||||
storage: StorageBackend | str = "lancedb",
|
||||
embedder: Any = None,
|
||||
# -- Scoring weights --
|
||||
# These three weights control how recall results are ranked.
|
||||
# The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance.
|
||||
# They should sum to ~1.0 for intuitive scoring.
|
||||
recency_weight: float = 0.3,
|
||||
semantic_weight: float = 0.5,
|
||||
importance_weight: float = 0.2,
|
||||
# How quickly old memories lose relevance. The recency score halves every
|
||||
# N days (exponential decay). Lower = faster forgetting; higher = longer relevance.
|
||||
recency_half_life_days: int = 30,
|
||||
# -- Consolidation --
|
||||
# When remembering new content, if an existing record has similarity >= this
|
||||
# threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable.
|
||||
consolidation_threshold: float = 0.85,
|
||||
# Max existing records to compare against when checking for consolidation.
|
||||
consolidation_limit: int = 5,
|
||||
# -- Save defaults --
|
||||
# Importance assigned to new memories when no explicit value is given and
|
||||
# the LLM analysis path is skipped (all fields provided by the caller).
|
||||
default_importance: float = 0.5,
|
||||
# -- Recall depth control --
|
||||
# These thresholds govern the RecallFlow router that decides between
|
||||
# returning results immediately ("synthesize") vs. doing an extra
|
||||
# LLM-driven exploration round ("explore_deeper").
|
||||
# confidence >= confidence_threshold_high => always synthesize
|
||||
# confidence < confidence_threshold_low => explore deeper (if budget > 0)
|
||||
# complex query + confidence < complex_query_threshold => explore deeper
|
||||
confidence_threshold_high: float = 0.8,
|
||||
confidence_threshold_low: float = 0.5,
|
||||
complex_query_threshold: float = 0.7,
|
||||
# How many LLM-driven exploration rounds the RecallFlow is allowed to run.
|
||||
# 0 = always shallow (vector search only); higher = more thorough but slower.
|
||||
exploration_budget: int = 1,
|
||||
# Queries shorter than this skip LLM analysis (saving ~1-3s).
|
||||
# Longer queries (full task descriptions) benefit from LLM distillation.
|
||||
query_analysis_threshold: int = 200,
|
||||
# When True, all write operations (remember, remember_many) are silently
|
||||
# skipped. Useful for sharing a read-only view of memory across agents
|
||||
# without any of them persisting new memories.
|
||||
read_only: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Memory.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
llm: LLM for analysis (model name or BaseLLM instance).
|
||||
storage: Backend: "lancedb" or a StorageBackend instance.
|
||||
embedder: Embedding callable, provider config dict, or None (default OpenAI).
|
||||
recency_weight: Weight for recency in the composite relevance score.
|
||||
semantic_weight: Weight for semantic similarity in the composite relevance score.
|
||||
importance_weight: Weight for importance in the composite relevance score.
|
||||
recency_half_life_days: Recency score halves every N days (exponential decay).
|
||||
consolidation_threshold: Similarity above which consolidation is triggered on save.
|
||||
consolidation_limit: Max existing records to compare during consolidation.
|
||||
default_importance: Default importance when not provided or inferred.
|
||||
confidence_threshold_high: Recall confidence above which results are returned directly.
|
||||
confidence_threshold_low: Recall confidence below which deeper exploration is triggered.
|
||||
complex_query_threshold: For complex queries, explore deeper below this confidence.
|
||||
exploration_budget: Number of LLM-driven exploration rounds during deep recall.
|
||||
query_analysis_threshold: Queries shorter than this skip LLM analysis during deep recall.
|
||||
read_only: If True, remember() and remember_many() are silent no-ops.
|
||||
"""
|
||||
self._read_only = read_only
|
||||
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="LLM for analysis (model name or BaseLLM instance).",
|
||||
)
|
||||
storage: Annotated[StorageBackend | str, PlainValidator(_passthrough)] = Field(
|
||||
default="lancedb",
|
||||
description="Storage backend instance or path string.",
|
||||
)
|
||||
embedder: Any = Field(
|
||||
default=None,
|
||||
description="Embedding callable, provider config dict, or None for default OpenAI.",
|
||||
)
|
||||
recency_weight: float = Field(
|
||||
default=0.3,
|
||||
description="Weight for recency in the composite relevance score.",
|
||||
)
|
||||
semantic_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight for semantic similarity in the composite relevance score.",
|
||||
)
|
||||
importance_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for importance in the composite relevance score.",
|
||||
)
|
||||
recency_half_life_days: int = Field(
|
||||
default=30,
|
||||
description="Recency score halves every N days (exponential decay).",
|
||||
)
|
||||
consolidation_threshold: float = Field(
|
||||
default=0.85,
|
||||
description="Similarity above which consolidation is triggered on save.",
|
||||
)
|
||||
consolidation_limit: int = Field(
|
||||
default=5,
|
||||
description="Max existing records to compare during consolidation.",
|
||||
)
|
||||
default_importance: float = Field(
|
||||
default=0.5,
|
||||
description="Default importance when not provided or inferred.",
|
||||
)
|
||||
confidence_threshold_high: float = Field(
|
||||
default=0.8,
|
||||
description="Recall confidence above which results are returned directly.",
|
||||
)
|
||||
confidence_threshold_low: float = Field(
|
||||
default=0.5,
|
||||
description="Recall confidence below which deeper exploration is triggered.",
|
||||
)
|
||||
complex_query_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="For complex queries, explore deeper below this confidence.",
|
||||
)
|
||||
exploration_budget: int = Field(
|
||||
default=1,
|
||||
description="Number of LLM-driven exploration rounds during deep recall.",
|
||||
)
|
||||
query_analysis_threshold: int = Field(
|
||||
default=200,
|
||||
description="Queries shorter than this skip LLM analysis during deep recall.",
|
||||
)
|
||||
read_only: bool = Field(
|
||||
default=False,
|
||||
description="If True, remember() and remember_many() are silent no-ops.",
|
||||
)
|
||||
|
||||
_config: MemoryConfig = PrivateAttr()
|
||||
_llm_instance: BaseLLM | None = PrivateAttr(default=None)
|
||||
_embedder_instance: Any = PrivateAttr(default=None)
|
||||
_storage: StorageBackend = PrivateAttr()
|
||||
_save_pool: ThreadPoolExecutor = PrivateAttr(
|
||||
default_factory=lambda: ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
)
|
||||
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize runtime state from field values."""
|
||||
self._config = MemoryConfig(
|
||||
recency_weight=recency_weight,
|
||||
semantic_weight=semantic_weight,
|
||||
importance_weight=importance_weight,
|
||||
recency_half_life_days=recency_half_life_days,
|
||||
consolidation_threshold=consolidation_threshold,
|
||||
consolidation_limit=consolidation_limit,
|
||||
default_importance=default_importance,
|
||||
confidence_threshold_high=confidence_threshold_high,
|
||||
confidence_threshold_low=confidence_threshold_low,
|
||||
complex_query_threshold=complex_query_threshold,
|
||||
exploration_budget=exploration_budget,
|
||||
query_analysis_threshold=query_analysis_threshold,
|
||||
recency_weight=self.recency_weight,
|
||||
semantic_weight=self.semantic_weight,
|
||||
importance_weight=self.importance_weight,
|
||||
recency_half_life_days=self.recency_half_life_days,
|
||||
consolidation_threshold=self.consolidation_threshold,
|
||||
consolidation_limit=self.consolidation_limit,
|
||||
default_importance=self.default_importance,
|
||||
confidence_threshold_high=self.confidence_threshold_high,
|
||||
confidence_threshold_low=self.confidence_threshold_low,
|
||||
complex_query_threshold=self.complex_query_threshold,
|
||||
exploration_budget=self.exploration_budget,
|
||||
query_analysis_threshold=self.query_analysis_threshold,
|
||||
)
|
||||
|
||||
# Store raw config for lazy initialization. LLM and embedder are only
|
||||
# built on first access so that Memory() never fails at construction
|
||||
# time (e.g. when auto-created by Flow without an API key set).
|
||||
self._llm_config: BaseLLM | str = llm
|
||||
self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm
|
||||
self._embedder_config: Any = embedder
|
||||
self._embedder_instance: Any = (
|
||||
embedder
|
||||
if (embedder is not None and not isinstance(embedder, dict))
|
||||
self._llm_instance = None if isinstance(self.llm, str) else self.llm
|
||||
self._embedder_instance = (
|
||||
self.embedder
|
||||
if (self.embedder is not None and not isinstance(self.embedder, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(storage, str):
|
||||
if isinstance(self.storage, str):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
|
||||
self._storage = (
|
||||
LanceDBStorage()
|
||||
if self.storage == "lancedb"
|
||||
else LanceDBStorage(path=self.storage)
|
||||
)
|
||||
else:
|
||||
self._storage = storage
|
||||
|
||||
# Background save queue. max_workers=1 serializes saves to avoid
|
||||
# concurrent storage mutations (two saves finding the same similar
|
||||
# record and both trying to update/delete it). Within each save,
|
||||
# the parallel LLM calls still run on their own thread pool.
|
||||
self._save_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
self._pending_saves: list[Future[Any]] = []
|
||||
self._pending_lock = threading.Lock()
|
||||
self._storage = self.storage
|
||||
|
||||
_MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory"
|
||||
|
||||
@@ -173,11 +184,7 @@ class Memory:
|
||||
from crewai.llm import LLM
|
||||
|
||||
try:
|
||||
model_name = (
|
||||
self._llm_config
|
||||
if isinstance(self._llm_config, str)
|
||||
else str(self._llm_config)
|
||||
)
|
||||
model_name = self.llm if isinstance(self.llm, str) else str(self.llm)
|
||||
self._llm_instance = LLM(model=model_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
@@ -197,8 +204,8 @@ class Memory:
|
||||
"""Lazy embedder initialization -- only created when first needed."""
|
||||
if self._embedder_instance is None:
|
||||
try:
|
||||
if isinstance(self._embedder_config, dict):
|
||||
self._embedder_instance = build_embedder(self._embedder_config)
|
||||
if isinstance(self.embedder, dict):
|
||||
self._embedder_instance = build_embedder(self.embedder)
|
||||
else:
|
||||
self._embedder_instance = _default_embedder()
|
||||
except Exception as e:
|
||||
@@ -223,8 +230,9 @@ class Memory:
|
||||
If the pool has been shut down (e.g. after ``close()``), the save
|
||||
runs synchronously as a fallback so late saves still succeed.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
try:
|
||||
future: Future[Any] = self._save_pool.submit(fn, *args, **kwargs)
|
||||
future: Future[Any] = self._save_pool.submit(ctx.run, fn, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
@@ -356,7 +364,7 @@ class Memory:
|
||||
Raises:
|
||||
Exception: On save failure (events emitted).
|
||||
"""
|
||||
if self._read_only:
|
||||
if self.read_only:
|
||||
return None
|
||||
_source_type = "unified_memory"
|
||||
try:
|
||||
@@ -444,7 +452,7 @@ class Memory:
|
||||
Returns:
|
||||
Empty list (records are not available until the background save completes).
|
||||
"""
|
||||
if not contents or self._read_only:
|
||||
if not contents or self.read_only:
|
||||
return []
|
||||
|
||||
self._submit_save(
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
@@ -169,8 +170,9 @@ def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return pool.submit(ctx.run, asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
@@ -146,8 +147,9 @@ def _resolve_result(result: Any) -> Any:
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return pool.submit(ctx.run, asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractContextManager, asynccontextmanager, nullcontext
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -29,6 +32,7 @@ from crewai.rag.core.base_client import (
|
||||
BaseCollectionParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
@@ -52,6 +56,7 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
lock_name: str = "",
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -61,12 +66,32 @@ class ChromaDBClient(BaseClient):
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
lock_name: Optional lock name for cross-process synchronization.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
self._lock_name = lock_name
|
||||
|
||||
def _locked(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock context manager, or nullcontext if no lock name."""
|
||||
return store_lock(self._lock_name) if self._lock_name else nullcontext()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _alocked(self) -> AsyncIterator[None]:
|
||||
"""Async cross-process lock that acquires/releases in an executor."""
|
||||
if not self._lock_name:
|
||||
yield
|
||||
return
|
||||
lock_cm = store_lock(self._lock_name)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lock_cm.__enter__)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await loop.run_in_executor(None, lock_cm.__exit__, None, None, None)
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -313,23 +338,24 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
with self._locked():
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -363,22 +389,23 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
async with self._alocked():
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
@@ -531,7 +558,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
with self._locked():
|
||||
self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
@@ -561,9 +591,10 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
async with self._alocked():
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
@@ -586,7 +617,8 @@ class ChromaDBClient(BaseClient):
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
with self._locked():
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
@@ -612,4 +644,5 @@ class ChromaDBClient(BaseClient):
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
async with self._alocked():
|
||||
await self.client.reset()
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
from hashlib import md5
|
||||
import os
|
||||
|
||||
from chromadb import PersistentClient
|
||||
import portalocker
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.utilities.lock_store import lock
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
@@ -25,10 +24,8 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
with lock(f"chromadb:{persist_dir}"):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
@@ -42,4 +39,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
lock_name=f"chromadb:{persist_dir}",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
import contextvars
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
from hashlib import md5
|
||||
@@ -524,10 +525,11 @@ class Task(BaseModel):
|
||||
) -> Future[TaskOutput]:
|
||||
"""Execute the task asynchronously."""
|
||||
future: Future[TaskOutput] = Future()
|
||||
ctx = contextvars.copy_context()
|
||||
threading.Thread(
|
||||
daemon=True,
|
||||
target=self._execute_task_async,
|
||||
args=(agent, context, tools, future),
|
||||
target=ctx.run,
|
||||
args=(self._execute_task_async, agent, context, tools, future),
|
||||
).start()
|
||||
return future
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import Parameter, signature
|
||||
import json
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
@@ -18,6 +19,7 @@ from pydantic import (
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
@@ -94,6 +96,7 @@ class BaseTool(BaseModel, ABC):
|
||||
default=0,
|
||||
description="Current number of times this tool has been used.",
|
||||
)
|
||||
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
@field_validator("args_schema", mode="before")
|
||||
@classmethod
|
||||
@@ -173,6 +176,25 @@ class BaseTool(BaseModel, ABC):
|
||||
) from e
|
||||
return kwargs
|
||||
|
||||
def _claim_usage(self) -> str | None:
|
||||
"""Atomically check max usage and increment the counter.
|
||||
|
||||
Returns:
|
||||
None if usage was claimed successfully, or an error message
|
||||
string if the tool has reached its usage limit.
|
||||
"""
|
||||
with self._usage_lock:
|
||||
if (
|
||||
self.max_usage_count is not None
|
||||
and self.current_usage_count >= self.max_usage_count
|
||||
):
|
||||
return (
|
||||
f"Tool '{self.name}' has reached its usage limit of "
|
||||
f"{self.max_usage_count} times and cannot be used anymore."
|
||||
)
|
||||
self.current_usage_count += 1
|
||||
return None
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -181,13 +203,15 @@ class BaseTool(BaseModel, ABC):
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error
|
||||
|
||||
result = self._run(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
|
||||
return result
|
||||
|
||||
async def arun(
|
||||
@@ -206,9 +230,12 @@ class BaseTool(BaseModel, ABC):
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error
|
||||
|
||||
return await self._arun(*args, **kwargs)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@@ -361,12 +388,15 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error # type: ignore[return-value]
|
||||
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@@ -393,9 +423,12 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error # type: ignore[return-value]
|
||||
|
||||
return await self._arun(*args, **kwargs)
|
||||
|
||||
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function asynchronously.
|
||||
|
||||
@@ -1,29 +1,31 @@
|
||||
"""Native MCP tool wrapper for CrewAI agents.
|
||||
|
||||
This module provides a tool wrapper that reuses existing MCP client sessions
|
||||
for better performance and connection management.
|
||||
This module provides a tool wrapper that creates a fresh MCP client for every
|
||||
invocation, ensuring safe parallel execution even when the same tool is called
|
||||
concurrently by the executor.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class MCPNativeTool(BaseTool):
|
||||
"""Native MCP tool that reuses client sessions.
|
||||
"""Native MCP tool that creates a fresh client per invocation.
|
||||
|
||||
This tool wrapper is used when agents connect to MCP servers using
|
||||
structured configurations. It reuses existing client sessions for
|
||||
better performance and proper connection lifecycle management.
|
||||
|
||||
Unlike MCPToolWrapper which connects on-demand, this tool uses
|
||||
a shared MCP client instance that maintains a persistent connection.
|
||||
A ``client_factory`` callable produces an independent ``MCPClient`` +
|
||||
transport for every ``_run_async`` call. This guarantees that parallel
|
||||
invocations -- whether of the *same* tool or *different* tools from the
|
||||
same server -- never share mutable connection state (which would cause
|
||||
anyio cancel-scope errors).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_client: Any,
|
||||
client_factory: Callable[[], Any],
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
@@ -32,19 +34,16 @@ class MCPNativeTool(BaseTool):
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
client_factory: Zero-arg callable that returns a new MCPClient.
|
||||
tool_name: Name of the tool (may be prefixed).
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
@@ -57,16 +56,9 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._client_factory = client_factory
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def mcp_client(self) -> Any:
|
||||
"""Get the MCP client instance."""
|
||||
return self._mcp_client
|
||||
|
||||
@property
|
||||
def original_tool_name(self) -> str:
|
||||
@@ -93,9 +85,10 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
coro = self._run_async(**kwargs)
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
future = executor.submit(ctx.run, asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._run_async(**kwargs))
|
||||
@@ -108,51 +101,26 @@ class MCPNativeTool(BaseTool):
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
A fresh ``MCPClient`` is created for every invocation so that
|
||||
concurrent calls never share transport or session state.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||
# use anyio.create_task_group() which can't span different event loops
|
||||
if self._mcp_client.connected:
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
await self._mcp_client.connect()
|
||||
client = self._client_factory()
|
||||
await client.connect()
|
||||
|
||||
try:
|
||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"not connected" in error_str
|
||||
or "connection" in error_str
|
||||
or "send" in error_str
|
||||
):
|
||||
await self._mcp_client.disconnect()
|
||||
await self._mcp_client.connect()
|
||||
# Retry the call
|
||||
result = await self._mcp_client.call_tool(
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
result = await client.call_tool(self.original_tool_name, kwargs)
|
||||
finally:
|
||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||
# This prevents "exit cancel scope in different task" errors
|
||||
# All transport context managers must be exited in the same event loop they were entered
|
||||
await self._mcp_client.disconnect()
|
||||
await client.disconnect()
|
||||
|
||||
# Extract result content
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
|
||||
# Handle various result formats
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
|
||||
@@ -121,7 +121,7 @@ def create_memory_tools(memory: Any) -> list[BaseTool]:
|
||||
description=i18n.tools("recall_memory"),
|
||||
),
|
||||
]
|
||||
if not getattr(memory, "_read_only", False):
|
||||
if not memory.read_only:
|
||||
tools.append(
|
||||
RememberTool(
|
||||
memory=memory,
|
||||
|
||||
@@ -74,9 +74,28 @@
|
||||
"consolidation_user": "New content to consider storing:\n{new_content}\n\nExisting similar memories:\n{records_summary}\n\nReturn the consolidation plan as structured output."
|
||||
},
|
||||
"reasoning": {
|
||||
"initial_plan": "You are {role}, a professional with the following background: {backstory}\n\nYour primary goal is: {goal}\n\nAs {role}, you are creating a strategic plan for a task that requires your expertise and unique perspective.",
|
||||
"refine_plan": "You are {role}, a professional with the following background: {backstory}\n\nYour primary goal is: {goal}\n\nAs {role}, you are refining a strategic plan for a task that requires your expertise and unique perspective.",
|
||||
"create_plan_prompt": "You are {role} with this background: {backstory}\n\nYour primary goal is: {goal}\n\nYou have been assigned the following task:\n{description}\n\nExpected output:\n{expected_output}\n\nAvailable tools: {tools}\n\nBefore executing this task, create a detailed plan that leverages your expertise as {role} and outlines:\n1. Your understanding of the task from your professional perspective\n2. The key steps you'll take to complete it, drawing on your background and skills\n3. How you'll approach any challenges that might arise, considering your expertise\n4. How you'll strategically use the available tools based on your experience, exactly what tools to use and how to use them\n5. The expected outcome and how it aligns with your goal\n\nAfter creating your plan, assess whether you feel ready to execute the task or if you could do better.\nConclude with one of these statements:\n- \"READY: I am ready to execute the task.\"\n- \"NOT READY: I need to refine my plan because [specific reason].\"",
|
||||
"refine_plan_prompt": "You are {role} with this background: {backstory}\n\nYour primary goal is: {goal}\n\nYou created the following plan for this task:\n{current_plan}\n\nHowever, you indicated that you're not ready to execute the task yet.\n\nPlease refine your plan further, drawing on your expertise as {role} to address any gaps or uncertainties. As you refine your plan, be specific about which available tools you will use, how you will use them, and why they are the best choices for each step. Clearly outline your tool usage strategy as part of your improved plan.\n\nAfter refining your plan, assess whether you feel ready to execute the task.\nConclude with one of these statements:\n- \"READY: I am ready to execute the task.\"\n- \"NOT READY: I need to refine my plan further because [specific reason].\""
|
||||
"initial_plan": "You are {role}. Create a focused execution plan using only the essential steps needed.",
|
||||
"refine_plan": "You are {role}. Refine your plan to address the specific gap while keeping it minimal.",
|
||||
"create_plan_prompt": "You are {role}.\n\nTask: {description}\n\nExpected output: {expected_output}\n\nAvailable tools: {tools}\n\nCreate a focused plan with ONLY the essential steps needed. Most tasks require just 2-5 steps. Do NOT pad with unnecessary steps like \"review\", \"verify\", \"document\", or \"finalize\" unless explicitly required.\n\nFor each step, specify the action and which tool to use (if any).\n\nConclude with:\n- \"READY: I am ready to execute the task.\"\n- \"NOT READY: I need to refine my plan because [specific reason].\"",
|
||||
"refine_plan_prompt": "Your plan:\n{current_plan}\n\nYou indicated you're not ready. Address the specific gap while keeping the plan minimal.\n\nConclude with READY or NOT READY."
|
||||
},
|
||||
"planning": {
|
||||
"system_prompt": "You are a strategic planning assistant. Create concrete, executable plans where every step produces a verifiable result.",
|
||||
"create_plan_prompt": "Create an execution plan for the following task:\n\n## Task\n{description}\n\n## Expected Output\n{expected_output}\n\n## Available Tools\n{tools}\n\n## Planning Principles\nFocus on CONCRETE, EXECUTABLE steps. Each step must clearly state WHAT ACTION to take and HOW to verify it succeeded. The number of steps should match the task complexity. Hard limit: {max_steps} steps.\n\n## Rules:\n- Each step must have a clear DONE criterion\n- Do NOT group unrelated actions: if steps can fail independently, keep them separate\n- NO standalone \"thinking\" or \"planning\" steps — act, don't just observe\n- The last step must produce the required output\n\nAfter your plan, state READY or NOT READY.",
|
||||
"refine_plan_prompt": "Your previous plan:\n{current_plan}\n\nYou indicated you weren't ready. Refine your plan to address the specific gap.\n\nKeep the plan minimal - only add steps that directly address the issue.\n\nConclude with READY or NOT READY as before.",
|
||||
"observation_system_prompt": "You are a Planning Agent observing execution progress. After each step completes, you analyze what happened and decide whether the remaining plan is still valid.\n\nReason step-by-step about:\n1. Did this step produce a concrete, verifiable result? (file created, command succeeded, service running, etc.) — or did it only explore without acting?\n2. What new information was learned from this step's result?\n3. Whether the remaining steps still make sense given this new information\n4. What refinements, if any, are needed for upcoming steps\n5. Whether the overall goal has already been achieved\n\nCritical: mark `step_completed_successfully=false` if:\n- The step result is only exploratory (ls, pwd, cat) without producing the required artifact or action\n- A command returned a non-zero exit code and the error was not recovered\n- The step description required creating/building/starting something and the result shows it was not done\n\nBe conservative about triggering full replans — only do so when the remaining plan is fundamentally wrong, not just suboptimal.\n\nIMPORTANT: Set step_completed_successfully=false if:\n- The step's stated goal was NOT achieved (even if other things were done)\n- The first meaningful action returned an error (file not found, command not found, etc.)\n- The result is exploration/discovery output rather than the concrete action the step required\n- The step ran out of attempts without producing the required output\nSet needs_full_replan=true if the current plan's remaining steps reference paths or state that don't exist yet and need to be created first.",
|
||||
"observation_user_prompt": "## Original task\n{task_description}\n\n## Expected output\n{task_goal}\n{completed_summary}\n\n## Just completed step {step_number}\nDescription: {step_description}\nResult: {step_result}\n{remaining_summary}\n\nAnalyze this step's result and provide your observation.",
|
||||
"step_executor_system_prompt": "You are {role}. {backstory}\n\nYour goal: {goal}\n\nYou are executing ONE specific step in a larger plan. Your ONLY job is to fully complete this step — not to plan ahead.\n\nKey rules:\n- **ACT FIRST.** Execute the primary action of this step immediately. Do NOT read or explore files before attempting the main action unless exploration IS the step's goal.\n- If the step says 'run X', run X NOW. If it says 'write file Y', write Y NOW.\n- If the step requires producing an output file (e.g. /app/move.txt, report.jsonl, summary.csv), you MUST write that file using a tool call — do NOT just state the answer in text.\n- You may use tools MULTIPLE TIMES. After each tool use, check the result. If it failed, try a different approach.\n- Only output your Final Answer AFTER the concrete outcome is verified (file written, build succeeded, command exited 0).\n- If a command is not found or a path does not exist, fix it (different PATH, install missing deps, use absolute paths).\n- Do NOT spend more than 3 tool calls on exploration/analysis before attempting the primary action.{tools_section}",
|
||||
"step_executor_tools_section": "\n\nAvailable tools: {tool_names}\n\nYou may call tools multiple times in sequence. Use this format for EACH tool call:\nThought: <what you observed and what you will try next>\nAction: <tool_name>\nAction Input: <input>\n\nAfter observing each result, decide: is the step complete? If yes:\nThought: The step is done because <evidence>\nFinal Answer: <concise summary of what was accomplished and the key result>",
|
||||
"step_executor_user_prompt": "## Current Step\n{step_description}",
|
||||
"step_executor_suggested_tool": "\nSuggested tool: {tool_to_use}",
|
||||
"step_executor_context_header": "\n## Context from previous steps:",
|
||||
"step_executor_context_entry": "Step {step_number} result: {result}",
|
||||
"step_executor_complete_step": "\n**Execute the primary action of this step NOW.** If the step requires writing a file, write it. If it requires running a command, run it. Verify the outcome with a follow-up tool call, then give your Final Answer. Your Final Answer must confirm what was DONE (file created at path X, command succeeded), not just what should be done.",
|
||||
"todo_system_prompt": "You are {role}. Your goal: {goal}\n\nYou are executing a specific step in a multi-step plan. Focus only on completing the current step. Use the suggested tool if one is provided. Be concise and provide clear results that can be used by subsequent steps.",
|
||||
"synthesis_system_prompt": "You are {role}. You have completed a multi-step task. Synthesize the results from all steps into a single, coherent final response that directly addresses the original task. Do NOT list step numbers or say 'Step 1 result'. Produce a clean, polished answer as if you did it all at once.",
|
||||
"synthesis_user_prompt": "## Original Task\n{task_description}\n\n## Results from each step\n{combined_steps}\n\nSynthesize these results into a single, coherent final answer.",
|
||||
"replan_enhancement_prompt": "\n\nIMPORTANT: Previous execution attempt did not fully succeed. Please create a revised plan that accounts for the following context from the previous attempt:\n\n{previous_context}\n\nConsider:\n1. What steps succeeded and can be built upon\n2. What steps failed and why they might have failed\n3. Alternative approaches that might work better\n4. Whether dependencies need to be restructured",
|
||||
"step_executor_task_context": "## Task Context\nThe following is the full task you are helping complete. Keep this in mind — especially any required output files, exact filenames, and expected formats.\n\n{task_context}\n\n---\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user