mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-11 06:18:19 +00:00
Compare commits
9 Commits
cursor/cod
...
gl/fix/con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a15aa0fb97 | ||
|
|
67bc64e82c | ||
|
|
a037ade1ca | ||
|
|
1bc92ebb5f | ||
|
|
0046f9a96f | ||
|
|
e72a80be6e | ||
|
|
7cffcab84a | ||
|
|
f070ce8abd | ||
|
|
d9f6e2222f |
@@ -1,97 +1,316 @@
|
||||
---
|
||||
title: Brave Search
|
||||
description: The `BraveSearchTool` is designed to search the internet using the Brave Search API.
|
||||
title: Brave Search Tools
|
||||
description: A suite of tools for querying the Brave Search API — covering web, news, image, and video search.
|
||||
icon: searchengin
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
# `BraveSearchTool`
|
||||
# Brave Search Tools
|
||||
|
||||
## Description
|
||||
|
||||
This tool is designed to perform web searches using the Brave Search API. It allows you to search the internet with a specified query and retrieve relevant results. The tool supports customizable result counts and country-specific searches.
|
||||
CrewAI offers a family of Brave Search tools, each targeting a specific [Brave Search API](https://brave.com/search/api/) endpoint.
|
||||
Rather than a single catch-all tool, you can pick exactly the tool that matches the kind of results your agent needs:
|
||||
|
||||
| Tool | Endpoint | Use case |
|
||||
| --- | --- | --- |
|
||||
| `BraveWebSearchTool` | Web Search | General web results, snippets, and URLs |
|
||||
| `BraveNewsSearchTool` | News Search | Recent news articles and headlines |
|
||||
| `BraveImageSearchTool` | Image Search | Image results with dimensions and source URLs |
|
||||
| `BraveVideoSearchTool` | Video Search | Video results from across the web |
|
||||
| `BraveLocalPOIsTool` | Local POIs | Find points of interest (e.g., restaurants) |
|
||||
| `BraveLocalPOIsDescriptionTool` | Local POIs | Retrieve AI-generated location descriptions |
|
||||
| `BraveLLMContextTool` | LLM Context | Pre-extracted web content optimized for AI agents, LLM grounding, and RAG pipelines. |
|
||||
|
||||
All tools share a common base class (`BraveSearchToolBase`) that provides consistent behavior — rate limiting, automatic retries on `429` responses, header and parameter validation, and optional file saving.
|
||||
|
||||
<Note>
|
||||
The older `BraveSearchTool` class is still available for backwards compatibility, but it is considered **legacy** and will not receive the same level of attention going forward. We recommend migrating to the specific tools listed above, which offer richer configuration and a more focused interface.
|
||||
</Note>
|
||||
|
||||
<Note>
|
||||
While many tools (e.g., _BraveWebSearchTool_, _BraveNewsSearchTool_, _BraveImageSearchTool_, and _BraveVideoSearchTool_) can be used with a free Brave Search API subscription/plan, some parameters (e.g., `enable_snippets`) and tools (e.g., _BraveLocalPOIsTool_ and _BraveLocalPOIsDescriptionTool_) require a paid plan. Consult your subscription plan's capabilities for clarification.
|
||||
</Note>
|
||||
|
||||
## Installation
|
||||
|
||||
To incorporate this tool into your project, follow the installation instructions below:
|
||||
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Steps to Get Started
|
||||
## Getting Started
|
||||
|
||||
To effectively use the `BraveSearchTool`, follow these steps:
|
||||
1. **Install the package** — confirm that `crewai[tools]` is installed in your Python environment.
|
||||
2. **Get an API key** — sign up at [api-dashboard.search.brave.com/login](https://api-dashboard.search.brave.com/login) to generate a key.
|
||||
3. **Set the environment variable** — store your key as `BRAVE_API_KEY`, or pass it directly via the `api_key` parameter.
|
||||
|
||||
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
|
||||
2. **API Key Acquisition**: Acquire a Brave Search API key at https://api.search.brave.com/app/keys (sign in to generate a key).
|
||||
3. **Environment Configuration**: Store your obtained API key in an environment variable named `BRAVE_API_KEY` to facilitate its use by the tool.
|
||||
## Quick Examples
|
||||
|
||||
## Example
|
||||
|
||||
The following example demonstrates how to initialize the tool and execute a search with a given query:
|
||||
### Web Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
# Initialize the tool for internet searching capabilities
|
||||
tool = BraveSearchTool()
|
||||
|
||||
# Execute a search
|
||||
results = tool.run(search_query="CrewAI agent framework")
|
||||
tool = BraveWebSearchTool()
|
||||
results = tool.run(q="CrewAI agent framework")
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
The `BraveSearchTool` accepts the following parameters:
|
||||
|
||||
- **search_query**: Mandatory. The search query you want to use to search the internet.
|
||||
- **country**: Optional. Specify the country for the search results. Default is empty string.
|
||||
- **n_results**: Optional. Number of search results to return. Default is `10`.
|
||||
- **save_file**: Optional. Whether to save the search results to a file. Default is `False`.
|
||||
|
||||
## Example with Parameters
|
||||
|
||||
Here is an example demonstrating how to use the tool with additional parameters:
|
||||
### News Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveNewsSearchTool
|
||||
|
||||
# Initialize the tool with custom parameters
|
||||
tool = BraveSearchTool(
|
||||
country="US",
|
||||
n_results=5,
|
||||
save_file=True
|
||||
tool = BraveNewsSearchTool()
|
||||
results = tool.run(q="latest AI breakthroughs")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Image Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveImageSearchTool
|
||||
|
||||
tool = BraveImageSearchTool()
|
||||
results = tool.run(q="northern lights photography")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Video Search
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveVideoSearchTool
|
||||
|
||||
tool = BraveVideoSearchTool()
|
||||
results = tool.run(q="how to build AI agents")
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Location POI Descriptions
|
||||
|
||||
```python Code
|
||||
from crewai_tools import (
|
||||
BraveWebSearchTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
)
|
||||
|
||||
# Execute a search
|
||||
results = tool.run(search_query="Latest AI developments")
|
||||
print(results)
|
||||
web_search = BraveWebSearchTool(raw=True)
|
||||
poi_details = BraveLocalPOIsDescriptionTool()
|
||||
|
||||
results = web_search.run(q="italian restaurants in pensacola, florida")
|
||||
|
||||
if "locations" in results:
|
||||
location_ids = [ loc["id"] for loc in results["locations"]["results"] ]
|
||||
if location_ids:
|
||||
descriptions = poi_details.run(ids=location_ids)
|
||||
print(descriptions)
|
||||
```
|
||||
|
||||
## Common Constructor Parameters
|
||||
|
||||
Every Brave Search tool accepts the following parameters at initialization:
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `api_key` | `str \| None` | `None` | Brave API key. Falls back to the `BRAVE_API_KEY` environment variable. |
|
||||
| `headers` | `dict \| None` | `None` | Additional HTTP headers to send with every request (e.g., `api-version`, geolocation headers). |
|
||||
| `requests_per_second` | `float` | `1.0` | Maximum request rate. The tool will sleep between calls to stay within this limit. |
|
||||
| `save_file` | `bool` | `False` | When `True`, each response is written to a timestamped `.txt` file. |
|
||||
| `raw` | `bool` | `False` | When `True`, the full API JSON response is returned without any refinement. |
|
||||
| `timeout` | `int` | `30` | HTTP request timeout in seconds. |
|
||||
| `country` | `str \| None` | `None` | Legacy shorthand for geo-targeting (e.g., `"US"`). Prefer using the `country` query parameter directly. |
|
||||
| `n_results` | `int` | `10` | Legacy shorthand for result count. Prefer using the `count` query parameter directly. |
|
||||
|
||||
<Warning>
|
||||
The `country` and `n_results` constructor parameters exist for backwards compatibility. They are applied as defaults when the corresponding query parameters (`country`, `count`) are not provided at call time. For new code, we recommend passing `country` and `count` directly as query parameters instead.
|
||||
</Warning>
|
||||
|
||||
## Query Parameters
|
||||
|
||||
Each tool validates its query parameters against a Pydantic schema before sending the request.
|
||||
The parameters vary slightly per endpoint — here is a summary of the most commonly used ones:
|
||||
|
||||
### BraveWebSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting (e.g., `"US"`). |
|
||||
| `search_lang` | Two-letter language code for results (e.g., `"en"`). |
|
||||
| `count` | Max number of results to return (1–20). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter: `"pd"` (past day), `"pw"` (past week), `"pm"` (past month), `"py"` (past year), or a date range like `"2025-01-01to2025-06-01"`. |
|
||||
| `extra_snippets` | Include up to 5 additional text snippets per result. |
|
||||
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Web Search API documentation](https://api-dashboard.search.brave.com/api-reference/web/search/get).
|
||||
|
||||
### BraveNewsSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–50). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter (same options as Web Search). |
|
||||
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave News Search API documentation](https://api-dashboard.search.brave.com/api-reference/news/news_search/get).
|
||||
|
||||
### BraveImageSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–200). |
|
||||
| `safesearch` | Content filter: `"off"` or `"strict"`. |
|
||||
| `spellcheck` | Attempt to correct spelling errors in the query. |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Image Search API documentation](https://api-dashboard.search.brave.com/api-reference/images/image_search).
|
||||
|
||||
### BraveVideoSearchTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `q` | **(required)** Search query string (max 400 chars). |
|
||||
| `country` | Two-letter country code for geo-targeting. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
| `count` | Max number of results to return (1–50). |
|
||||
| `offset` | Skip the first N pages of results (0–9). |
|
||||
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
|
||||
| `freshness` | Recency filter (same options as Web Search). |
|
||||
|
||||
For the complete parameter and header reference, see the [Brave Video Search API documentation](https://api-dashboard.search.brave.com/api-reference/videos/video_search/get).
|
||||
|
||||
### BraveLocalPOIsTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
|
||||
| `search_lang` | Two-letter language code for results. |
|
||||
|
||||
For the complete parameter and header reference, see [Brave Local POIs API documentation](https://api-dashboard.search.brave.com/api-reference/web/local_pois).
|
||||
|
||||
### BraveLocalPOIsDescriptionTool
|
||||
|
||||
| Parameter | Description |
|
||||
| --- | --- |
|
||||
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
|
||||
|
||||
For the complete parameter and header reference, see [Brave POI Descriptions API documentation](https://api-dashboard.search.brave.com/api-reference/web/poi_descriptions).
|
||||
|
||||
## Custom Headers
|
||||
|
||||
All tools support custom HTTP request headers. The Web Search tool, for example, accepts geolocation headers for location-aware results:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(
|
||||
headers={
|
||||
"x-loc-lat": "37.7749",
|
||||
"x-loc-long": "-122.4194",
|
||||
"x-loc-city": "San Francisco",
|
||||
"x-loc-state": "CA",
|
||||
"x-loc-country": "US",
|
||||
}
|
||||
)
|
||||
|
||||
results = tool.run(q="best coffee shops nearby")
|
||||
```
|
||||
|
||||
You can also update headers after initialization using the `set_headers()` method:
|
||||
|
||||
```python Code
|
||||
tool.set_headers({"api-version": "2025-01-01"})
|
||||
```
|
||||
|
||||
## Raw Mode
|
||||
|
||||
By default, each tool refines the API response into a concise list of results. If you need the full, unprocessed API response, enable raw mode:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(raw=True)
|
||||
full_response = tool.run(q="Brave Search API")
|
||||
```
|
||||
|
||||
## Agent Integration Example
|
||||
|
||||
Here's how to integrate the `BraveSearchTool` with a CrewAI agent:
|
||||
Here's how to equip a CrewAI agent with multiple Brave Search tools:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent
|
||||
from crewai.project import agent
|
||||
from crewai_tools import BraveSearchTool
|
||||
from crewai_tools import BraveWebSearchTool, BraveNewsSearchTool
|
||||
|
||||
# Initialize the tool
|
||||
brave_search_tool = BraveSearchTool()
|
||||
web_search = BraveWebSearchTool()
|
||||
news_search = BraveNewsSearchTool()
|
||||
|
||||
# Define an agent with the BraveSearchTool
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["researcher"],
|
||||
allow_delegation=False,
|
||||
tools=[brave_search_tool]
|
||||
tools=[web_search, news_search],
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Example
|
||||
|
||||
Combining multiple parameters for a targeted search:
|
||||
|
||||
```python Code
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(
|
||||
requests_per_second=0.5, # conservative rate limit
|
||||
save_file=True,
|
||||
)
|
||||
|
||||
results = tool.run(
|
||||
q="artificial intelligence news",
|
||||
country="US",
|
||||
search_lang="en",
|
||||
count=5,
|
||||
freshness="pm", # past month only
|
||||
extra_snippets=True,
|
||||
)
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Migrating from `BraveSearchTool` (Legacy)
|
||||
|
||||
If you are currently using `BraveSearchTool`, switching to the new tools is straightforward:
|
||||
|
||||
```python Code
|
||||
# Before (legacy)
|
||||
from crewai_tools import BraveSearchTool
|
||||
|
||||
tool = BraveSearchTool(country="US", n_results=5, save_file=True)
|
||||
results = tool.run(search_query="AI agents")
|
||||
|
||||
# After (recommended)
|
||||
from crewai_tools import BraveWebSearchTool
|
||||
|
||||
tool = BraveWebSearchTool(save_file=True)
|
||||
results = tool.run(q="AI agents", country="US", count=5)
|
||||
```
|
||||
|
||||
Key differences:
|
||||
- **Import**: Use `BraveWebSearchTool` (or the news/image/video variant) instead of `BraveSearchTool`.
|
||||
- **Query parameter**: Use `q` instead of `search_query`. (Both `search_query` and `query` are still accepted for convenience, but `q` is the preferred parameter.)
|
||||
- **Result count**: Pass `count` as a query parameter instead of `n_results` at init time.
|
||||
- **Country**: Pass `country` as a query parameter instead of at init time.
|
||||
- **API key**: Can now be passed directly via `api_key=` in addition to the `BRAVE_API_KEY` environment variable.
|
||||
- **Rate limiting**: Configurable via `requests_per_second` with automatic retry on `429` responses.
|
||||
|
||||
## Conclusion
|
||||
|
||||
By integrating the `BraveSearchTool` into Python projects, users gain the ability to conduct real-time, relevant searches across the internet directly from their applications. The tool provides a simple interface to the powerful Brave Search API, making it easy to retrieve and process search results programmatically. By adhering to the setup and usage guidelines provided, incorporating this tool into projects is streamlined and straightforward.
|
||||
The Brave Search tool suite gives your CrewAI agents flexible, endpoint-specific access to the Brave Search API. Whether you need web pages, breaking news, images, or videos, there is a dedicated tool with validated parameters and built-in resilience. Pick the tool that fits your use case, and refer to the [Brave Search API documentation](https://brave.com/search/api/) for the full details on available parameters and response formats.
|
||||
|
||||
@@ -10,7 +10,18 @@ from crewai_tools.aws.s3.writer_tool import S3WriterTool
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
|
||||
BrightDataDatasetTool,
|
||||
)
|
||||
@@ -200,7 +211,14 @@ __all__ = [
|
||||
"ArxivPaperTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool import (
|
||||
BrightDataDatasetTool,
|
||||
BrightDataSearchTool,
|
||||
@@ -185,7 +196,14 @@ __all__ = [
|
||||
"AIMindTool",
|
||||
"ApifyActorsTool",
|
||||
"ArxivPaperTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Brave API error codes that indicate non-retryable quota/usage exhaustion.
|
||||
_QUOTA_CODES = frozenset({"QUOTA_LIMITED", "USAGE_LIMIT_EXCEEDED"})
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def _parse_error_body(resp: requests.Response) -> dict[str, Any] | None:
|
||||
"""Extract the structured "error" object from a Brave API error response."""
|
||||
try:
|
||||
body = resp.json()
|
||||
error = body.get("error")
|
||||
return error if isinstance(error, dict) else None
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def _raise_for_error(resp: requests.Response) -> None:
|
||||
"""Brave Search API error responses contain helpful JSON payloads"""
|
||||
status = resp.status_code
|
||||
try:
|
||||
body = json.dumps(resp.json())
|
||||
except (ValueError, KeyError):
|
||||
body = resp.text[:500]
|
||||
|
||||
raise RuntimeError(f"Brave Search API error (HTTP {status}): {body}")
|
||||
|
||||
|
||||
def _is_retryable(resp: requests.Response) -> bool:
|
||||
"""Return True for transient failures that are worth retrying.
|
||||
|
||||
* 429 + RATE_LIMITED — the per-second sliding window is full.
|
||||
* 5xx — transient server-side errors.
|
||||
|
||||
Quota exhaustion (QUOTA_LIMITED, USAGE_LIMIT_EXCEEDED) is
|
||||
explicitly excluded: retrying will never succeed until the billing
|
||||
period resets.
|
||||
"""
|
||||
if resp.status_code == 429:
|
||||
error = _parse_error_body(resp) or {}
|
||||
return error.get("code") not in _QUOTA_CODES
|
||||
return 500 <= resp.status_code < 600
|
||||
|
||||
|
||||
def _retry_delay(resp: requests.Response, attempt: int) -> float:
|
||||
"""Compute wait time before the next retry attempt.
|
||||
|
||||
Prefers the server-supplied Retry-After header when available;
|
||||
falls back to exponential backoff (1s, 2s, 4s, ...).
|
||||
"""
|
||||
retry_after = resp.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
try:
|
||||
return max(0.0, float(retry_after))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return float(2**attempt)
|
||||
|
||||
|
||||
class BraveSearchToolBase(BaseTool, ABC):
|
||||
"""
|
||||
Base class for Brave Search API interactions.
|
||||
|
||||
Individual tool subclasses must provide the following:
|
||||
- search_url
|
||||
- header_schema (pydantic model)
|
||||
- args_schema (pydantic model)
|
||||
- _refine_payload() -> dict[str, Any]
|
||||
"""
|
||||
|
||||
search_url: str
|
||||
raw: bool = False
|
||||
args_schema: type[BaseModel]
|
||||
header_schema: type[BaseModel]
|
||||
|
||||
# Tool options (legacy parameters)
|
||||
country: str | None = None
|
||||
save_file: bool = False
|
||||
n_results: int = 10
|
||||
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY",
|
||||
description="API key for Brave Search",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
requests_per_second: float = 1.0,
|
||||
save_file: bool = False,
|
||||
raw: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key or os.environ.get("BRAVE_API_KEY")
|
||||
if not self._api_key:
|
||||
raise ValueError("BRAVE_API_KEY environment variable is required")
|
||||
|
||||
self.raw = bool(raw)
|
||||
self._timeout = int(timeout)
|
||||
self.save_file = bool(save_file)
|
||||
self._requests_per_second = float(requests_per_second)
|
||||
self._headers = self._build_and_validate_headers(headers or {})
|
||||
# Per-instance rate limiting: each instance has its own clock and lock.
|
||||
# Total process rate is the sum of limits of instances you create.
|
||||
self._last_request_time: float = 0
|
||||
self._rate_limit_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, Any]:
|
||||
return self._headers
|
||||
|
||||
def set_headers(self, headers: dict[str, Any]) -> BraveSearchToolBase:
|
||||
merged = {**self._headers, **{k.lower(): v for k, v in headers.items()}}
|
||||
self._headers = self._build_and_validate_headers(merged)
|
||||
return self
|
||||
|
||||
def _build_and_validate_headers(self, headers: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = {k.lower(): v for k, v in headers.items()}
|
||||
normalized.setdefault("x-subscription-token", self._api_key)
|
||||
normalized.setdefault("accept", "application/json")
|
||||
|
||||
try:
|
||||
self.header_schema(**normalized)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid headers: {e}") from e
|
||||
|
||||
return normalized
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
"""Enforce minimum interval between requests for this instance. Thread-safe."""
|
||||
if self._requests_per_second <= 0:
|
||||
return
|
||||
|
||||
min_interval = 1.0 / self._requests_per_second
|
||||
with self._rate_limit_lock:
|
||||
now = time.time()
|
||||
next_allowed = self._last_request_time + min_interval
|
||||
if now < next_allowed:
|
||||
time.sleep(next_allowed - now)
|
||||
now = time.time()
|
||||
self._last_request_time = now
|
||||
|
||||
def _make_request(
|
||||
self, params: dict[str, Any], *, _max_retries: int = 3
|
||||
) -> dict[str, Any]:
|
||||
"""Execute an HTTP GET against the Brave Search API with retry logic."""
|
||||
last_resp: requests.Response | None = None
|
||||
|
||||
# Retry the request up to _max_retries times
|
||||
for attempt in range(_max_retries):
|
||||
self._rate_limit()
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
resp = requests.get(
|
||||
self.search_url,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except requests.ConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API connection failed: {exc}"
|
||||
) from exc
|
||||
except requests.Timeout as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API request timed out after {self._timeout}s: {exc}"
|
||||
) from exc
|
||||
|
||||
# Log the rate limit headers and request details
|
||||
logger.debug(
|
||||
"Brave Search API request: %s %s -> %d",
|
||||
"GET",
|
||||
resp.url,
|
||||
resp.status_code,
|
||||
)
|
||||
|
||||
# Response was OK, return the JSON body
|
||||
if resp.ok:
|
||||
try:
|
||||
return resp.json()
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
|
||||
) from exc
|
||||
|
||||
# Response was not OK, but is retryable
|
||||
# (e.g., 429 Too Many Requests, 500 Internal Server Error)
|
||||
if _is_retryable(resp) and attempt < _max_retries - 1:
|
||||
delay = _retry_delay(resp, attempt)
|
||||
logger.warning(
|
||||
"Brave Search API returned %d. Retrying in %.1fs (attempt %d/%d)",
|
||||
resp.status_code,
|
||||
delay,
|
||||
attempt + 1,
|
||||
_max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
last_resp = resp
|
||||
continue
|
||||
|
||||
# Response was not OK, nor was it retryable
|
||||
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
|
||||
_raise_for_error(resp)
|
||||
|
||||
# All retries exhausted
|
||||
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
|
||||
return {} # unreachable (here to satisfy the type checker and linter)
|
||||
|
||||
def _run(self, q: str | None = None, **params: Any) -> Any:
|
||||
# Allow positional usage: tool.run("latest Brave browser features")
|
||||
if q is not None:
|
||||
params["q"] = q
|
||||
|
||||
params = self._common_payload_refinement(params)
|
||||
|
||||
# Validate only schema fields
|
||||
schema_keys = self.args_schema.model_fields
|
||||
payload_in = {k: v for k, v in params.items() if k in schema_keys}
|
||||
|
||||
try:
|
||||
validated = self.args_schema(**payload_in)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid parameters: {e}") from e
|
||||
|
||||
# The subclass may have additional refinements to apply to the payload, such as goggles or other parameters
|
||||
payload = self._refine_request_payload(validated.model_dump(exclude_none=True))
|
||||
response = self._make_request(payload)
|
||||
|
||||
if not self.raw:
|
||||
response = self._refine_response(response)
|
||||
|
||||
if self.save_file:
|
||||
_save_results_to_file(json.dumps(response, indent=2))
|
||||
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Subclass must implement: transform validated params dict into API request params."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _refine_response(self, response: dict[str, Any]) -> Any:
|
||||
"""Subclass must implement: transform response dict into a more useful format."""
|
||||
raise NotImplementedError
|
||||
|
||||
_EMPTY_VALUES: ClassVar[tuple[None, str, str, list[Any]]] = (None, "", "null", [])
|
||||
|
||||
def _common_payload_refinement(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Common payload refinement for all tools."""
|
||||
# crewAI's schema pipeline (ensure_all_properties_required in
|
||||
# pydantic_schema_utils.py) marks every property as required so
|
||||
# that OpenAI strict-mode structured outputs work correctly.
|
||||
# The side-effect is that the LLM fills in *every* parameter —
|
||||
# even truly optional ones — using placeholder values such as
|
||||
# None, "", "null", or []. Only optional fields are affected,
|
||||
# so we limit the check to those.
|
||||
fields = self.args_schema.model_fields
|
||||
params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
# Permit custom and required fields, and fields with non-empty values
|
||||
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
|
||||
}
|
||||
|
||||
# Make sure params has "q" for query instead of "query" or "search_query"
|
||||
query = params.get("query") or params.get("search_query")
|
||||
if query is not None and "q" not in params:
|
||||
params["q"] = query
|
||||
params.pop("query", None)
|
||||
params.pop("search_query", None)
|
||||
|
||||
# If "count" was not explicitly provided, use n_results
|
||||
# (only when the schema actually supports a "count" field)
|
||||
if "count" in self.args_schema.model_fields:
|
||||
if "count" not in params and self.n_results is not None:
|
||||
params["count"] = self.n_results
|
||||
|
||||
# If "country" was not explicitly provided, but self.country is set, use it
|
||||
# (only when the schema actually supports a "country" field)
|
||||
if "country" in self.args_schema.model_fields:
|
||||
if "country" not in params and self.country is not None:
|
||||
params["country"] = self.country
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
ImageSearchHeaders,
|
||||
ImageSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveImageSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs image searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Image Search"
|
||||
args_schema: type[BaseModel] = ImageSearchParams
|
||||
header_schema: type[BaseModel] = ImageSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs image searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/images/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("properties", {}).get("url"),
|
||||
"dimensions": f"{w}x{h}"
|
||||
if (w := result.get("properties", {}).get("width"))
|
||||
and (h := result.get("properties", {}).get("height"))
|
||||
else None,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LLMContextHeaders,
|
||||
LLMContextParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveLLMContextTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves context for LLM usage from the Brave Search API."""
|
||||
|
||||
name: str = "Brave LLM Context"
|
||||
args_schema: type[BaseModel] = LLMContextParams
|
||||
header_schema: type[BaseModel] = LLMContextHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that retrieves context for LLM usage from the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/llm/context"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
|
||||
"""The LLM Context response schema is fairly simple. Return as is."""
|
||||
return response
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LocalPOIs
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LocalPOIsDescriptionHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsParams,
|
||||
)
|
||||
|
||||
|
||||
DayOpeningHours = LocalPOIs.DayOpeningHours
|
||||
OpeningHours = LocalPOIs.OpeningHours
|
||||
LocationResult = LocalPOIs.LocationResult
|
||||
LocalPOIsResponse = LocalPOIs.Response
|
||||
|
||||
|
||||
def _flatten_slots(slots: list[DayOpeningHours]) -> list[dict[str, str]]:
|
||||
"""Convert a list of DayOpeningHours dicts into simplified entries."""
|
||||
return [
|
||||
{
|
||||
"day": slot["full_name"].lower(),
|
||||
"opens": slot["opens"],
|
||||
"closes": slot["closes"],
|
||||
}
|
||||
for slot in slots
|
||||
]
|
||||
|
||||
|
||||
def _simplify_opening_hours(result: LocationResult) -> list[dict[str, str]] | None:
|
||||
"""Collapse opening_hours into a flat list of {day, opens, closes} dicts."""
|
||||
hours = result.get("opening_hours")
|
||||
if not hours:
|
||||
return None
|
||||
|
||||
entries: list[dict[str, str]] = []
|
||||
|
||||
current = hours.get("current_day")
|
||||
if current:
|
||||
entries.extend(_flatten_slots(current))
|
||||
|
||||
days = hours.get("days")
|
||||
if days:
|
||||
for day_slots in days:
|
||||
entries.extend(_flatten_slots(day_slots))
|
||||
|
||||
return entries or None
|
||||
|
||||
|
||||
class BraveLocalPOIsTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POIs"
|
||||
args_schema: type[BaseModel] = LocalPOIsParams
|
||||
header_schema: type[BaseModel] = LocalPOIsHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/pois"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("url"),
|
||||
"description": result.get("description"),
|
||||
"address": result.get("postal_address", {}).get("displayAddress"),
|
||||
"contact": result.get("contact", {}).get("telephone")
|
||||
or result.get("contact", {}).get("email")
|
||||
or None,
|
||||
"opening_hours": _simplify_opening_hours(result),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
|
||||
class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POI Descriptions"
|
||||
args_schema: type[BaseModel] = LocalPOIsDescriptionParams
|
||||
header_schema: type[BaseModel] = LocalPOIsDescriptionHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/descriptions"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"id": result.get("id"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
NewsSearchHeaders,
|
||||
NewsSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveNewsSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs news searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave News Search"
|
||||
args_schema: type[BaseModel] = NewsSearchParams
|
||||
header_schema: type[BaseModel] = NewsSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs news searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/news/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -10,17 +10,13 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
@@ -29,51 +25,6 @@ Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
@@ -83,7 +34,7 @@ class BraveSearchTool(BaseTool):
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
@@ -120,8 +71,8 @@ class BraveSearchTool(BaseTool):
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
# Fallback to "query" or "search_query" for backwards compatibility
|
||||
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
@@ -130,8 +81,11 @@ class BraveSearchTool(BaseTool):
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
# Fallback to "search_language" for backwards compatibility
|
||||
if search_lang := kwargs.get("search_lang") or kwargs.get(
|
||||
"search_language"
|
||||
):
|
||||
payload["search_lang"] = search_lang
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
VideoSearchHeaders,
|
||||
VideoSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveVideoSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs video searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Video Search"
|
||||
args_schema: type[BaseModel] = VideoSearchParams
|
||||
header_schema: type[BaseModel] = VideoSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs video searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/videos/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchHeaders,
|
||||
WebSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveWebSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Web Search"
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
header_schema: type[BaseModel] = WebSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
results = response.get("web", {}).get("results", [])
|
||||
refined = []
|
||||
for result in results:
|
||||
snippets = result.get("extra_snippets") or []
|
||||
if not snippets:
|
||||
desc = result.get("description")
|
||||
if desc:
|
||||
snippets = [desc]
|
||||
refined.append(
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"snippets": snippets,
|
||||
}
|
||||
)
|
||||
return refined
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
|
||||
class LocalPOIs:
|
||||
class PostalAddress(TypedDict, total=False):
|
||||
type: Literal["PostalAddress"]
|
||||
country: str
|
||||
postalCode: str
|
||||
streetAddress: str
|
||||
addressRegion: str
|
||||
addressLocality: str
|
||||
displayAddress: str
|
||||
|
||||
class DayOpeningHours(TypedDict):
|
||||
abbr_name: str
|
||||
full_name: str
|
||||
opens: str
|
||||
closes: str
|
||||
|
||||
class OpeningHours(TypedDict, total=False):
|
||||
current_day: list[LocalPOIs.DayOpeningHours]
|
||||
days: list[list[LocalPOIs.DayOpeningHours]]
|
||||
|
||||
class LocationResult(TypedDict, total=False):
|
||||
provider_url: str
|
||||
title: str
|
||||
url: str
|
||||
id: str | None
|
||||
opening_hours: LocalPOIs.OpeningHours | None
|
||||
postal_address: LocalPOIs.PostalAddress | None
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
type: Literal["local_pois"]
|
||||
results: list[LocalPOIs.LocationResult]
|
||||
|
||||
|
||||
class LLMContext:
|
||||
class LLMContextItem(TypedDict, total=False):
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextMapItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextPOIItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class Grounding(TypedDict, total=False):
|
||||
generic: list[LLMContext.LLMContextItem]
|
||||
poi: LLMContext.LLMContextPOIItem
|
||||
map: list[LLMContext.LLMContextMapItem]
|
||||
|
||||
class Sources(TypedDict, total=False):
|
||||
pass
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
grounding: LLMContext.Grounding
|
||||
sources: LLMContext.Sources
|
||||
@@ -0,0 +1,525 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
|
||||
|
||||
# Common types
|
||||
Units = Literal["metric", "imperial"]
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
Freshness = (
|
||||
Literal["pd", "pw", "pm", "py"]
|
||||
| Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
)
|
||||
ResultFilter = list[
|
||||
Literal[
|
||||
"discussions",
|
||||
"faq",
|
||||
"infobox",
|
||||
"news",
|
||||
"query",
|
||||
"summarizer",
|
||||
"videos",
|
||||
"web",
|
||||
"locations",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class LLMContextParams(BaseModel):
|
||||
"""Parameters for Brave LLM Context endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_urls: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of URLs to include in the context.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="The approximate maximum number of tokens to include in the context.",
|
||||
ge=1,
|
||||
le=32768,
|
||||
)
|
||||
maximum_number_of_snippets: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of different snippets to include in the context.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
context_threshold_mode: (
|
||||
Literal["disabled", "strict", "lenient", "balanced"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="The mode to use for the context thresholding.",
|
||||
)
|
||||
maximum_number_of_tokens_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to include for each URL in the context.",
|
||||
ge=1,
|
||||
le=8192,
|
||||
)
|
||||
maximum_number_of_snippets_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of snippets to include per URL.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
enable_local: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable local recall. Not setting this value means auto-detect and uses local recall if any of the localization headers are provided.",
|
||||
)
|
||||
|
||||
|
||||
class WebSearchParams(BaseModel):
|
||||
"""Parameters for Brave Web Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
result_filter: ResultFilter | None = Field(
|
||||
default=None,
|
||||
description="Filter the results by type. Options: discussions/faq/infobox/news/query/summarizer/videos/web/locations. Note: The `count` parameter is applied only to the `web` results.",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
summary: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to generate a summarizer ID for the results.",
|
||||
)
|
||||
enable_rich_callback: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable rich callbacks for the results. Requires Pro level subscription.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsParams(BaseModel):
|
||||
"""Parameters for Brave Local POIs endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionParams(BaseModel):
|
||||
"""Parameters for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
class ImageSearchParams(BaseModel):
|
||||
"""Parameters for Brave Image Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Default is strict.",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=200,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
|
||||
|
||||
class VideoSearchParams(BaseModel):
|
||||
"""Parameters for Brave Video Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class NewsSearchParams(BaseModel):
|
||||
"""Parameters for Brave News Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class BaseSearchHeaders(BaseModel):
|
||||
"""Common headers for Brave Search endpoints."""
|
||||
|
||||
x_subscription_token: str = Field(
|
||||
alias="x-subscription-token",
|
||||
description="API key for Brave Search",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
alias="api-version",
|
||||
default=None,
|
||||
description="API version to use. Default is latest available.",
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}$", # YYYY-MM-DD
|
||||
)
|
||||
accept: Literal["application/json"] | Literal["*/*"] | None = Field(
|
||||
default=None,
|
||||
description="Accept header for the request.",
|
||||
)
|
||||
cache_control: Literal["no-cache"] | None = Field(
|
||||
alias="cache-control",
|
||||
default=None,
|
||||
description="Cache control header for the request.",
|
||||
)
|
||||
user_agent: str | None = Field(
|
||||
alias="user-agent",
|
||||
default=None,
|
||||
description="User agent for the request.",
|
||||
)
|
||||
|
||||
|
||||
class LLMContextHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave LLM Context endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POIs endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
|
||||
class VideoSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Video Search endpoint."""
|
||||
|
||||
|
||||
class ImageSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Image Search endpoint."""
|
||||
|
||||
|
||||
class NewsSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave News Search endpoint."""
|
||||
|
||||
|
||||
class WebSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Web Search endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_timezone: str | None = Field(
|
||||
alias="x-loc-timezone",
|
||||
default=None,
|
||||
description="Timezone of the user's location.",
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
x_loc_postal_code: str | None = Field(
|
||||
alias="x-loc-postal-code",
|
||||
default=None,
|
||||
description="The postal code of the user's location.",
|
||||
)
|
||||
@@ -1,27 +1,13 @@
|
||||
# CodeInterpreterTool
|
||||
|
||||
## Description
|
||||
This tool is used to give the Agent the ability to run code (Python3) from the code generated by the Agent itself. The code is executed in a Docker container for secure isolation.
|
||||
This tool is used to give the Agent the ability to run code (Python3) from the code generated by the Agent itself. The code is executed in a sandboxed environment, so it is safe to run any code.
|
||||
|
||||
It is incredibly useful since it allows the Agent to generate code, run it in an isolated environment, get the result and use it to make decisions.
|
||||
|
||||
## ⚠️ Security Requirements
|
||||
|
||||
**Docker is REQUIRED** for safe code execution. The tool will refuse to execute code without Docker to prevent security vulnerabilities.
|
||||
|
||||
### Why Docker is Required
|
||||
|
||||
Previous versions included a "restricted sandbox" fallback when Docker was unavailable. This has been **removed** due to critical security vulnerabilities:
|
||||
|
||||
- The Python-based sandbox could be escaped via object introspection
|
||||
- Attackers could recover the original `__import__` function and access any module
|
||||
- This allowed arbitrary command execution on the host system
|
||||
|
||||
**Docker provides real process isolation** and is the only secure way to execute untrusted code.
|
||||
It is incredible useful since it allows the Agent to generate code, run it in the same environment, get the result and use it to make decisions.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Docker (REQUIRED)** - Install from [docker.com](https://docs.docker.com/get-docker/)
|
||||
- Docker
|
||||
|
||||
## Installation
|
||||
Install the crewai_tools package
|
||||
@@ -31,9 +17,7 @@ pip install 'crewai[tools]'
|
||||
|
||||
## Example
|
||||
|
||||
Remember that when using this tool, the code must be generated by the Agent itself. The code must be Python3 code. It will take some time the first time to run because it needs to build the Docker image.
|
||||
|
||||
### Basic Usage (Docker Container - Recommended)
|
||||
Remember that when using this tool, the code must be generated by the Agent itself. The code must be a Python3 code. And it will take some time for the first time to run because it needs to build the Docker image.
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
@@ -44,9 +28,7 @@ Agent(
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Dockerfile
|
||||
|
||||
If you need to pass your own Dockerfile:
|
||||
Or if you need to pass your own Dockerfile just do this
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
@@ -57,39 +39,15 @@ Agent(
|
||||
)
|
||||
```
|
||||
|
||||
### Manual Docker Host Configuration
|
||||
|
||||
If it is difficult to connect to the Docker daemon automatically (especially for macOS users), you can set up the Docker host manually:
|
||||
If it is difficult to connect to docker daemon automatically (especially for macOS users), you can do this to setup docker host manually
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
Agent(
|
||||
...
|
||||
tools=[CodeInterpreterTool(
|
||||
user_docker_base_url="<Docker Host Base Url>",
|
||||
user_dockerfile_path="<Dockerfile_path>"
|
||||
)],
|
||||
tools=[CodeInterpreterTool(user_docker_base_url="<Docker Host Base Url>",
|
||||
user_dockerfile_path="<Dockerfile_path>")],
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Unsafe Mode (NOT RECOMMENDED)
|
||||
|
||||
If you absolutely cannot use Docker and **fully trust the code source**, you can use unsafe mode:
|
||||
|
||||
```python
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
# WARNING: Only use with fully trusted code!
|
||||
Agent(
|
||||
...
|
||||
tools=[CodeInterpreterTool(unsafe_mode=True)],
|
||||
)
|
||||
```
|
||||
|
||||
**⚠️ SECURITY WARNING:** `unsafe_mode=True` executes code directly on the host without any isolation. Only use this if:
|
||||
- You completely trust the code being executed
|
||||
- You understand the security risks
|
||||
- You cannot install Docker in your environment
|
||||
|
||||
For production use, **always use Docker** (the default mode).
|
||||
|
||||
@@ -50,16 +50,11 @@ class CodeInterpreterSchema(BaseModel):
|
||||
|
||||
|
||||
class SandboxPython:
|
||||
"""INSECURE: A restricted Python execution environment with known vulnerabilities.
|
||||
"""A restricted Python execution environment for running code safely.
|
||||
|
||||
WARNING: This class does NOT provide real security isolation and is vulnerable to
|
||||
sandbox escape attacks via Python object introspection. Attackers can recover the
|
||||
original __import__ function and bypass all restrictions.
|
||||
|
||||
DO NOT USE for untrusted code execution. Use Docker containers instead.
|
||||
|
||||
This class attempts to restrict access to dangerous modules and built-in functions
|
||||
but provides no real security boundary against a motivated attacker.
|
||||
This class provides methods to safely execute Python code by restricting access to
|
||||
potentially dangerous modules and built-in functions. It creates a sandboxed
|
||||
environment where harmful operations are blocked.
|
||||
"""
|
||||
|
||||
BLOCKED_MODULES: ClassVar[set[str]] = {
|
||||
@@ -304,8 +299,8 @@ class CodeInterpreterTool(BaseTool):
|
||||
def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code in the safest available environment.
|
||||
|
||||
Requires Docker to be available for secure code execution. Fails closed
|
||||
if Docker is not available to prevent sandbox escape vulnerabilities.
|
||||
Attempts to run code in Docker if available, falls back to a restricted
|
||||
sandbox if Docker is not available.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
@@ -313,24 +308,10 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
Returns:
|
||||
The output of the executed code as a string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Docker is not available, as the restricted sandbox
|
||||
is vulnerable to escape attacks and should not be used
|
||||
for untrusted code execution.
|
||||
"""
|
||||
if self._check_docker_available():
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
|
||||
error_msg = (
|
||||
"Docker is required for safe code execution but is not available. "
|
||||
"The restricted sandbox fallback has been removed due to security vulnerabilities "
|
||||
"that allow sandbox escape via Python object introspection. "
|
||||
"Please install Docker (https://docs.docker.com/get-docker/) or use unsafe_mode=True "
|
||||
"if you trust the code source and understand the security risks."
|
||||
)
|
||||
Printer.print(error_msg, color="bold_red")
|
||||
raise RuntimeError(error_msg)
|
||||
return self.run_code_in_restricted_sandbox(code)
|
||||
|
||||
def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs Python code in a Docker container for safe isolation.
|
||||
@@ -361,19 +342,10 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
@staticmethod
|
||||
def run_code_in_restricted_sandbox(code: str) -> str:
|
||||
"""DEPRECATED AND INSECURE: Runs Python code in a restricted sandbox environment.
|
||||
"""Runs Python code in a restricted sandbox environment.
|
||||
|
||||
WARNING: This method is vulnerable to sandbox escape attacks via Python object
|
||||
introspection and should NOT be used for untrusted code execution. It has been
|
||||
deprecated and is only kept for backward compatibility with trusted code.
|
||||
|
||||
The "restricted" environment can be bypassed by attackers who can:
|
||||
- Use object graph introspection to recover the original __import__ function
|
||||
- Access any Python module including os, subprocess, sys, etc.
|
||||
- Execute arbitrary commands on the host system
|
||||
|
||||
Use run_code_in_docker() for secure code execution, or run_code_unsafe()
|
||||
if you explicitly acknowledge the security risks.
|
||||
Executes the code with restricted access to potentially dangerous modules and
|
||||
built-in functions for basic safety when Docker is not available.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
@@ -382,10 +354,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
The value of the 'result' variable from the executed code,
|
||||
or an error message if execution failed.
|
||||
"""
|
||||
Printer.print(
|
||||
"WARNING: Running code in INSECURE restricted sandbox (vulnerable to escape attacks)",
|
||||
color="bold_red"
|
||||
)
|
||||
Printer.print("Running code in restricted sandbox", color="yellow")
|
||||
exec_locals: dict[str, Any] = {}
|
||||
try:
|
||||
SandboxPython.exec(code=code, locals_=exec_locals)
|
||||
|
||||
@@ -1,80 +1,777 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests as requests_lib
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
text: str = "",
|
||||
) -> MagicMock:
|
||||
"""Build a ``requests.Response``-like mock with the attributes used by ``_make_request``."""
|
||||
resp = MagicMock(spec=requests_lib.Response)
|
||||
resp.status_code = status_code
|
||||
resp.ok = 200 <= status_code < 400
|
||||
resp.url = "https://api.search.brave.com/res/v1/web/search?q=test"
|
||||
resp.text = text or (str(json_data) if json_data else "")
|
||||
resp.headers = headers or {}
|
||||
resp.json.return_value = json_data if json_data is not None else {}
|
||||
return resp
|
||||
|
||||
|
||||
# Fixtures
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _brave_env_and_rate_limit():
|
||||
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
|
||||
with patch.dict(os.environ, {"BRAVE_API_KEY": "test-api-key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
return BraveSearchTool(n_results=2)
|
||||
def web_tool():
|
||||
return BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_brave_tool_initialization():
|
||||
tool = BraveSearchTool()
|
||||
assert tool.n_results == 10
|
||||
@pytest.fixture
|
||||
def image_tool():
|
||||
return BraveImageSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def news_tool():
|
||||
return BraveNewsSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_tool():
|
||||
return BraveVideoSearchTool()
|
||||
|
||||
|
||||
# Initialization
|
||||
|
||||
ALL_TOOL_CLASSES = [
|
||||
BraveWebSearchTool,
|
||||
BraveImageSearchTool,
|
||||
BraveNewsSearchTool,
|
||||
BraveVideoSearchTool,
|
||||
BraveLLMContextTool,
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_env_var(tool_cls):
|
||||
"""Each tool can be created when BRAVE_API_KEY is in the environment."""
|
||||
tool = tool_cls()
|
||||
assert tool.api_key == "test-api-key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_explicit_key(tool_cls):
|
||||
"""An explicit api_key takes precedence over the environment."""
|
||||
tool = tool_cls(api_key="explicit-key")
|
||||
assert tool.api_key == "explicit-key"
|
||||
|
||||
|
||||
def test_missing_api_key_raises():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="BRAVE_API_KEY"):
|
||||
BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_default_attributes():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.save_file is False
|
||||
assert tool.n_results == 10
|
||||
assert tool._timeout == 30
|
||||
assert tool._requests_per_second == 1.0
|
||||
assert tool.raw is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool_search(mock_get, brave_tool):
|
||||
mock_response = {
|
||||
def test_custom_constructor_args():
|
||||
tool = BraveWebSearchTool(
|
||||
save_file=True,
|
||||
timeout=60,
|
||||
n_results=5,
|
||||
requests_per_second=0.5,
|
||||
raw=True,
|
||||
)
|
||||
assert tool.save_file is True
|
||||
assert tool._timeout == 60
|
||||
assert tool.n_results == 5
|
||||
assert tool._requests_per_second == 0.5
|
||||
assert tool.raw is True
|
||||
|
||||
|
||||
# Headers
|
||||
|
||||
|
||||
def test_default_headers():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
assert tool.headers["accept"] == "application/json"
|
||||
|
||||
|
||||
def test_set_headers_merges_and_normalizes():
|
||||
tool = BraveWebSearchTool()
|
||||
tool.set_headers({"Cache-Control": "no-cache"})
|
||||
assert tool.headers["cache-control"] == "no-cache"
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
|
||||
|
||||
def test_set_headers_returns_self_for_chaining():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.set_headers({"Cache-Control": "no-cache"}) is tool
|
||||
|
||||
|
||||
def test_invalid_header_value_raises():
|
||||
tool = BraveImageSearchTool()
|
||||
with pytest.raises(ValueError, match="Invalid headers"):
|
||||
tool.set_headers({"Accept": "text/xml"})
|
||||
|
||||
|
||||
# Endpoint & Schema Wiring
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_cls, expected_url, expected_params, expected_headers",
|
||||
[
|
||||
(
|
||||
BraveWebSearchTool,
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveImageSearchTool,
|
||||
"https://api.search.brave.com/res/v1/images/search",
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveNewsSearchTool,
|
||||
"https://api.search.brave.com/res/v1/news/search",
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveVideoSearchTool,
|
||||
"https://api.search.brave.com/res/v1/videos/search",
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveLLMContextTool,
|
||||
"https://api.search.brave.com/res/v1/llm/context",
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsTool,
|
||||
"https://api.search.brave.com/res/v1/local/pois",
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
"https://api.search.brave.com/res/v1/local/descriptions",
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_wiring(tool_cls, expected_url, expected_params, expected_headers):
|
||||
tool = tool_cls()
|
||||
assert tool.search_url == expected_url
|
||||
assert tool.args_schema is expected_params
|
||||
assert tool.header_schema is expected_headers
|
||||
|
||||
|
||||
# Payload Refinement (e.g., `query` -> `q`, `count` fallback, param pass-through)
|
||||
|
||||
|
||||
def test_web_refine_request_payload_passes_all_params(web_tool):
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"count": 5,
|
||||
"offset": 2,
|
||||
"safesearch": "moderate",
|
||||
"freshness": "pw",
|
||||
}
|
||||
)
|
||||
refined_params = web_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "test"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["count"] == 5
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["search_lang"] == "en"
|
||||
assert refined_params["offset"] == 2
|
||||
assert refined_params["safesearch"] == "moderate"
|
||||
assert refined_params["freshness"] == "pw"
|
||||
|
||||
|
||||
def test_image_refine_request_payload_passes_all_params(image_tool):
|
||||
params = image_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "cat photos",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"safesearch": "strict",
|
||||
"count": 50,
|
||||
"spellcheck": True,
|
||||
}
|
||||
)
|
||||
refined_params = image_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "cat photos"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["safesearch"] == "strict"
|
||||
assert refined_params["count"] == 50
|
||||
assert refined_params["spellcheck"] is True
|
||||
|
||||
|
||||
def test_news_refine_request_payload_passes_all_params(news_tool):
|
||||
params = news_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "breaking news",
|
||||
"country": "US",
|
||||
"count": 10,
|
||||
"offset": 1,
|
||||
"freshness": "pd",
|
||||
"extra_snippets": True,
|
||||
}
|
||||
)
|
||||
refined_params = news_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "breaking news"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 1
|
||||
assert refined_params["freshness"] == "pd"
|
||||
assert refined_params["extra_snippets"] is True
|
||||
|
||||
|
||||
def test_video_refine_request_payload_passes_all_params(video_tool):
|
||||
params = video_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "tutorial",
|
||||
"country": "US",
|
||||
"count": 25,
|
||||
"offset": 0,
|
||||
"safesearch": "strict",
|
||||
"freshness": "pm",
|
||||
}
|
||||
)
|
||||
refined_params = video_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "tutorial"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 0
|
||||
assert refined_params["freshness"] == "pm"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_flow_into_query_params():
|
||||
"""The legacy n_results and country constructor params are applied as defaults
|
||||
when count/country are not explicitly provided at call time."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement({"query": "test"})
|
||||
|
||||
assert params["count"] == 3
|
||||
assert params["country"] == "BR"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_do_not_override_explicit_query_params():
|
||||
"""Explicit query-time count/country take precedence over constructor defaults."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement(
|
||||
{"query": "test", "count": 10, "country": "US"}
|
||||
)
|
||||
|
||||
assert params["count"] == 10
|
||||
assert params["country"] == "US"
|
||||
|
||||
|
||||
def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_tool):
|
||||
result = web_tool._refine_request_payload(
|
||||
{
|
||||
"query": "test",
|
||||
"goggles": ["goggle1", "goggle2"],
|
||||
}
|
||||
)
|
||||
assert result["goggles"] == ["goggle1", "goggle2"]
|
||||
|
||||
|
||||
# Null-like / empty value stripping
|
||||
#
|
||||
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
|
||||
# every schema property as required for OpenAI strict-mode compatibility.
|
||||
# Because optional Brave API parameters look required to the LLM, it fills
|
||||
# them with placeholder junk — None, "", "null", or []. The test below
|
||||
# verifies that _common_payload_refinement strips these from optional fields.
|
||||
|
||||
|
||||
def test_common_refinement_strips_null_like_values(web_tool):
|
||||
"""_common_payload_refinement drops optional keys with None / '' / 'null' / []."""
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "",
|
||||
"freshness": "null",
|
||||
"count": 5,
|
||||
"goggles": [],
|
||||
}
|
||||
)
|
||||
assert params["q"] == "test"
|
||||
assert params["country"] == "US"
|
||||
assert params["count"] == 5
|
||||
assert "search_lang" not in params
|
||||
assert "freshness" not in params
|
||||
assert "goggles" not in params
|
||||
|
||||
|
||||
# End-to-End _run() with Mocked HTTP Response
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_web_search_end_to_end(mock_get, web_tool):
|
||||
web_tool.raw = True
|
||||
data = {"web": {"results": [{"title": "R", "url": "http://r.co"}]}}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
call_args = mock_get.call_args.kwargs
|
||||
assert call_args["params"]["q"] == "test"
|
||||
assert call_args["headers"]["x-subscription-token"] == "test-api-key"
|
||||
assert result == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_image_search_end_to_end(mock_get, image_tool):
|
||||
image_tool.raw = True
|
||||
data = {"results": [{"url": "http://img.co/a.jpg"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert image_tool._run(query="cats") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_news_search_end_to_end(mock_get, news_tool):
|
||||
news_tool.raw = True
|
||||
data = {"results": [{"title": "News", "url": "http://n.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert news_tool._run(query="headlines") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_video_search_end_to_end(mock_get, video_tool):
|
||||
video_tool.raw = True
|
||||
data = {"results": [{"title": "Vid", "url": "http://v.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert video_tool._run(query="python tutorial") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_raw_false_calls_refine_response(mock_get, web_tool):
|
||||
"""With raw=False (the default), _refine_response transforms the API response."""
|
||||
api_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "http://test.com",
|
||||
"description": "Test Description",
|
||||
"title": "CrewAI",
|
||||
"url": "https://crewai.com",
|
||||
"description": "AI agent framework",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value = _mock_response(json_data=api_response)
|
||||
|
||||
result = brave_tool.run(query="test")
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
assert data[0]["title"] == "Test Title"
|
||||
assert data[0]["url"] == "http://test.com"
|
||||
assert web_tool.raw is False
|
||||
result = web_tool._run(query="crewai")
|
||||
|
||||
# The web tool's _refine_response extracts and reshapes results.
|
||||
# The key assertion: we should NOT get back the raw API envelope.
|
||||
assert result != api_response
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
# Backward Compatibility & Legacy Parameter Support
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_brave_tool()
|
||||
test_brave_tool_initialization()
|
||||
# test_brave_tool_search(brave_tool)
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_positional_query_argument(mock_get, web_tool):
|
||||
"""tool.run('my query') works as a positional argument."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run("positional test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "positional test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_search_query_backward_compat(mock_get, web_tool):
|
||||
"""The legacy 'search_query' param is mapped to 'query'."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run(search_query="legacy test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "legacy test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base._save_results_to_file")
|
||||
def test_save_file_called_when_enabled(mock_save, mock_get):
|
||||
mock_get.return_value = _mock_response(json_data={"results": []})
|
||||
|
||||
tool = BraveWebSearchTool(save_file=True)
|
||||
tool._run(query="test")
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
# Error Handling
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_connection_error_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.ConnectionError("refused")
|
||||
with pytest.raises(RuntimeError, match="Brave Search API connection failed"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_timeout_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.Timeout("timed out")
|
||||
with pytest.raises(RuntimeError, match="timed out"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_params_raises_value_error(mock_get, web_tool):
|
||||
"""count=999 exceeds WebSearchParams.count le=20."""
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
web_tool._run(query="test", count=999)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_4xx_error_raises_with_api_detail(mock_get, web_tool):
|
||||
"""A 422 with a structured error body includes code and detail in the message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=422,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "abc-123",
|
||||
"status": 422,
|
||||
"code": "OPTION_NOT_IN_PLAN",
|
||||
"detail": "extra_snippets requires a Pro plan",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="OPTION_NOT_IN_PLAN") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "extra_snippets requires a Pro plan" in str(exc_info.value)
|
||||
assert "HTTP 422" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_auth_error_raises_immediately(mock_get, web_tool):
|
||||
"""A 401 with SUBSCRIPTION_TOKEN_INVALID is not retried."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=401,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "xyz",
|
||||
"status": 401,
|
||||
"code": "SUBSCRIPTION_TOKEN_INVALID",
|
||||
"detail": "The subscription token is invalid",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="SUBSCRIPTION_TOKEN_INVALID"):
|
||||
web_tool._run(query="test")
|
||||
# Should NOT have retried — only one call.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_quota_limited_429_raises_immediately(mock_get, web_tool):
|
||||
"""A 429 with QUOTA_LIMITED is NOT retried — quota exhaustion is terminal."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ql-1",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Monthly quota exceeded",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="QUOTA_LIMITED") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "Monthly quota exceeded" in str(exc_info.value)
|
||||
# Terminal — only one HTTP call, no retries.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_usage_limit_exceeded_429_raises_immediately(mock_get, web_tool):
|
||||
"""USAGE_LIMIT_EXCEEDED is also non-retryable, just like QUOTA_LIMITED."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ule-1",
|
||||
"status": 429,
|
||||
"code": "USAGE_LIMIT_EXCEEDED",
|
||||
}
|
||||
},
|
||||
text="usage limit exceeded",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="USAGE_LIMIT_EXCEEDED"):
|
||||
web_tool._run(query="test")
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_body_is_fully_included_in_message(mock_get, web_tool):
|
||||
"""The full JSON error body is included in the RuntimeError message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "x",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Exceeded",
|
||||
"meta": {"plan": "free", "limit": 1000},
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
web_tool._run(query="test")
|
||||
msg = str(exc_info.value)
|
||||
assert "HTTP 429" in msg
|
||||
assert "QUOTA_LIMITED" in msg
|
||||
assert "free" in msg
|
||||
assert "1000" in msg
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_without_json_body_falls_back_to_text(mock_get, web_tool):
|
||||
"""When the error response isn't valid JSON, resp.text is used as the detail."""
|
||||
resp = _mock_response(status_code=500, text="Internal Server Error")
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="Internal Server Error"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_json_on_success_raises_runtime_error(mock_get, web_tool):
|
||||
"""A 200 OK with a non-JSON body raises RuntimeError."""
|
||||
resp = _mock_response(status_code=200)
|
||||
resp.json.side_effect = ValueError("Expecting value")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="invalid JSON"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
# Rate Limiting
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_sleeps_when_too_fast(mock_time, mock_get, web_tool):
|
||||
"""Back-to-back calls within the interval trigger a sleep."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Simulate: last request was at t=100, "now" is t=100.2 (only 0.2s elapsed).
|
||||
# With default 1 req/s the min interval is 1.0s, so it should sleep ~0.8s.
|
||||
mock_time.time.return_value = 100.2
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_called_once()
|
||||
sleep_duration = mock_time.sleep.call_args[0][0]
|
||||
assert 0.7 < sleep_duration < 0.9 # approximately 0.8s
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_skips_sleep_when_enough_time_passed(mock_time, mock_get, web_tool):
|
||||
"""No sleep when the elapsed time already exceeds the interval."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Last request was at t=100, "now" is t=102 (2s elapsed > 1s interval).
|
||||
mock_time.time.return_value = 102.0
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_disabled_when_zero(mock_time, mock_get, web_tool):
|
||||
"""requests_per_second=0 disables rate limiting entirely."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._last_request_time = 100.0
|
||||
mock_time.time.return_value = 100.0 # same instant
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_per_instance_independent(mock_time, mock_get, web_tool, image_tool):
|
||||
"""Each instance has its own rate-limit clock; a request on one does not delay the other."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Web tool fires at t=100 (its clock goes 0 -> 100).
|
||||
mock_time.time.return_value = 100.0
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Image tool fires at t=100.3. Its clock is still 0 (separate instance), so
|
||||
# next_allowed = 1.0 and 100.3 > 1.0 — no sleep. Total process rate can be sum of instance limits.
|
||||
mock_time.time.return_value = 100.3
|
||||
image_tool._run(query="cats")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
# Retry Behavior
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_retries_then_succeeds(mock_time, mock_get, web_tool):
|
||||
"""A transient RATE_LIMITED 429 is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
headers={"Retry-After": "2"},
|
||||
)
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_429, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
# Slept for the Retry-After value.
|
||||
retry_sleeps = [c for c in mock_time.sleep.call_args_list if c[0][0] == 2.0]
|
||||
assert len(retry_sleeps) == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_5xx_is_retried(mock_time, mock_get, web_tool):
|
||||
"""A 502 server error is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_502 = _mock_response(status_code=502, text="Bad Gateway")
|
||||
resp_502.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_502, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_exhausts_retries(mock_time, mock_get, web_tool):
|
||||
"""Persistent RATE_LIMITED 429s exhaust retries and raise RuntimeError."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
)
|
||||
mock_get.return_value = resp_429
|
||||
|
||||
with pytest.raises(RuntimeError, match="RATE_LIMITED"):
|
||||
web_tool._run(query="test")
|
||||
# 3 attempts (default _max_retries).
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_retry_uses_exponential_backoff_when_no_retry_after(
|
||||
mock_time, mock_get, web_tool
|
||||
):
|
||||
"""Without Retry-After, backoff is 2^attempt (1s, 2s, ...)."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_503 = _mock_response(status_code=503, text="Service Unavailable")
|
||||
resp_503.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"ok": True})
|
||||
mock_get.side_effect = [resp_503, resp_503, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Two retries: attempt 0 → sleep(1.0), attempt 1 → sleep(2.0).
|
||||
retry_sleeps = [c[0][0] for c in mock_time.sleep.call_args_list]
|
||||
assert 1.0 in retry_sleeps
|
||||
assert 2.0 in retry_sleeps
|
||||
|
||||
@@ -76,24 +76,24 @@ print("This is line 2")"""
|
||||
)
|
||||
|
||||
|
||||
def test_docker_unavailable_raises_error(printer_mock, docker_unavailable_mock):
|
||||
"""Test that execution fails when Docker is unavailable in safe mode."""
|
||||
def test_restricted_sandbox_basic_code_execution(printer_mock, docker_unavailable_mock):
|
||||
"""Test basic code execution."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
result = 2 + 2
|
||||
print(result)
|
||||
"""
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
tool.run(code=code, libraries_used=[])
|
||||
|
||||
assert "Docker is required for safe code execution" in str(exc_info.value)
|
||||
assert "sandbox escape" in str(exc_info.value)
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert result == 4
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_modules(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted modules cannot be imported when using the deprecated sandbox directly."""
|
||||
"""Test that restricted modules cannot be imported."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_modules = SandboxPython.BLOCKED_MODULES
|
||||
|
||||
@@ -102,17 +102,18 @@ def test_restricted_sandbox_running_with_blocked_modules(
|
||||
import {module}
|
||||
result = "Import succeeded"
|
||||
"""
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
|
||||
assert f"An error occurred: Importing '{module}' is not allowed" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_builtins(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted builtins are not available when using the deprecated sandbox directly."""
|
||||
"""Test that restricted builtins are not available."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_builtins = SandboxPython.UNSAFE_BUILTINS
|
||||
|
||||
@@ -121,23 +122,25 @@ def test_restricted_sandbox_running_with_blocked_builtins(
|
||||
{builtin}("test")
|
||||
result = "Builtin available"
|
||||
"""
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert f"An error occurred: name '{builtin}' is not defined" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_no_result_variable(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test behavior when no result variable is set in deprecated sandbox."""
|
||||
"""Test behavior when no result variable is set."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
x = 10
|
||||
"""
|
||||
# Note: run_code_in_restricted_sandbox is deprecated and insecure
|
||||
# This test verifies the old behavior but should not be used in production
|
||||
result = tool.run_code_in_restricted_sandbox(code)
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert result == "No result variable found."
|
||||
|
||||
|
||||
@@ -169,40 +172,3 @@ result = eval("5/1")
|
||||
"WARNING: Running code in unsafe mode", color="bold_magenta"
|
||||
)
|
||||
assert 5.0 == result
|
||||
|
||||
|
||||
def test_sandbox_escape_vulnerability_demonstration(printer_mock):
|
||||
"""Demonstrate that the restricted sandbox is vulnerable to escape attacks.
|
||||
|
||||
This test shows that an attacker can use Python object introspection to bypass
|
||||
the restricted sandbox and access blocked modules like 'os'. This is why the
|
||||
sandbox should never be used for untrusted code execution.
|
||||
|
||||
NOTE: This test uses the deprecated run_code_in_restricted_sandbox directly
|
||||
to demonstrate the vulnerability. In production, Docker is now required.
|
||||
"""
|
||||
tool = CodeInterpreterTool()
|
||||
|
||||
# Classic Python sandbox escape via object introspection
|
||||
escape_code = """
|
||||
# Recover the real __import__ function via object introspection
|
||||
for cls in ().__class__.__bases__[0].__subclasses__():
|
||||
if cls.__name__ == 'catch_warnings':
|
||||
# Get the real builtins module
|
||||
real_builtins = cls()._module.__builtins__
|
||||
real_import = real_builtins['__import__']
|
||||
# Now we can import os and execute commands
|
||||
os = real_import('os')
|
||||
# Demonstrate we have escaped the sandbox
|
||||
result = "SANDBOX_ESCAPED" if hasattr(os, 'system') else "FAILED"
|
||||
break
|
||||
"""
|
||||
|
||||
# The deprecated sandbox is vulnerable to this attack
|
||||
result = tool.run_code_in_restricted_sandbox(escape_code)
|
||||
|
||||
# This demonstrates the vulnerability - the attacker can escape
|
||||
assert result == "SANDBOX_ESCAPED", (
|
||||
"The restricted sandbox was bypassed via object introspection. "
|
||||
"This is why Docker is now required for safe code execution."
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
|
||||
@@ -408,7 +408,7 @@ def human_feedback(
|
||||
emit=list(emit) if emit else None,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata or {},
|
||||
llm=llm if isinstance(llm, str) else None,
|
||||
llm=llm if isinstance(llm, str) else getattr(llm, "model", None),
|
||||
)
|
||||
|
||||
# Determine effective provider:
|
||||
|
||||
@@ -72,7 +72,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -136,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
@@ -163,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
@@ -213,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -248,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json, context_json
|
||||
@@ -272,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -22,7 +22,12 @@ if TYPE_CHECKING:
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types import (
|
||||
Message,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
@@ -31,6 +36,11 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
TOOL_SEARCH_TOOL_TYPES: Final[tuple[str, ...]] = (
|
||||
"tool_search_tool_regex_20251119",
|
||||
"tool_search_tool_bm25_20251119",
|
||||
)
|
||||
|
||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||
|
||||
@@ -117,6 +127,22 @@ class AnthropicThinkingConfig(BaseModel):
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicToolSearchConfig(BaseModel):
|
||||
"""Configuration for Anthropic's server-side tool search.
|
||||
|
||||
When enabled, tools marked with defer_loading=True are not loaded into
|
||||
context immediately. Instead, Claude uses the tool search tool to
|
||||
dynamically discover and load relevant tools on-demand.
|
||||
|
||||
Attributes:
|
||||
type: The tool search variant to use.
|
||||
- "regex": Claude constructs regex patterns to search tool names/descriptions.
|
||||
- "bm25": Claude uses natural language queries to search tools.
|
||||
"""
|
||||
|
||||
type: Literal["regex", "bm25"] = "bm25"
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -140,6 +166,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -159,6 +186,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -190,6 +221,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
@@ -432,10 +470,23 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
converted_tools = self._convert_tools_for_interference(tools)
|
||||
|
||||
# When tool_search is enabled and there are 2+ regular tools,
|
||||
# inject the search tool and mark regular tools with defer_loading.
|
||||
# With only 1 tool there's nothing to search — skip tool search
|
||||
# entirely so the normal forced tool_choice optimisation still works.
|
||||
regular_tools = [
|
||||
t
|
||||
for t in converted_tools
|
||||
if t.get("type", "") not in TOOL_SEARCH_TOOL_TYPES
|
||||
]
|
||||
if self.tool_search is not None and len(regular_tools) >= 2:
|
||||
converted_tools = self._apply_tool_search(converted_tools)
|
||||
|
||||
params["tools"] = converted_tools
|
||||
|
||||
if available_functions and len(converted_tools) == 1:
|
||||
tool_name = converted_tools[0].get("name")
|
||||
if available_functions and len(regular_tools) == 1:
|
||||
tool_name = regular_tools[0].get("name")
|
||||
if tool_name and tool_name in available_functions:
|
||||
params["tool_choice"] = {"type": "tool", "name": tool_name}
|
||||
|
||||
@@ -454,6 +505,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
# Pass through tool search tool definitions unchanged
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in TOOL_SEARCH_TOOL_TYPES:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
|
||||
if "input_schema" in tool and "name" in tool and "description" in tool:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
@@ -466,15 +523,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
anthropic_tool = {
|
||||
anthropic_tool: dict[str, Any] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
else:
|
||||
anthropic_tool["input_schema"] = { # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
@@ -484,6 +541,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _apply_tool_search(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Inject tool search tool and mark regular tools with defer_loading.
|
||||
|
||||
When tool_search is enabled, this method:
|
||||
1. Adds the appropriate tool search tool definition (regex or bm25)
|
||||
2. Marks all regular tools with defer_loading=True so they are only
|
||||
loaded when Claude discovers them via search
|
||||
|
||||
Args:
|
||||
tools: Converted tool definitions in Anthropic format.
|
||||
|
||||
Returns:
|
||||
Updated tools list with tool search tool prepended and
|
||||
regular tools marked as deferred.
|
||||
"""
|
||||
if self.tool_search is None:
|
||||
return tools
|
||||
|
||||
# Check if a tool search tool is already present (user passed one manually)
|
||||
has_search_tool = any(
|
||||
t.get("type", "") in TOOL_SEARCH_TOOL_TYPES for t in tools
|
||||
)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
|
||||
if not has_search_tool:
|
||||
# Map config type to API type identifier
|
||||
type_map = {
|
||||
"regex": "tool_search_tool_regex_20251119",
|
||||
"bm25": "tool_search_tool_bm25_20251119",
|
||||
}
|
||||
tool_type = type_map[self.tool_search.type]
|
||||
# Tool search tool names follow the convention: tool_search_tool_{variant}
|
||||
tool_name = f"tool_search_tool_{self.tool_search.type}"
|
||||
result.append({"type": tool_type, "name": tool_name})
|
||||
|
||||
for tool in tools:
|
||||
# Don't modify tool search tools
|
||||
if tool.get("type", "") in TOOL_SEARCH_TOOL_TYPES:
|
||||
result.append(tool)
|
||||
continue
|
||||
|
||||
# Mark regular tools as deferred if not already set
|
||||
if "defer_loading" not in tool:
|
||||
tool = {**tool, "defer_loading": True}
|
||||
result.append(tool)
|
||||
|
||||
return result
|
||||
|
||||
def _extract_thinking_block(
|
||||
self, content_block: Any
|
||||
) -> ThinkingBlock | dict[str, Any] | None:
|
||||
|
||||
@@ -1781,6 +1781,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
converse_messages: list[LLMMessage] = []
|
||||
system_message: str | None = None
|
||||
pending_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
for message in formatted_messages:
|
||||
role = message.get("role")
|
||||
@@ -1794,53 +1795,56 @@ class BedrockCompletion(BaseLLM):
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = cast(str, content)
|
||||
elif role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
elif role == "tool":
|
||||
if not tool_call_id:
|
||||
raise ValueError("Tool message missing required tool_call_id")
|
||||
converse_messages.append(
|
||||
pending_tool_results.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [
|
||||
{"text": str(content) if content else ""}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [{"text": str(content) if content else ""}],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
if pending_tool_results:
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
)
|
||||
|
||||
if pending_tool_results:
|
||||
converse_messages.append({"role": "user", "content": pending_tool_results})
|
||||
|
||||
# CRITICAL: Handle model-specific conversation requirements
|
||||
# Cohere and some other models require conversation to end with user message
|
||||
|
||||
@@ -22,6 +22,7 @@ from crewai.mcp.config import (
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
@@ -74,10 +75,9 @@ class MCPToolResolver:
|
||||
elif isinstance(mcp_config, str):
|
||||
amp_refs.append(self._parse_amp_ref(mcp_config))
|
||||
else:
|
||||
tools, client = self._resolve_native(mcp_config)
|
||||
tools, clients = self._resolve_native(mcp_config)
|
||||
all_tools.extend(tools)
|
||||
if client:
|
||||
self._clients.append(client)
|
||||
self._clients.extend(clients)
|
||||
|
||||
if amp_refs:
|
||||
tools, clients = self._resolve_amp(amp_refs)
|
||||
@@ -131,7 +131,7 @@ class MCPToolResolver:
|
||||
all_tools: list[BaseTool] = []
|
||||
all_clients: list[Any] = []
|
||||
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {}
|
||||
|
||||
for slug in unique_slugs:
|
||||
config_dict = amp_configs_map.get(slug)
|
||||
@@ -149,10 +149,9 @@ class MCPToolResolver:
|
||||
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
try:
|
||||
tools, client = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, client)
|
||||
if client:
|
||||
all_clients.append(client)
|
||||
tools, clients = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, clients)
|
||||
all_clients.extend(clients)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -170,8 +169,9 @@ class MCPToolResolver:
|
||||
|
||||
slug_tools, _ = cached
|
||||
if specific_tool:
|
||||
sanitized = sanitize_tool_name(specific_tool)
|
||||
all_tools.extend(
|
||||
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
|
||||
t for t in slug_tools if t.name.endswith(f"_{sanitized}")
|
||||
)
|
||||
else:
|
||||
all_tools.extend(slug_tools)
|
||||
@@ -198,7 +198,6 @@ class MCPToolResolver:
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
|
||||
if response.status_code == 200:
|
||||
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
|
||||
return configs
|
||||
@@ -218,6 +217,7 @@ class MCPToolResolver:
|
||||
|
||||
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Resolve an HTTPS MCP server URL into tools."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
if "#" in mcp_ref:
|
||||
@@ -227,6 +227,7 @@ class MCPToolResolver:
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
@@ -239,7 +240,7 @@ class MCPToolResolver:
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
if sanitized_specific_tool and tool_name != sanitized_specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -271,14 +272,16 @@ class MCPToolResolver:
|
||||
)
|
||||
return []
|
||||
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
@staticmethod
|
||||
def _create_transport(
|
||||
mcp_config: MCPServerConfig,
|
||||
) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]:
|
||||
"""Create a fresh transport instance from an MCP server config.
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
Returns a ``(transport, server_name)`` tuple. Each call produces an
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -292,38 +295,54 @@ class MCPToolResolver:
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
return transport, server_name
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], list[Any]]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools.
|
||||
|
||||
Returns ``(tools, clients)`` where *clients* is always empty for
|
||||
native tools (clients are now created on-demand per invocation).
|
||||
A ``client_factory`` closure is passed to each ``MCPNativeTool`` so
|
||||
every call -- even concurrent calls to the *same* tool -- gets its
|
||||
own ``MCPClient`` + transport with no shared mutable state.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
discovery_transport, server_name = self._create_transport(mcp_config)
|
||||
discovery_client = MCPClient(
|
||||
transport=discovery_transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
if not discovery_client.connected:
|
||||
await discovery_client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
tools_list = await discovery_client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
if discovery_client.connected:
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
@@ -376,6 +395,13 @@ class MCPToolResolver:
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
def _client_factory() -> MCPClient:
|
||||
transport, _ = self._create_transport(mcp_config)
|
||||
return MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
@@ -396,7 +422,7 @@ class MCPToolResolver:
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
client_factory=_client_factory,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
@@ -407,10 +433,10 @@ class MCPToolResolver:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
return cast(list[BaseTool], tools), []
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
if discovery_client.connected:
|
||||
asyncio.run(discovery_client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -82,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -125,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -166,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -205,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
|
||||
@@ -12,6 +12,7 @@ import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import lancedb
|
||||
import portalocker
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
|
||||
@@ -90,6 +91,7 @@ class LanceDBStorage:
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft < 4096:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||
@@ -99,7 +101,8 @@ class LanceDBStorage:
|
||||
self._compact_every = compact_every
|
||||
self._save_count = 0
|
||||
|
||||
# Get or create a shared write lock for this database path.
|
||||
self._lockfile = str(self._path / ".lance_write.lock")
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
@@ -110,10 +113,13 @@ class LanceDBStorage:
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
@@ -124,7 +130,8 @@ class LanceDBStorage:
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
self._table = self._create_table(vector_dim)
|
||||
with self._file_lock():
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
@@ -149,18 +156,14 @@ class LanceDBStorage:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _retry_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a table operation with retry on LanceDB commit conflicts.
|
||||
def _file_lock(self) -> portalocker.Lock:
|
||||
"""Return a cross-process file lock for serialising writes."""
|
||||
return portalocker.Lock(self._lockfile, timeout=120)
|
||||
|
||||
Args:
|
||||
op: Method name on the table object (e.g. "add", "delete").
|
||||
*args, **kwargs: Passed to the table method.
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
LanceDB uses optimistic concurrency: if two transactions overlap,
|
||||
the second to commit fails with an ``OSError`` containing
|
||||
"Commit conflict". This helper retries with exponential backoff,
|
||||
refreshing the table reference before each retry so the retried
|
||||
call uses the latest committed version (not a stale reference).
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -171,20 +174,24 @@ class LanceDBStorage:
|
||||
raise
|
||||
_logger.debug(
|
||||
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
||||
op, attempt + 1, _MAX_RETRIES, delay,
|
||||
op,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES,
|
||||
delay,
|
||||
)
|
||||
# Refresh table to pick up the latest version before retrying.
|
||||
# The next getattr(self._table, op) will use the fresh table.
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass # table refresh is best-effort
|
||||
pass
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
"""Create a new table with the given vector dimension."""
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
"id": "__schema_placeholder__",
|
||||
@@ -200,8 +207,11 @@ class LanceDBStorage:
|
||||
"vector": [0.0] * vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
try:
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
except ValueError:
|
||||
table = self._db.open_table(self._table_name)
|
||||
return table
|
||||
|
||||
def _ensure_scope_index(self) -> None:
|
||||
@@ -248,9 +258,9 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
self._table.optimize()
|
||||
# Refresh the scope index so new fragments are covered.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
@@ -280,7 +290,9 @@ class LanceDBStorage:
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"source": record.source or "",
|
||||
"private": record.private,
|
||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
||||
"vector": record.embedding
|
||||
if record.embedding
|
||||
else [0.0] * self._vector_dim,
|
||||
}
|
||||
|
||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||
@@ -296,7 +308,9 @@ class LanceDBStorage:
|
||||
id=str(row["id"]),
|
||||
content=str(row["content"]),
|
||||
scope=str(row["scope"]),
|
||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
||||
categories=json.loads(row["categories_str"])
|
||||
if row.get("categories_str")
|
||||
else [],
|
||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||
importance=float(row.get("importance", 0.5)),
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -316,16 +330,15 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", rows)
|
||||
# Create the scope index on the first save so it covers the initial dataset.
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
# Auto-compact every N saves so fragment files don't pile up.
|
||||
self._save_count += 1
|
||||
if self._compact_every > 0 and self._save_count % self._compact_every == 0:
|
||||
@@ -333,14 +346,14 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._retry_write("delete", f"id = '{safe_id}'")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
row = self._record_to_row(record)
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", [row])
|
||||
self._do_write("add", [row])
|
||||
|
||||
def touch_records(self, record_ids: list[str]) -> None:
|
||||
"""Update last_accessed to now for the given record IDs.
|
||||
@@ -354,11 +367,11 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
self._retry_write(
|
||||
self._do_write(
|
||||
"update",
|
||||
where=f"id IN ({ids_expr})",
|
||||
values={"last_accessed": now},
|
||||
@@ -390,13 +403,17 @@ class LanceDBStorage:
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
distance = row.get("_distance", 0.0)
|
||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||
@@ -416,20 +433,24 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
if categories and not any(
|
||||
c in record.categories for c in categories
|
||||
):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
@@ -438,7 +459,7 @@ class LanceDBStorage:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
@@ -450,11 +471,11 @@ class LanceDBStorage:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", "id != ''")
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", where_expr)
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
|
||||
def _scan_rows(
|
||||
@@ -528,7 +549,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix):]
|
||||
rest = sc[len(child_prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(child_prefix + first_component)
|
||||
@@ -539,7 +560,11 @@ class LanceDBStorage:
|
||||
pass
|
||||
created = row.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
||||
dt = (
|
||||
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||
if isinstance(created, str)
|
||||
else created
|
||||
)
|
||||
if isinstance(dt, datetime):
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
@@ -562,7 +587,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix):]
|
||||
rest = sc[len(prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(prefix + first_component)
|
||||
@@ -590,17 +615,17 @@ class LanceDBStorage:
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
# Dimension is preserved; table will be recreated on next save.
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
||||
with self._write_lock, self._file_lock():
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._do_write("delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Compact the table synchronously and refresh the scope index.
|
||||
@@ -614,8 +639,9 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
with self._write_lock, self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
self.save(records)
|
||||
|
||||
@@ -28,7 +28,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
with portalocker.Lock(lockfile, timeout=120):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
|
||||
@@ -1,29 +1,30 @@
|
||||
"""Native MCP tool wrapper for CrewAI agents.
|
||||
|
||||
This module provides a tool wrapper that reuses existing MCP client sessions
|
||||
for better performance and connection management.
|
||||
This module provides a tool wrapper that creates a fresh MCP client for every
|
||||
invocation, ensuring safe parallel execution even when the same tool is called
|
||||
concurrently by the executor.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class MCPNativeTool(BaseTool):
|
||||
"""Native MCP tool that reuses client sessions.
|
||||
"""Native MCP tool that creates a fresh client per invocation.
|
||||
|
||||
This tool wrapper is used when agents connect to MCP servers using
|
||||
structured configurations. It reuses existing client sessions for
|
||||
better performance and proper connection lifecycle management.
|
||||
|
||||
Unlike MCPToolWrapper which connects on-demand, this tool uses
|
||||
a shared MCP client instance that maintains a persistent connection.
|
||||
A ``client_factory`` callable produces an independent ``MCPClient`` +
|
||||
transport for every ``_run_async`` call. This guarantees that parallel
|
||||
invocations -- whether of the *same* tool or *different* tools from the
|
||||
same server -- never share mutable connection state (which would cause
|
||||
anyio cancel-scope errors).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_client: Any,
|
||||
client_factory: Callable[[], Any],
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
@@ -32,19 +33,16 @@ class MCPNativeTool(BaseTool):
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
client_factory: Zero-arg callable that returns a new MCPClient.
|
||||
tool_name: Name of the tool (may be prefixed).
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
@@ -57,16 +55,9 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._client_factory = client_factory
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def mcp_client(self) -> Any:
|
||||
"""Get the MCP client instance."""
|
||||
return self._mcp_client
|
||||
|
||||
@property
|
||||
def original_tool_name(self) -> str:
|
||||
@@ -108,51 +99,26 @@ class MCPNativeTool(BaseTool):
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
A fresh ``MCPClient`` is created for every invocation so that
|
||||
concurrent calls never share transport or session state.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||
# use anyio.create_task_group() which can't span different event loops
|
||||
if self._mcp_client.connected:
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
await self._mcp_client.connect()
|
||||
client = self._client_factory()
|
||||
await client.connect()
|
||||
|
||||
try:
|
||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"not connected" in error_str
|
||||
or "connection" in error_str
|
||||
or "send" in error_str
|
||||
):
|
||||
await self._mcp_client.disconnect()
|
||||
await self._mcp_client.connect()
|
||||
# Retry the call
|
||||
result = await self._mcp_client.call_tool(
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
result = await client.call_tool(self.original_tool_name, kwargs)
|
||||
finally:
|
||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||
# This prevents "exit cancel scope in different task" errors
|
||||
# All transport context managers must be exited in the same event loop they were entered
|
||||
await self._mcp_client.disconnect()
|
||||
await client.disconnect()
|
||||
|
||||
# Extract result content
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
|
||||
# Handle various result formats
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
|
||||
@@ -2353,3 +2353,68 @@ def test_agent_without_apps_no_platform_tools():
|
||||
|
||||
tools = crew._prepare_tools(agent, task, [])
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_slug_with_specific_tool():
|
||||
"""Agent(mcps=["notion#get_page"]) must pass validation (_SLUG_RE)."""
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get_page"],
|
||||
)
|
||||
assert agent.mcps == ["notion#get_page"]
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_slug_with_hyphenated_tool():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get-page"],
|
||||
)
|
||||
assert agent.mcps == ["notion#get-page"]
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_multiple_hash_refs():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get_page", "notion#search", "github#list_repos"],
|
||||
)
|
||||
assert len(agent.mcps) == 3
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_mixed_ref_types():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=[
|
||||
"notion#get_page",
|
||||
"notion",
|
||||
"https://mcp.example.com/api",
|
||||
],
|
||||
)
|
||||
assert len(agent.mcps) == 3
|
||||
|
||||
|
||||
def test_agent_mcps_rejects_hash_without_slug():
|
||||
with pytest.raises(ValueError, match="Invalid MCP reference"):
|
||||
Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["#get_page"],
|
||||
)
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_legacy_prefix_with_tool():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["crewai-amp:notion#get_page"],
|
||||
)
|
||||
assert agent.mcps == ["crewai-amp:notion#get_page"]
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"type":"tool_search_tool_bm25_20251119","name":"tool_search_tool_bm25"},{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]},"defer_loading":true},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]},"defer_loading":true},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]},"defer_loading":true},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]},"defer_loading":true},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]},"defer_loading":true},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]},"defer_loading":true},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]},"defer_loading":true},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]},"defer_loading":true},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]},"defer_loading":true},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]},"defer_loading":true},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]},"defer_loading":true},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]},"defer_loading":true},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]},"defer_loading":true},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]},"defer_loading":true},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]},"defer_loading":true}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '3952'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
x-api-key:
|
||||
- X-API-KEY-XXX
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 0.73.0
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
x-stainless-timeout:
|
||||
- NOT_GIVEN
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01DAGCoL6C12u6yAgR1UqNAs","type":"message","role":"assistant","content":[{"type":"text","text":"I''ll
|
||||
search for a weather-related tool to help you get the weather information
|
||||
for Tokyo."},{"type":"server_tool_use","id":"srvtoolu_0176qgHeeBpSygYAnUzKHCfh","name":"tool_search_tool_bm25","input":{"query":"weather
|
||||
Tokyo current conditions forecast"},"caller":{"type":"direct"}},{"type":"tool_search_tool_result","tool_use_id":"srvtoolu_0176qgHeeBpSygYAnUzKHCfh","content":{"type":"tool_search_tool_search_result","tool_references":[{"type":"tool_reference","tool_name":"get_weather"}]}},{"type":"text","text":"Great!
|
||||
I found a weather tool. Let me get the current weather conditions for Tokyo."},{"type":"tool_use","id":"toolu_01R3FavQLuTrwNvEk9gMaViK","name":"get_weather","input":{"input":"Tokyo"},"caller":{"type":"direct"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1566,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":155,"service_tier":"standard","inference_geo":"not_available","server_tool_use":{"web_search_requests":0,"web_fetch_requests":0}}}'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Security-Policy:
|
||||
- CSP-FILTERED
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 08 Mar 2026 21:04:12 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Robots-Tag:
|
||||
- none
|
||||
anthropic-organization-id:
|
||||
- ANTHROPIC-ORGANIZATION-ID-XXX
|
||||
anthropic-ratelimit-input-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-input-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-input-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX
|
||||
anthropic-ratelimit-output-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-output-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-output-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX
|
||||
anthropic-ratelimit-requests-limit:
|
||||
- '20000'
|
||||
anthropic-ratelimit-requests-remaining:
|
||||
- '19999'
|
||||
anthropic-ratelimit-requests-reset:
|
||||
- '2026-03-08T21:04:07Z'
|
||||
anthropic-ratelimit-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
request-id:
|
||||
- REQUEST-ID-XXX
|
||||
strict-transport-security:
|
||||
- STS-XXX
|
||||
vary:
|
||||
- Accept-Encoding
|
||||
x-envoy-upstream-service-time:
|
||||
- '4330'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,112 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]}},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]}},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]}},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]}},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]}},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]}},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]}},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]}},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]}},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]}},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]}},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]}},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]}},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]}},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]}}]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01NoSearch001","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01NoSearch001","name":"get_weather","input":{"input":"Tokyo"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1943,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":54,"service_tier":"standard"}}'
|
||||
headers:
|
||||
Content-Type:
|
||||
- application/json
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"type":"tool_search_tool_bm25_20251119","name":"tool_search_tool_bm25"},{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]},"defer_loading":true},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]},"defer_loading":true},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]},"defer_loading":true},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]},"defer_loading":true},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]},"defer_loading":true},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]},"defer_loading":true},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]},"defer_loading":true},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]},"defer_loading":true},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]},"defer_loading":true},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]},"defer_loading":true},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]},"defer_loading":true},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]},"defer_loading":true},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]},"defer_loading":true},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]},"defer_loading":true},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]},"defer_loading":true}]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01WithSearch001","type":"message","role":"assistant","content":[{"type":"text","text":"I''ll search for a weather tool."},{"type":"server_tool_use","id":"srvtoolu_01Search001","name":"tool_search_tool_bm25","input":{"query":"weather conditions"},"caller":{"type":"direct"}},{"type":"tool_search_tool_result","tool_use_id":"srvtoolu_01Search001","content":{"type":"tool_search_tool_search_result","tool_references":[{"type":"tool_reference","tool_name":"get_weather"}]}},{"type":"text","text":"Found it. Let me get the weather for Tokyo."},{"type":"tool_use","id":"toolu_01WithSearch001","name":"get_weather","input":{"input":"Tokyo"},"caller":{"type":"direct"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1566,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":155,"service_tier":"standard"}}'
|
||||
headers:
|
||||
Content-Type:
|
||||
- application/json
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1121,3 +1121,345 @@ def test_anthropic_cached_prompt_tokens_with_tools():
|
||||
assert usage.successful_requests == 2
|
||||
# The second call should have cached prompt tokens
|
||||
assert usage.cached_prompt_tokens > 0
|
||||
|
||||
|
||||
# ---- Tool Search Tool Tests ----
|
||||
|
||||
|
||||
def test_tool_search_true_injects_bm25_and_defer_loading():
|
||||
"""tool_search=True should inject bm25 tool search and defer all tools."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculator",
|
||||
"description": "Perform math calculations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
# Should have 3 tools: tool_search + 2 regular
|
||||
assert len(tools) == 3
|
||||
|
||||
# First tool should be the bm25 tool search tool
|
||||
assert tools[0]["type"] == "tool_search_tool_bm25_20251119"
|
||||
assert tools[0]["name"] == "tool_search_tool_bm25"
|
||||
assert "input_schema" not in tools[0]
|
||||
|
||||
# All regular tools should have defer_loading=True
|
||||
for t in tools[1:]:
|
||||
assert t.get("defer_loading") is True, f"Tool {t['name']} missing defer_loading"
|
||||
|
||||
|
||||
def test_tool_search_regex_config():
|
||||
"""tool_search with regex config should use regex variant."""
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicToolSearchConfig
|
||||
|
||||
config = AnthropicToolSearchConfig(type="regex")
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=config)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_a",
|
||||
"description": "First tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_b",
|
||||
"description": "Second tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
assert tools[0]["type"] == "tool_search_tool_regex_20251119"
|
||||
assert tools[0]["name"] == "tool_search_tool_regex"
|
||||
|
||||
|
||||
def test_tool_search_disabled_by_default():
|
||||
"""tool_search=None (default) should NOT inject anything."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
assert len(tools) == 1
|
||||
for t in tools:
|
||||
assert t.get("type", "") not in (
|
||||
"tool_search_tool_bm25_20251119",
|
||||
"tool_search_tool_regex_20251119",
|
||||
)
|
||||
assert "defer_loading" not in t
|
||||
|
||||
|
||||
def test_tool_search_no_duplicate_when_manually_provided():
|
||||
"""If user passes a tool search tool manually, don't inject a duplicate."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
# User manually includes a tool search tool
|
||||
tools_with_search = [
|
||||
{"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, tools_with_search
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
search_tools = [
|
||||
t for t in tools
|
||||
if t.get("type", "").startswith("tool_search_tool")
|
||||
]
|
||||
# Should only have 1 tool search tool (the user's manual one)
|
||||
assert len(search_tools) == 1
|
||||
assert search_tools[0]["type"] == "tool_search_tool_regex_20251119"
|
||||
|
||||
|
||||
def test_tool_search_passthrough_preserves_tool_search_type():
|
||||
"""_convert_tools_for_interference should pass through tool search tools unchanged."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
|
||||
tools = [
|
||||
{"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"},
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
converted = llm._convert_tools_for_interference(tools)
|
||||
assert len(converted) == 2
|
||||
# Tool search tool should be passed through exactly
|
||||
assert converted[0] == {
|
||||
"type": "tool_search_tool_regex_20251119",
|
||||
"name": "tool_search_tool_regex",
|
||||
}
|
||||
# Regular tool should be preserved
|
||||
assert converted[1]["name"] == "get_weather"
|
||||
assert "input_schema" in converted[1]
|
||||
|
||||
|
||||
def test_tool_search_single_tool_skips_search_and_forces_choice():
|
||||
"""With only 1 tool, tool_search is skipped (nothing to search) and the
|
||||
normal forced tool_choice optimisation still applies."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages,
|
||||
system_message,
|
||||
crewai_tools,
|
||||
available_functions={"test_tool": lambda q: "result"},
|
||||
)
|
||||
|
||||
# Single tool — tool_search skipped, tool_choice forced as normal
|
||||
assert "tool_choice" in params
|
||||
assert params["tool_choice"]["name"] == "test_tool"
|
||||
|
||||
# No tool search tool should be injected
|
||||
tool_types = [t.get("type", "") for t in params["tools"]]
|
||||
for ts_type in ("tool_search_tool_bm25_20251119", "tool_search_tool_regex_20251119"):
|
||||
assert ts_type not in tool_types
|
||||
|
||||
# No defer_loading on the single tool
|
||||
assert "defer_loading" not in params["tools"][0]
|
||||
|
||||
|
||||
def test_tool_search_via_llm_class():
|
||||
"""Verify tool_search param passes through LLM class correctly."""
|
||||
from crewai.llms.providers.anthropic.completion import (
|
||||
AnthropicCompletion,
|
||||
AnthropicToolSearchConfig,
|
||||
)
|
||||
|
||||
# Test with True
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.tool_search is not None
|
||||
assert llm.tool_search.type == "bm25"
|
||||
|
||||
# Test with config
|
||||
llm2 = LLM(
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
tool_search=AnthropicToolSearchConfig(type="regex"),
|
||||
)
|
||||
assert llm2.tool_search is not None
|
||||
assert llm2.tool_search.type == "regex"
|
||||
|
||||
# Test without (default)
|
||||
llm3 = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
assert llm3.tool_search is None
|
||||
|
||||
|
||||
# Many tools shared by the VCR tests below
|
||||
_MANY_TOOLS = [
|
||||
{
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string", "description": f"Input for {name}"}},
|
||||
"required": ["input"],
|
||||
},
|
||||
}
|
||||
for name, desc in [
|
||||
("get_weather", "Get current weather conditions for a specified location"),
|
||||
("search_files", "Search through files in the workspace by name or content"),
|
||||
("read_database", "Read records from a database table with optional filtering"),
|
||||
("write_database", "Write or update records in a database table"),
|
||||
("send_email", "Send an email message to one or more recipients"),
|
||||
("read_email", "Read emails from inbox with filtering options"),
|
||||
("create_ticket", "Create a new support ticket in the ticketing system"),
|
||||
("update_ticket", "Update an existing support ticket status or description"),
|
||||
("list_users", "List all users in the system with optional filters"),
|
||||
("get_user_profile", "Get detailed profile information for a specific user"),
|
||||
("deploy_service", "Deploy a service to the specified environment"),
|
||||
("rollback_service", "Rollback a service deployment to a previous version"),
|
||||
("get_service_logs", "Get service logs filtered by time range and severity"),
|
||||
("run_sql_query", "Run a read-only SQL query against the analytics database"),
|
||||
("create_dashboard", "Create a new monitoring dashboard with widgets"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_search_discovers_and_calls_tool():
|
||||
"""Tool search should discover the right tool and return a tool_use block."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
result = llm.call(
|
||||
"What is the weather in Tokyo?",
|
||||
tools=_MANY_TOOLS,
|
||||
)
|
||||
|
||||
# Should return tool_use blocks (list) since no available_functions provided
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
# The discovered tool should be get_weather
|
||||
tool_names = [getattr(block, "name", None) for block in result]
|
||||
assert "get_weather" in tool_names
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_search_saves_input_tokens():
|
||||
"""Tool search with deferred loading should use fewer input tokens than loading all tools."""
|
||||
# Call WITHOUT tool search — all 15 tools loaded upfront
|
||||
llm_no_search = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
llm_no_search.call("What is the weather in Tokyo?", tools=_MANY_TOOLS)
|
||||
usage_no_search = llm_no_search.get_token_usage_summary()
|
||||
|
||||
# Call WITH tool search — tools deferred
|
||||
llm_search = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
llm_search.call("What is the weather in Tokyo?", tools=_MANY_TOOLS)
|
||||
usage_search = llm_search.get_token_usage_summary()
|
||||
|
||||
# Tool search should use fewer input tokens
|
||||
assert usage_search.prompt_tokens < usage_no_search.prompt_tokens, (
|
||||
f"Expected tool_search ({usage_search.prompt_tokens}) to use fewer input tokens "
|
||||
f"than no search ({usage_no_search.prompt_tokens})"
|
||||
)
|
||||
|
||||
@@ -967,3 +967,211 @@ def test_bedrock_agent_kickoff_structured_output_with_tools():
|
||||
assert result.pydantic.result == 42, f"Expected result 42 but got {result.pydantic.result}"
|
||||
assert result.pydantic.operation, "Operation should not be empty"
|
||||
assert result.pydantic.explanation, "Explanation should not be empty"
|
||||
|
||||
|
||||
def test_bedrock_groups_three_tool_results():
|
||||
"""Consecutive tool results should be grouped into one Bedrock user message."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Use all three tools, then continue."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"arguments": '{"location": "New York"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tool-2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_news",
|
||||
"arguments": '{"topic": "AI"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tool-3",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_stock",
|
||||
"arguments": '{"ticker": "AMZN"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tool-1", "content": "72F and sunny"},
|
||||
{"role": "tool", "tool_call_id": "tool-2", "content": "AI news summary"},
|
||||
{"role": "tool", "tool_call_id": "tool-3", "content": "AMZN up 1.2%"},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_converse(messages)
|
||||
|
||||
assert system_message is None
|
||||
assert [message["role"] for message in formatted_messages] == [
|
||||
"user",
|
||||
"assistant",
|
||||
"user",
|
||||
]
|
||||
assert len(formatted_messages[1]["content"]) == 3
|
||||
|
||||
tool_results = formatted_messages[2]["content"]
|
||||
assert len(tool_results) == 3
|
||||
assert [block["toolResult"]["toolUseId"] for block in tool_results] == [
|
||||
"tool-1",
|
||||
"tool-2",
|
||||
"tool-3",
|
||||
]
|
||||
assert [block["toolResult"]["content"][0]["text"] for block in tool_results] == [
|
||||
"72F and sunny",
|
||||
"AI news summary",
|
||||
"AMZN up 1.2%",
|
||||
]
|
||||
|
||||
|
||||
def test_bedrock_parallel_tool_results_grouped():
|
||||
"""Regression test for issue #4749.
|
||||
|
||||
When an assistant message contains multiple parallel tool calls,
|
||||
Bedrock requires all corresponding tool results to be grouped
|
||||
in a single user message. Previously each tool result was emitted
|
||||
as a separate user message, causing:
|
||||
ValidationException: Expected toolResult blocks at messages.2.content
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Calculate 25 + 17 AND 10 * 5"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_add",
|
||||
"type": "function",
|
||||
"function": {"name": "add_tool", "arguments": '{"a": 25, "b": 17}'},
|
||||
},
|
||||
{
|
||||
"id": "call_mul",
|
||||
"type": "function",
|
||||
"function": {"name": "multiply_tool", "arguments": '{"a": 10, "b": 5}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_add", "content": "42"},
|
||||
{"role": "tool", "tool_call_id": "call_mul", "content": "50"},
|
||||
]
|
||||
|
||||
converse_msgs, system_msg = llm._format_messages_for_converse(messages)
|
||||
|
||||
# Find the user message that contains toolResult blocks
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
|
||||
# There must be exactly ONE user message with tool results (not two)
|
||||
assert len(tool_result_messages) == 1, (
|
||||
f"Expected 1 grouped tool-result message, got {len(tool_result_messages)}. "
|
||||
"Bedrock requires all parallel tool results in a single user message."
|
||||
)
|
||||
|
||||
# That single message must contain both tool results
|
||||
tool_results = tool_result_messages[0]["content"]
|
||||
assert len(tool_results) == 2, (
|
||||
f"Expected 2 toolResult blocks in grouped message, got {len(tool_results)}"
|
||||
)
|
||||
|
||||
# Verify the tool use IDs match
|
||||
tool_use_ids = {
|
||||
block["toolResult"]["toolUseId"] for block in tool_results
|
||||
}
|
||||
assert tool_use_ids == {"call_add", "call_mul"}
|
||||
|
||||
|
||||
def test_bedrock_single_tool_result_still_works():
|
||||
"""Ensure single tool call still produces a single-block user message."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Add 1 + 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_single",
|
||||
"type": "function",
|
||||
"function": {"name": "add_tool", "arguments": '{"a": 1, "b": 2}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_single", "content": "3"},
|
||||
]
|
||||
|
||||
converse_msgs, _ = llm._format_messages_for_converse(messages)
|
||||
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
assert len(tool_result_messages) == 1
|
||||
assert len(tool_result_messages[0]["content"]) == 1
|
||||
assert tool_result_messages[0]["content"][0]["toolResult"]["toolUseId"] == "call_single"
|
||||
|
||||
|
||||
def test_bedrock_tool_results_not_merged_across_assistant_messages():
|
||||
"""Tool results from different assistant turns must NOT be merged."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "First task"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": {"name": "tool_a", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "result_a"},
|
||||
{"role": "assistant", "content": "Now doing second task"},
|
||||
{"role": "user", "content": "Second task"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": {"name": "tool_b", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "result_b"},
|
||||
]
|
||||
|
||||
converse_msgs, _ = llm._format_messages_for_converse(messages)
|
||||
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
|
||||
# Two separate tool-result messages (one per assistant turn)
|
||||
assert len(tool_result_messages) == 2, (
|
||||
"Tool results from different assistant turns must remain separate"
|
||||
)
|
||||
assert tool_result_messages[0]["content"][0]["toolResult"]["toolUseId"] == "call_a"
|
||||
assert tool_result_messages[1]["content"][0]["toolResult"]["toolUseId"] == "call_b"
|
||||
|
||||
@@ -268,6 +268,54 @@ class TestGetMCPToolsAmpIntegration:
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "mcp_notion_so_sse_search"
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_tool_filter_with_hyphenated_hash_syntax(
|
||||
self, mock_fetch, mock_client_class, agent
|
||||
):
|
||||
"""notion#get-page must match the tool whose sanitized name is get_page."""
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
hyphenated_tool_definitions = [
|
||||
{
|
||||
"name": "get_page",
|
||||
"original_name": "get-page",
|
||||
"description": "Get a page",
|
||||
"inputSchema": {},
|
||||
},
|
||||
{
|
||||
"name": "search",
|
||||
"original_name": "search",
|
||||
"description": "Search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=hyphenated_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#get-page"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name.endswith("_get_page")
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_deduplicates_slugs(
|
||||
@@ -371,3 +419,87 @@ class TestGetMCPToolsAmpIntegration:
|
||||
mock_external.assert_called_once_with("https://external.mcp.com/api")
|
||||
# 2 from notion + 1 from external + 2 from http_config
|
||||
assert len(tools) == 5
|
||||
|
||||
|
||||
class TestResolveExternalToolFilter:
|
||||
"""Tests for _resolve_external with #tool-name filtering."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(self, agent):
|
||||
return MCPToolResolver(agent=agent, logger=agent._logger)
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_filters_hyphenated_tool_name(self, mock_schemas, resolver):
|
||||
"""https://...#get-page must match the sanitized key get_page in schemas."""
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#get-page")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "get_page" in tools[0].name
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_filters_underscored_tool_name(self, mock_schemas, resolver):
|
||||
"""https://...#get_page must also match the sanitized key get_page."""
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#get_page")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "get_page" in tools[0].name
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_returns_all_tools_without_hash(self, mock_schemas, resolver):
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api")
|
||||
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_returns_empty_for_nonexistent_tool(self, mock_schemas, resolver):
|
||||
mock_schemas.return_value = {
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#nonexistent")
|
||||
|
||||
assert len(tools) == 0
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -30,6 +31,17 @@ def mock_tool_definitions():
|
||||
]
|
||||
|
||||
|
||||
def _make_mock_client(tool_definitions):
|
||||
"""Create a mock MCPClient that returns *tool_definitions*."""
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
client.call_tool = AsyncMock(return_value="test result")
|
||||
return client
|
||||
|
||||
|
||||
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
"""Test agent setup with MCPServerStdio configuration."""
|
||||
stdio_config = MCPServerStdio(
|
||||
@@ -45,14 +57,8 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
mcps=[stdio_config],
|
||||
)
|
||||
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([stdio_config])
|
||||
|
||||
@@ -60,8 +66,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.command == "python"
|
||||
assert transport.args == ["server.py"]
|
||||
assert transport.env == {"API_KEY": "test_key"}
|
||||
@@ -83,12 +88,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
|
||||
@@ -96,8 +96,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
assert transport.streamable is True
|
||||
@@ -118,12 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([sse_config])
|
||||
|
||||
@@ -131,8 +125,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp/sse"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@@ -142,13 +135,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
@@ -160,12 +147,12 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
# 1 discovery + 1 for the run() invocation
|
||||
assert mock_client_class.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -174,13 +161,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
@@ -192,9 +173,129 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
assert mock_client_class.call_count == 2
|
||||
|
||||
|
||||
def test_each_invocation_gets_fresh_client(mock_tool_definitions):
|
||||
"""Every tool.run() must create its own MCPClient (no shared state)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
clients_created: list = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = _make_mock_client(mock_tool_definitions)
|
||||
clients_created.append(client)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
# 1 discovery client so far
|
||||
assert len(clients_created) == 1
|
||||
|
||||
# Two sequential calls to the same tool must create 2 new clients
|
||||
tools[0].run(query="q1")
|
||||
tools[0].run(query="q2")
|
||||
assert len(clients_created) == 3
|
||||
assert clients_created[1] is not clients_created[2]
|
||||
|
||||
|
||||
def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions):
|
||||
"""Parallel calls to the *same* tool must not interfere."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
|
||||
async def _call_tool(name, args):
|
||||
call_log.append(name)
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result-{name}"
|
||||
|
||||
client.call_tool = AsyncMock(side_effect=_call_tool)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) >= 1
|
||||
tool = tools[0]
|
||||
|
||||
# Call the SAME tool concurrently -- the exact scenario from the bug
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(tool.run, query="q1"),
|
||||
pool.submit(tool.run, query="q2"),
|
||||
]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all("result-" in r for r in results)
|
||||
assert len(call_log) == 2
|
||||
|
||||
|
||||
def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions):
|
||||
"""Parallel calls to different tools from the same server must not interfere."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
|
||||
async def _call_tool(name, args):
|
||||
call_log.append(name)
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result-{name}"
|
||||
|
||||
client.call_tool = AsyncMock(side_effect=_call_tool)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(tools[0].run, query="q1"),
|
||||
pool.submit(tools[1].run, query="q2"),
|
||||
]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all("result-" in r for r in results)
|
||||
assert len(call_log) == 2
|
||||
|
||||
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Stress tests for concurrent multi-process storage access.
|
||||
|
||||
Simulates the Airflow pattern: N worker processes each writing to the
|
||||
same storage directory simultaneously. Verifies no LockException and
|
||||
data integrity after all writes complete.
|
||||
|
||||
Uses temp files for IPC instead of multiprocessing.Manager (which uses
|
||||
sockets blocked by pytest_recording).
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File-based IPC helpers (avoids Manager sockets)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_result(result_dir: str, worker_id: int, success: bool, error: str = ""):
|
||||
path = os.path.join(result_dir, f"worker-{worker_id}.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump({"success": success, "error": error}, f)
|
||||
|
||||
|
||||
def _collect_results(result_dir: str, n_workers: int):
|
||||
errors = {}
|
||||
successes = 0
|
||||
for wid in range(n_workers):
|
||||
path = os.path.join(result_dir, f"worker-{wid}.json")
|
||||
if not os.path.exists(path):
|
||||
errors[wid] = "Process produced no output (crashed or timed out)"
|
||||
continue
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if data["success"]:
|
||||
successes += 1
|
||||
else:
|
||||
errors[wid] = data["error"]
|
||||
return successes, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Worker functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lancedb_worker(path: str, worker_id: int, n_records: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
storage = LanceDBStorage(path=path, table_name="memories", vector_dim=8)
|
||||
records = [
|
||||
MemoryRecord(
|
||||
id=f"worker-{worker_id}-record-{i}",
|
||||
content=f"content from worker {worker_id} record {i}",
|
||||
scope=f"/test/worker-{worker_id}",
|
||||
categories=["test"],
|
||||
metadata={"worker": worker_id},
|
||||
importance=0.5,
|
||||
embedding=[float(worker_id)] * 8,
|
||||
)
|
||||
for i in range(n_records)
|
||||
]
|
||||
storage.save(records)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_kickoff_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
f"worker-{worker_id}-task-{i}",
|
||||
"expected output",
|
||||
'{"result": "ok"}',
|
||||
worker_id * 1000 + i,
|
||||
"{}",
|
||||
False,
|
||||
),
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
persistence = SQLiteFlowPersistence(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
persistence.save_state(
|
||||
flow_uuid=f"flow-{worker_id}-{i}",
|
||||
method_name="test_method",
|
||||
state_data={"worker": worker_id, "iteration": i},
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str):
|
||||
try:
|
||||
from hashlib import md5
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
import portalocker
|
||||
|
||||
settings = Settings(
|
||||
persist_directory=persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
# Test the actual locking path directly (same as factory.py)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
with portalocker.Lock(lockfile, timeout=120):
|
||||
PersistentClient(path=persist_dir, settings=settings)
|
||||
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
N_WORKERS = 6
|
||||
N_RECORDS = 20
|
||||
|
||||
|
||||
def _run_workers(target, args_fn, n_workers=N_WORKERS, timeout=120):
|
||||
"""Spawn n_workers processes and collect results via temp files."""
|
||||
with tempfile.TemporaryDirectory() as result_dir:
|
||||
procs = []
|
||||
for wid in range(n_workers):
|
||||
p = multiprocessing.Process(
|
||||
target=target,
|
||||
args=args_fn(wid, result_dir),
|
||||
)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join(timeout=timeout)
|
||||
|
||||
successes, errors = _collect_results(result_dir, n_workers)
|
||||
return successes, errors
|
||||
|
||||
|
||||
class TestConcurrentLanceDB:
|
||||
"""Concurrent multi-process writes to LanceDB."""
|
||||
|
||||
def test_concurrent_saves_no_lock_exception(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_concurrent")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
def test_data_integrity_after_concurrent_saves(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_integrity")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=db_path, table_name="memories", vector_dim=8)
|
||||
total = storage.count()
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert total == expected, f"Expected {expected} records, got {total}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteKickoff:
|
||||
"""Concurrent multi-process writes to kickoff task outputs SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "kickoff.db")
|
||||
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_kickoff_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM latest_kickoff_task_outputs"
|
||||
).fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteFlow:
|
||||
"""Concurrent multi-process writes to flow persistence SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "flow_states.db")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_flow_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) FROM flow_states").fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentChromaDB:
|
||||
"""Concurrent multi-process ChromaDB client creation."""
|
||||
|
||||
def test_concurrent_client_creation_no_lock_exception(self, tmp_path):
|
||||
persist_dir = str(tmp_path / "chromadb_concurrent")
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_chromadb_worker,
|
||||
lambda wid, rd: (persist_dir, wid, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
@@ -971,6 +971,128 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
assert mock_llm.call.call_count == 2
|
||||
|
||||
|
||||
class TestLLMObjectPreservedInContext:
|
||||
"""Tests that BaseLLM objects have their model string preserved in PendingFeedbackContext."""
|
||||
|
||||
@patch("crewai.flow.flow.crewai_event_bus.emit")
|
||||
def test_basellm_object_model_string_survives_roundtrip(self, mock_emit: MagicMock) -> None:
|
||||
"""Test that when llm is a BaseLLM object, its model string is stored in context
|
||||
so that outcome collapsing works after async pause/resume.
|
||||
|
||||
This is the exact bug: locally the sync path keeps the LLM object in memory,
|
||||
but in production the async path serializes the context and the LLM object was
|
||||
discarded (stored as None), causing resume to skip classification and always
|
||||
fall back to emit[0].
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
# Create a mock BaseLLM object (not a string)
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj.model = "gemini/gemini-2.0-flash"
|
||||
|
||||
class PausingProvider:
|
||||
def __init__(self, persistence: SQLiteFlowPersistence):
|
||||
self.persistence = persistence
|
||||
self.captured_context: PendingFeedbackContext | None = None
|
||||
|
||||
def request_feedback(
|
||||
self, context: PendingFeedbackContext, flow: Flow
|
||||
) -> str:
|
||||
self.captured_context = context
|
||||
self.persistence.save_pending_feedback(
|
||||
flow_uuid=context.flow_id,
|
||||
context=context,
|
||||
state_data=flow.state if isinstance(flow.state, dict) else flow.state.model_dump(),
|
||||
)
|
||||
raise HumanFeedbackPending(context=context)
|
||||
|
||||
provider = PausingProvider(persistence)
|
||||
|
||||
class TestFlow(Flow):
|
||||
result_path: str = ""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Approve?",
|
||||
emit=["needs_changes", "approved"],
|
||||
llm=mock_llm_obj,
|
||||
default_outcome="approved",
|
||||
provider=provider,
|
||||
)
|
||||
def review(self):
|
||||
return "content for review"
|
||||
|
||||
@listen("approved")
|
||||
def handle_approved(self):
|
||||
self.result_path = "approved"
|
||||
return "Approved!"
|
||||
|
||||
@listen("needs_changes")
|
||||
def handle_changes(self):
|
||||
self.result_path = "needs_changes"
|
||||
return "Changes needed"
|
||||
|
||||
# Phase 1: Start flow (should pause)
|
||||
flow1 = TestFlow(persistence=persistence)
|
||||
result = flow1.kickoff()
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
|
||||
# Verify the context stored the model STRING, not None
|
||||
assert provider.captured_context is not None
|
||||
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Verify it survives persistence roundtrip
|
||||
flow_id = result.context.flow_id
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
assert loaded is not None
|
||||
_, loaded_context = loaded
|
||||
assert loaded_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Phase 2: Resume with positive feedback - should use LLM to classify
|
||||
flow2 = TestFlow.from_pending(flow_id, persistence)
|
||||
assert flow2._pending_feedback_context is not None
|
||||
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Mock _collapse_to_outcome to verify it gets called (not skipped)
|
||||
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
|
||||
flow2.resume("this looks good, proceed!")
|
||||
|
||||
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
|
||||
mock_collapse.assert_called_once_with(
|
||||
feedback="this looks good, proceed!",
|
||||
outcomes=["needs_changes", "approved"],
|
||||
llm="gemini/gemini-2.0-flash",
|
||||
)
|
||||
assert flow2.last_human_feedback.outcome == "approved"
|
||||
assert flow2.result_path == "approved"
|
||||
|
||||
def test_string_llm_still_works(self) -> None:
|
||||
"""Test that passing llm as a string still works correctly."""
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="str-llm-test",
|
||||
flow_class="test.Flow",
|
||||
method_name="review",
|
||||
method_output="output",
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
serialized = context.to_dict()
|
||||
restored = PendingFeedbackContext.from_dict(serialized)
|
||||
assert restored.llm == "gpt-4o-mini"
|
||||
|
||||
def test_none_llm_when_no_model_attr(self) -> None:
|
||||
"""Test that llm is None when object has no model attribute."""
|
||||
mock_obj = MagicMock(spec=[]) # No attributes
|
||||
|
||||
# Simulate what the decorator does
|
||||
llm_value = mock_obj if isinstance(mock_obj, str) else getattr(mock_obj, "model", None)
|
||||
assert llm_value is None
|
||||
|
||||
|
||||
class TestAsyncHumanFeedbackEdgeCases:
|
||||
"""Edge case tests for async human feedback."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user