Compare commits

...

2 Commits

Author SHA1 Message Date
Sampson
d9f6e2222f Introduce more Brave Search tools (#4446)
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Check Documentation Broken Links / Check broken links (push) Waiting to run
* feat: add dedicated Brave Search tools for web, news, image, video, local POIs, and Brave's newest LLM Context endpoint

* fix: normalize transformed response shape

* revert legacy tool name

* fix: schema change prevented property resolution

* Update tool.specs.json

* fix: add fallback for search_langugage

* simplify exports

* makes rate-limiting logic per-instance

* fix(brave-tools): correct _refine_response return type annotations

The abstract method and subclasses annotated _refine_response as returning
dict[str, Any] but most implementations actually return list[dict[str, Any]].
Updated base to return Any, and each subclass to match its actual return type.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Joao Moura <joaomdmoura@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 01:38:54 -03:00
Lucas Gomide
adef605410 fix: add missing list/dict methods to LockedListProxy and LockedDictProxy
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
2026-03-09 09:38:35 -04:00
19 changed files with 4374 additions and 172 deletions

View File

@@ -1,97 +1,316 @@
---
title: Brave Search
description: The `BraveSearchTool` is designed to search the internet using the Brave Search API.
title: Brave Search Tools
description: A suite of tools for querying the Brave Search API — covering web, news, image, and video search.
icon: searchengin
mode: "wide"
---
# `BraveSearchTool`
# Brave Search Tools
## Description
This tool is designed to perform web searches using the Brave Search API. It allows you to search the internet with a specified query and retrieve relevant results. The tool supports customizable result counts and country-specific searches.
CrewAI offers a family of Brave Search tools, each targeting a specific [Brave Search API](https://brave.com/search/api/) endpoint.
Rather than a single catch-all tool, you can pick exactly the tool that matches the kind of results your agent needs:
| Tool | Endpoint | Use case |
| --- | --- | --- |
| `BraveWebSearchTool` | Web Search | General web results, snippets, and URLs |
| `BraveNewsSearchTool` | News Search | Recent news articles and headlines |
| `BraveImageSearchTool` | Image Search | Image results with dimensions and source URLs |
| `BraveVideoSearchTool` | Video Search | Video results from across the web |
| `BraveLocalPOIsTool` | Local POIs | Find points of interest (e.g., restaurants) |
| `BraveLocalPOIsDescriptionTool` | Local POIs | Retrieve AI-generated location descriptions |
| `BraveLLMContextTool` | LLM Context | Pre-extracted web content optimized for AI agents, LLM grounding, and RAG pipelines. |
All tools share a common base class (`BraveSearchToolBase`) that provides consistent behavior — rate limiting, automatic retries on `429` responses, header and parameter validation, and optional file saving.
<Note>
The older `BraveSearchTool` class is still available for backwards compatibility, but it is considered **legacy** and will not receive the same level of attention going forward. We recommend migrating to the specific tools listed above, which offer richer configuration and a more focused interface.
</Note>
<Note>
While many tools (e.g., _BraveWebSearchTool_, _BraveNewsSearchTool_, _BraveImageSearchTool_, and _BraveVideoSearchTool_) can be used with a free Brave Search API subscription/plan, some parameters (e.g., `enable_snippets`) and tools (e.g., _BraveLocalPOIsTool_ and _BraveLocalPOIsDescriptionTool_) require a paid plan. Consult your subscription plan's capabilities for clarification.
</Note>
## Installation
To incorporate this tool into your project, follow the installation instructions below:
```shell
pip install 'crewai[tools]'
```
## Steps to Get Started
## Getting Started
To effectively use the `BraveSearchTool`, follow these steps:
1. **Install the package** — confirm that `crewai[tools]` is installed in your Python environment.
2. **Get an API key** — sign up at [api-dashboard.search.brave.com/login](https://api-dashboard.search.brave.com/login) to generate a key.
3. **Set the environment variable** — store your key as `BRAVE_API_KEY`, or pass it directly via the `api_key` parameter.
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
2. **API Key Acquisition**: Acquire a Brave Search API key at https://api.search.brave.com/app/keys (sign in to generate a key).
3. **Environment Configuration**: Store your obtained API key in an environment variable named `BRAVE_API_KEY` to facilitate its use by the tool.
## Quick Examples
## Example
The following example demonstrates how to initialize the tool and execute a search with a given query:
### Web Search
```python Code
from crewai_tools import BraveSearchTool
from crewai_tools import BraveWebSearchTool
# Initialize the tool for internet searching capabilities
tool = BraveSearchTool()
# Execute a search
results = tool.run(search_query="CrewAI agent framework")
tool = BraveWebSearchTool()
results = tool.run(q="CrewAI agent framework")
print(results)
```
## Parameters
The `BraveSearchTool` accepts the following parameters:
- **search_query**: Mandatory. The search query you want to use to search the internet.
- **country**: Optional. Specify the country for the search results. Default is empty string.
- **n_results**: Optional. Number of search results to return. Default is `10`.
- **save_file**: Optional. Whether to save the search results to a file. Default is `False`.
## Example with Parameters
Here is an example demonstrating how to use the tool with additional parameters:
### News Search
```python Code
from crewai_tools import BraveSearchTool
from crewai_tools import BraveNewsSearchTool
# Initialize the tool with custom parameters
tool = BraveSearchTool(
country="US",
n_results=5,
save_file=True
tool = BraveNewsSearchTool()
results = tool.run(q="latest AI breakthroughs")
print(results)
```
### Image Search
```python Code
from crewai_tools import BraveImageSearchTool
tool = BraveImageSearchTool()
results = tool.run(q="northern lights photography")
print(results)
```
### Video Search
```python Code
from crewai_tools import BraveVideoSearchTool
tool = BraveVideoSearchTool()
results = tool.run(q="how to build AI agents")
print(results)
```
### Location POI Descriptions
```python Code
from crewai_tools import (
BraveWebSearchTool,
BraveLocalPOIsDescriptionTool,
)
# Execute a search
results = tool.run(search_query="Latest AI developments")
print(results)
web_search = BraveWebSearchTool(raw=True)
poi_details = BraveLocalPOIsDescriptionTool()
results = web_search.run(q="italian restaurants in pensacola, florida")
if "locations" in results:
location_ids = [ loc["id"] for loc in results["locations"]["results"] ]
if location_ids:
descriptions = poi_details.run(ids=location_ids)
print(descriptions)
```
## Common Constructor Parameters
Every Brave Search tool accepts the following parameters at initialization:
| Parameter | Type | Default | Description |
| --- | --- | --- | --- |
| `api_key` | `str \| None` | `None` | Brave API key. Falls back to the `BRAVE_API_KEY` environment variable. |
| `headers` | `dict \| None` | `None` | Additional HTTP headers to send with every request (e.g., `api-version`, geolocation headers). |
| `requests_per_second` | `float` | `1.0` | Maximum request rate. The tool will sleep between calls to stay within this limit. |
| `save_file` | `bool` | `False` | When `True`, each response is written to a timestamped `.txt` file. |
| `raw` | `bool` | `False` | When `True`, the full API JSON response is returned without any refinement. |
| `timeout` | `int` | `30` | HTTP request timeout in seconds. |
| `country` | `str \| None` | `None` | Legacy shorthand for geo-targeting (e.g., `"US"`). Prefer using the `country` query parameter directly. |
| `n_results` | `int` | `10` | Legacy shorthand for result count. Prefer using the `count` query parameter directly. |
<Warning>
The `country` and `n_results` constructor parameters exist for backwards compatibility. They are applied as defaults when the corresponding query parameters (`country`, `count`) are not provided at call time. For new code, we recommend passing `country` and `count` directly as query parameters instead.
</Warning>
## Query Parameters
Each tool validates its query parameters against a Pydantic schema before sending the request.
The parameters vary slightly per endpoint — here is a summary of the most commonly used ones:
### BraveWebSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting (e.g., `"US"`). |
| `search_lang` | Two-letter language code for results (e.g., `"en"`). |
| `count` | Max number of results to return (120). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter: `"pd"` (past day), `"pw"` (past week), `"pm"` (past month), `"py"` (past year), or a date range like `"2025-01-01to2025-06-01"`. |
| `extra_snippets` | Include up to 5 additional text snippets per result. |
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
For the complete parameter and header reference, see the [Brave Web Search API documentation](https://api-dashboard.search.brave.com/api-reference/web/search/get).
### BraveNewsSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (150). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter (same options as Web Search). |
| `goggles` | Brave Goggles URL(s) and/or source for custom re-ranking. |
For the complete parameter and header reference, see the [Brave News Search API documentation](https://api-dashboard.search.brave.com/api-reference/news/news_search/get).
### BraveImageSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (1200). |
| `safesearch` | Content filter: `"off"` or `"strict"`. |
| `spellcheck` | Attempt to correct spelling errors in the query. |
For the complete parameter and header reference, see the [Brave Image Search API documentation](https://api-dashboard.search.brave.com/api-reference/images/image_search).
### BraveVideoSearchTool
| Parameter | Description |
| --- | --- |
| `q` | **(required)** Search query string (max 400 chars). |
| `country` | Two-letter country code for geo-targeting. |
| `search_lang` | Two-letter language code for results. |
| `count` | Max number of results to return (150). |
| `offset` | Skip the first N pages of results (09). |
| `safesearch` | Content filter: `"off"`, `"moderate"`, or `"strict"`. |
| `freshness` | Recency filter (same options as Web Search). |
For the complete parameter and header reference, see the [Brave Video Search API documentation](https://api-dashboard.search.brave.com/api-reference/videos/video_search/get).
### BraveLocalPOIsTool
| Parameter | Description |
| --- | --- |
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
| `search_lang` | Two-letter language code for results. |
For the complete parameter and header reference, see [Brave Local POIs API documentation](https://api-dashboard.search.brave.com/api-reference/web/local_pois).
### BraveLocalPOIsDescriptionTool
| Parameter | Description |
| --- | --- |
| `ids` | **(required)** A list of unique identifiers for the desired locations. |
For the complete parameter and header reference, see [Brave POI Descriptions API documentation](https://api-dashboard.search.brave.com/api-reference/web/poi_descriptions).
## Custom Headers
All tools support custom HTTP request headers. The Web Search tool, for example, accepts geolocation headers for location-aware results:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(
headers={
"x-loc-lat": "37.7749",
"x-loc-long": "-122.4194",
"x-loc-city": "San Francisco",
"x-loc-state": "CA",
"x-loc-country": "US",
}
)
results = tool.run(q="best coffee shops nearby")
```
You can also update headers after initialization using the `set_headers()` method:
```python Code
tool.set_headers({"api-version": "2025-01-01"})
```
## Raw Mode
By default, each tool refines the API response into a concise list of results. If you need the full, unprocessed API response, enable raw mode:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(raw=True)
full_response = tool.run(q="Brave Search API")
```
## Agent Integration Example
Here's how to integrate the `BraveSearchTool` with a CrewAI agent:
Here's how to equip a CrewAI agent with multiple Brave Search tools:
```python Code
from crewai import Agent
from crewai.project import agent
from crewai_tools import BraveSearchTool
from crewai_tools import BraveWebSearchTool, BraveNewsSearchTool
# Initialize the tool
brave_search_tool = BraveSearchTool()
web_search = BraveWebSearchTool()
news_search = BraveNewsSearchTool()
# Define an agent with the BraveSearchTool
@agent
def researcher(self) -> Agent:
return Agent(
config=self.agents_config["researcher"],
allow_delegation=False,
tools=[brave_search_tool]
tools=[web_search, news_search],
)
```
## Advanced Example
Combining multiple parameters for a targeted search:
```python Code
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(
requests_per_second=0.5, # conservative rate limit
save_file=True,
)
results = tool.run(
q="artificial intelligence news",
country="US",
search_lang="en",
count=5,
freshness="pm", # past month only
extra_snippets=True,
)
print(results)
```
## Migrating from `BraveSearchTool` (Legacy)
If you are currently using `BraveSearchTool`, switching to the new tools is straightforward:
```python Code
# Before (legacy)
from crewai_tools import BraveSearchTool
tool = BraveSearchTool(country="US", n_results=5, save_file=True)
results = tool.run(search_query="AI agents")
# After (recommended)
from crewai_tools import BraveWebSearchTool
tool = BraveWebSearchTool(save_file=True)
results = tool.run(q="AI agents", country="US", count=5)
```
Key differences:
- **Import**: Use `BraveWebSearchTool` (or the news/image/video variant) instead of `BraveSearchTool`.
- **Query parameter**: Use `q` instead of `search_query`. (Both `search_query` and `query` are still accepted for convenience, but `q` is the preferred parameter.)
- **Result count**: Pass `count` as a query parameter instead of `n_results` at init time.
- **Country**: Pass `country` as a query parameter instead of at init time.
- **API key**: Can now be passed directly via `api_key=` in addition to the `BRAVE_API_KEY` environment variable.
- **Rate limiting**: Configurable via `requests_per_second` with automatic retry on `429` responses.
## Conclusion
By integrating the `BraveSearchTool` into Python projects, users gain the ability to conduct real-time, relevant searches across the internet directly from their applications. The tool provides a simple interface to the powerful Brave Search API, making it easy to retrieve and process search results programmatically. By adhering to the setup and usage guidelines provided, incorporating this tool into projects is streamlined and straightforward.
The Brave Search tool suite gives your CrewAI agents flexible, endpoint-specific access to the Brave Search API. Whether you need web pages, breaking news, images, or videos, there is a dedicated tool with validated parameters and built-in resilience. Pick the tool that fits your use case, and refer to the [Brave Search API documentation](https://brave.com/search/api/) for the full details on available parameters and response formats.

View File

@@ -10,7 +10,18 @@ from crewai_tools.aws.s3.writer_tool import S3WriterTool
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsDescriptionTool,
BraveLocalPOIsTool,
)
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
BrightDataDatasetTool,
)
@@ -200,7 +211,14 @@ __all__ = [
"ArxivPaperTool",
"BedrockInvokeAgentTool",
"BedrockKBRetrieverTool",
"BraveImageSearchTool",
"BraveLLMContextTool",
"BraveLocalPOIsDescriptionTool",
"BraveLocalPOIsTool",
"BraveNewsSearchTool",
"BraveSearchTool",
"BraveVideoSearchTool",
"BraveWebSearchTool",
"BrightDataDatasetTool",
"BrightDataSearchTool",
"BrightDataWebUnlockerTool",

View File

@@ -1,7 +1,18 @@
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsDescriptionTool,
BraveLocalPOIsTool,
)
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brightdata_tool import (
BrightDataDatasetTool,
BrightDataSearchTool,
@@ -185,7 +196,14 @@ __all__ = [
"AIMindTool",
"ApifyActorsTool",
"ArxivPaperTool",
"BraveImageSearchTool",
"BraveLLMContextTool",
"BraveLocalPOIsDescriptionTool",
"BraveLocalPOIsTool",
"BraveNewsSearchTool",
"BraveSearchTool",
"BraveVideoSearchTool",
"BraveWebSearchTool",
"BrightDataDatasetTool",
"BrightDataSearchTool",
"BrightDataWebUnlockerTool",

View File

@@ -0,0 +1,322 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
import json
import logging
import os
import threading
import time
from typing import Any, ClassVar
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
import requests
logger = logging.getLogger(__name__)
# Brave API error codes that indicate non-retryable quota/usage exhaustion.
_QUOTA_CODES = frozenset({"QUOTA_LIMITED", "USAGE_LIMIT_EXCEEDED"})
def _save_results_to_file(content: str) -> None:
"""Saves the search results to a file."""
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
with open(filename, "w") as file:
file.write(content)
def _parse_error_body(resp: requests.Response) -> dict[str, Any] | None:
"""Extract the structured "error" object from a Brave API error response."""
try:
body = resp.json()
error = body.get("error")
return error if isinstance(error, dict) else None
except (ValueError, KeyError):
return None
def _raise_for_error(resp: requests.Response) -> None:
"""Brave Search API error responses contain helpful JSON payloads"""
status = resp.status_code
try:
body = json.dumps(resp.json())
except (ValueError, KeyError):
body = resp.text[:500]
raise RuntimeError(f"Brave Search API error (HTTP {status}): {body}")
def _is_retryable(resp: requests.Response) -> bool:
"""Return True for transient failures that are worth retrying.
* 429 + RATE_LIMITED — the per-second sliding window is full.
* 5xx — transient server-side errors.
Quota exhaustion (QUOTA_LIMITED, USAGE_LIMIT_EXCEEDED) is
explicitly excluded: retrying will never succeed until the billing
period resets.
"""
if resp.status_code == 429:
error = _parse_error_body(resp) or {}
return error.get("code") not in _QUOTA_CODES
return 500 <= resp.status_code < 600
def _retry_delay(resp: requests.Response, attempt: int) -> float:
"""Compute wait time before the next retry attempt.
Prefers the server-supplied Retry-After header when available;
falls back to exponential backoff (1s, 2s, 4s, ...).
"""
retry_after = resp.headers.get("Retry-After")
if retry_after is not None:
try:
return max(0.0, float(retry_after))
except (ValueError, TypeError):
pass
return float(2**attempt)
class BraveSearchToolBase(BaseTool, ABC):
"""
Base class for Brave Search API interactions.
Individual tool subclasses must provide the following:
- search_url
- header_schema (pydantic model)
- args_schema (pydantic model)
- _refine_payload() -> dict[str, Any]
"""
search_url: str
raw: bool = False
args_schema: type[BaseModel]
header_schema: type[BaseModel]
# Tool options (legacy parameters)
country: str | None = None
save_file: bool = False
n_results: int = 10
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
name="BRAVE_API_KEY",
description="API key for Brave Search",
required=True,
),
]
)
def __init__(
self,
*,
api_key: str | None = None,
headers: dict[str, Any] | None = None,
requests_per_second: float = 1.0,
save_file: bool = False,
raw: bool = False,
timeout: int = 30,
**kwargs: Any,
):
super().__init__(**kwargs)
self._api_key = api_key or os.environ.get("BRAVE_API_KEY")
if not self._api_key:
raise ValueError("BRAVE_API_KEY environment variable is required")
self.raw = bool(raw)
self._timeout = int(timeout)
self.save_file = bool(save_file)
self._requests_per_second = float(requests_per_second)
self._headers = self._build_and_validate_headers(headers or {})
# Per-instance rate limiting: each instance has its own clock and lock.
# Total process rate is the sum of limits of instances you create.
self._last_request_time: float = 0
self._rate_limit_lock = threading.Lock()
@property
def api_key(self) -> str:
return self._api_key
@property
def headers(self) -> dict[str, Any]:
return self._headers
def set_headers(self, headers: dict[str, Any]) -> BraveSearchToolBase:
merged = {**self._headers, **{k.lower(): v for k, v in headers.items()}}
self._headers = self._build_and_validate_headers(merged)
return self
def _build_and_validate_headers(self, headers: dict[str, Any]) -> dict[str, Any]:
normalized = {k.lower(): v for k, v in headers.items()}
normalized.setdefault("x-subscription-token", self._api_key)
normalized.setdefault("accept", "application/json")
try:
self.header_schema(**normalized)
except Exception as e:
raise ValueError(f"Invalid headers: {e}") from e
return normalized
def _rate_limit(self) -> None:
"""Enforce minimum interval between requests for this instance. Thread-safe."""
if self._requests_per_second <= 0:
return
min_interval = 1.0 / self._requests_per_second
with self._rate_limit_lock:
now = time.time()
next_allowed = self._last_request_time + min_interval
if now < next_allowed:
time.sleep(next_allowed - now)
now = time.time()
self._last_request_time = now
def _make_request(
self, params: dict[str, Any], *, _max_retries: int = 3
) -> dict[str, Any]:
"""Execute an HTTP GET against the Brave Search API with retry logic."""
last_resp: requests.Response | None = None
# Retry the request up to _max_retries times
for attempt in range(_max_retries):
self._rate_limit()
# Make the request
try:
resp = requests.get(
self.search_url,
headers=self._headers,
params=params,
timeout=self._timeout,
)
except requests.ConnectionError as exc:
raise RuntimeError(
f"Brave Search API connection failed: {exc}"
) from exc
except requests.Timeout as exc:
raise RuntimeError(
f"Brave Search API request timed out after {self._timeout}s: {exc}"
) from exc
# Log the rate limit headers and request details
logger.debug(
"Brave Search API request: %s %s -> %d",
"GET",
resp.url,
resp.status_code,
)
# Response was OK, return the JSON body
if resp.ok:
try:
return resp.json()
except ValueError as exc:
raise RuntimeError(
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
) from exc
# Response was not OK, but is retryable
# (e.g., 429 Too Many Requests, 500 Internal Server Error)
if _is_retryable(resp) and attempt < _max_retries - 1:
delay = _retry_delay(resp, attempt)
logger.warning(
"Brave Search API returned %d. Retrying in %.1fs (attempt %d/%d)",
resp.status_code,
delay,
attempt + 1,
_max_retries,
)
time.sleep(delay)
last_resp = resp
continue
# Response was not OK, nor was it retryable
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
_raise_for_error(resp)
# All retries exhausted
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
return {} # unreachable (here to satisfy the type checker and linter)
def _run(self, q: str | None = None, **params: Any) -> Any:
# Allow positional usage: tool.run("latest Brave browser features")
if q is not None:
params["q"] = q
params = self._common_payload_refinement(params)
# Validate only schema fields
schema_keys = self.args_schema.model_fields
payload_in = {k: v for k, v in params.items() if k in schema_keys}
try:
validated = self.args_schema(**payload_in)
except Exception as e:
raise ValueError(f"Invalid parameters: {e}") from e
# The subclass may have additional refinements to apply to the payload, such as goggles or other parameters
payload = self._refine_request_payload(validated.model_dump(exclude_none=True))
response = self._make_request(payload)
if not self.raw:
response = self._refine_response(response)
if self.save_file:
_save_results_to_file(json.dumps(response, indent=2))
return response
@abstractmethod
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
"""Subclass must implement: transform validated params dict into API request params."""
raise NotImplementedError
@abstractmethod
def _refine_response(self, response: dict[str, Any]) -> Any:
"""Subclass must implement: transform response dict into a more useful format."""
raise NotImplementedError
_EMPTY_VALUES: ClassVar[tuple[None, str, str, list[Any]]] = (None, "", "null", [])
def _common_payload_refinement(self, params: dict[str, Any]) -> dict[str, Any]:
"""Common payload refinement for all tools."""
# crewAI's schema pipeline (ensure_all_properties_required in
# pydantic_schema_utils.py) marks every property as required so
# that OpenAI strict-mode structured outputs work correctly.
# The side-effect is that the LLM fills in *every* parameter —
# even truly optional ones — using placeholder values such as
# None, "", "null", or []. Only optional fields are affected,
# so we limit the check to those.
fields = self.args_schema.model_fields
params = {
k: v
for k, v in params.items()
# Permit custom and required fields, and fields with non-empty values
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
}
# Make sure params has "q" for query instead of "query" or "search_query"
query = params.get("query") or params.get("search_query")
if query is not None and "q" not in params:
params["q"] = query
params.pop("query", None)
params.pop("search_query", None)
# If "count" was not explicitly provided, use n_results
# (only when the schema actually supports a "count" field)
if "count" in self.args_schema.model_fields:
if "count" not in params and self.n_results is not None:
params["count"] = self.n_results
# If "country" was not explicitly provided, but self.country is set, use it
# (only when the schema actually supports a "country" field)
if "country" in self.args_schema.model_fields:
if "country" not in params and self.country is not None:
params["country"] = self.country
return params

View File

@@ -0,0 +1,42 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
ImageSearchHeaders,
ImageSearchParams,
)
class BraveImageSearchTool(BraveSearchToolBase):
"""A tool that performs image searches using the Brave Search API."""
name: str = "Brave Image Search"
args_schema: type[BaseModel] = ImageSearchParams
header_schema: type[BaseModel] = ImageSearchHeaders
description: str = (
"A tool that performs image searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/images/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"title": result.get("title"),
"url": result.get("properties", {}).get("url"),
"dimensions": f"{w}x{h}"
if (w := result.get("properties", {}).get("width"))
and (h := result.get("properties", {}).get("height"))
else None,
}
for result in results
]

View File

@@ -0,0 +1,32 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
from crewai_tools.tools.brave_search_tool.schemas import (
LLMContextHeaders,
LLMContextParams,
)
class BraveLLMContextTool(BraveSearchToolBase):
"""A tool that retrieves context for LLM usage from the Brave Search API."""
name: str = "Brave LLM Context"
args_schema: type[BaseModel] = LLMContextParams
header_schema: type[BaseModel] = LLMContextHeaders
description: str = (
"A tool that retrieves context for LLM usage from the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/llm/context"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
"""The LLM Context response schema is fairly simple. Return as is."""
return response

View File

@@ -0,0 +1,109 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.response_types import LocalPOIs
from crewai_tools.tools.brave_search_tool.schemas import (
LocalPOIsDescriptionHeaders,
LocalPOIsDescriptionParams,
LocalPOIsHeaders,
LocalPOIsParams,
)
DayOpeningHours = LocalPOIs.DayOpeningHours
OpeningHours = LocalPOIs.OpeningHours
LocationResult = LocalPOIs.LocationResult
LocalPOIsResponse = LocalPOIs.Response
def _flatten_slots(slots: list[DayOpeningHours]) -> list[dict[str, str]]:
"""Convert a list of DayOpeningHours dicts into simplified entries."""
return [
{
"day": slot["full_name"].lower(),
"opens": slot["opens"],
"closes": slot["closes"],
}
for slot in slots
]
def _simplify_opening_hours(result: LocationResult) -> list[dict[str, str]] | None:
"""Collapse opening_hours into a flat list of {day, opens, closes} dicts."""
hours = result.get("opening_hours")
if not hours:
return None
entries: list[dict[str, str]] = []
current = hours.get("current_day")
if current:
entries.extend(_flatten_slots(current))
days = hours.get("days")
if days:
for day_slots in days:
entries.extend(_flatten_slots(day_slots))
return entries or None
class BraveLocalPOIsTool(BraveSearchToolBase):
"""A tool that retrieves local POIs using the Brave Search API."""
name: str = "Brave Local POIs"
args_schema: type[BaseModel] = LocalPOIsParams
header_schema: type[BaseModel] = LocalPOIsHeaders
description: str = (
"A tool that retrieves local POIs using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/local/pois"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
results = response.get("results", [])
return [
{
"title": result.get("title"),
"url": result.get("url"),
"description": result.get("description"),
"address": result.get("postal_address", {}).get("displayAddress"),
"contact": result.get("contact", {}).get("telephone")
or result.get("contact", {}).get("email")
or None,
"opening_hours": _simplify_opening_hours(result),
}
for result in results
]
class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
"""A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API."""
name: str = "Brave Local POI Descriptions"
args_schema: type[BaseModel] = LocalPOIsDescriptionParams
header_schema: type[BaseModel] = LocalPOIsDescriptionHeaders
description: str = (
"A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/local/descriptions"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"id": result.get("id"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -0,0 +1,39 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
NewsSearchHeaders,
NewsSearchParams,
)
class BraveNewsSearchTool(BraveSearchToolBase):
"""A tool that performs news searches using the Brave Search API."""
name: str = "Brave News Search"
args_schema: type[BaseModel] = NewsSearchParams
header_schema: type[BaseModel] = NewsSearchHeaders
description: str = (
"A tool that performs news searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/news/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -10,17 +10,13 @@ from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
import requests
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
load_dotenv()
def _save_results_to_file(content: str) -> None:
"""Saves the search results to a file."""
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
with open(filename, "w") as file:
file.write(content)
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
FreshnessRange = Annotated[
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
@@ -29,51 +25,6 @@ Freshness = FreshnessPreset | FreshnessRange
SafeSearch = Literal["off", "moderate", "strict"]
class BraveSearchToolSchema(BaseModel):
"""Input for BraveSearchTool"""
query: str = Field(..., description="Search query to perform")
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
)
search_language: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
)
offset: int | None = Field(
default=None, description="Skip the first N result sets/pages. Max is 9."
)
safesearch: SafeSearch | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
text_decorations: bool | None = Field(
default=None,
description="Include markup to highlight search terms in the results.",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
class BraveSearchTool(BaseTool):
"""A tool that performs web searches using the Brave Search API."""
@@ -83,7 +34,7 @@ class BraveSearchTool(BaseTool):
"A tool that performs web searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
args_schema: type[BaseModel] = BraveSearchToolSchema
args_schema: type[BaseModel] = WebSearchParams
search_url: str = "https://api.search.brave.com/res/v1/web/search"
n_results: int = 10
save_file: bool = False
@@ -120,8 +71,8 @@ class BraveSearchTool(BaseTool):
# Construct and send the request
try:
# Maintain both "search_query" and "query" for backwards compatibility
query = kwargs.get("search_query") or kwargs.get("query")
# Fallback to "query" or "search_query" for backwards compatibility
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
if not query:
raise ValueError("Query is required")
@@ -130,8 +81,11 @@ class BraveSearchTool(BaseTool):
if country := kwargs.get("country"):
payload["country"] = country
if search_language := kwargs.get("search_language"):
payload["search_language"] = search_language
# Fallback to "search_language" for backwards compatibility
if search_lang := kwargs.get("search_lang") or kwargs.get(
"search_language"
):
payload["search_lang"] = search_lang
# Fallback to deprecated n_results parameter if no count is provided
count = kwargs.get("count")

View File

@@ -0,0 +1,39 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
VideoSearchHeaders,
VideoSearchParams,
)
class BraveVideoSearchTool(BraveSearchToolBase):
"""A tool that performs video searches using the Brave Search API."""
name: str = "Brave Video Search"
args_schema: type[BaseModel] = VideoSearchParams
header_schema: type[BaseModel] = VideoSearchHeaders
description: str = (
"A tool that performs video searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/videos/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
return [
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
}
for result in results
]

View File

@@ -0,0 +1,45 @@
from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.schemas import (
WebSearchHeaders,
WebSearchParams,
)
class BraveWebSearchTool(BraveSearchToolBase):
"""A tool that performs web searches using the Brave Search API."""
name: str = "Brave Web Search"
args_schema: type[BaseModel] = WebSearchParams
header_schema: type[BaseModel] = WebSearchHeaders
description: str = (
"A tool that performs web searches using the Brave Search API. "
"Results are returned as structured JSON data."
)
search_url: str = "https://api.search.brave.com/res/v1/web/search"
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
results = response.get("web", {}).get("results", [])
refined = []
for result in results:
snippets = result.get("extra_snippets") or []
if not snippets:
desc = result.get("description")
if desc:
snippets = [desc]
refined.append(
{
"url": result.get("url"),
"title": result.get("title"),
"snippets": snippets,
}
)
return refined

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from typing import Literal, TypedDict
class LocalPOIs:
class PostalAddress(TypedDict, total=False):
type: Literal["PostalAddress"]
country: str
postalCode: str
streetAddress: str
addressRegion: str
addressLocality: str
displayAddress: str
class DayOpeningHours(TypedDict):
abbr_name: str
full_name: str
opens: str
closes: str
class OpeningHours(TypedDict, total=False):
current_day: list[LocalPOIs.DayOpeningHours]
days: list[list[LocalPOIs.DayOpeningHours]]
class LocationResult(TypedDict, total=False):
provider_url: str
title: str
url: str
id: str | None
opening_hours: LocalPOIs.OpeningHours | None
postal_address: LocalPOIs.PostalAddress | None
class Response(TypedDict, total=False):
type: Literal["local_pois"]
results: list[LocalPOIs.LocationResult]
class LLMContext:
class LLMContextItem(TypedDict, total=False):
snippets: list[str]
title: str
url: str
class LLMContextMapItem(TypedDict, total=False):
name: str
snippets: list[str]
title: str
url: str
class LLMContextPOIItem(TypedDict, total=False):
name: str
snippets: list[str]
title: str
url: str
class Grounding(TypedDict, total=False):
generic: list[LLMContext.LLMContextItem]
poi: LLMContext.LLMContextPOIItem
map: list[LLMContext.LLMContextMapItem]
class Sources(TypedDict, total=False):
pass
class Response(TypedDict, total=False):
grounding: LLMContext.Grounding
sources: LLMContext.Sources

View File

@@ -0,0 +1,525 @@
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
# Common types
Units = Literal["metric", "imperial"]
SafeSearch = Literal["off", "moderate", "strict"]
Freshness = (
Literal["pd", "pw", "pm", "py"]
| Annotated[
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
]
)
ResultFilter = list[
Literal[
"discussions",
"faq",
"infobox",
"news",
"query",
"summarizer",
"videos",
"web",
"locations",
]
]
class LLMContextParams(BaseModel):
"""Parameters for Brave LLM Context endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
ge=1,
le=50,
)
maximum_number_of_urls: int | None = Field(
default=None,
description="The maximum number of URLs to include in the context.",
ge=1,
le=50,
)
maximum_number_of_tokens: int | None = Field(
default=None,
description="The approximate maximum number of tokens to include in the context.",
ge=1,
le=32768,
)
maximum_number_of_snippets: int | None = Field(
default=None,
description="The maximum number of different snippets to include in the context.",
ge=1,
le=100,
)
context_threshold_mode: (
Literal["disabled", "strict", "lenient", "balanced"] | None
) = Field(
default=None,
description="The mode to use for the context thresholding.",
)
maximum_number_of_tokens_per_url: int | None = Field(
default=None,
description="The maximum number of tokens to include for each URL in the context.",
ge=1,
le=8192,
)
maximum_number_of_snippets_per_url: int | None = Field(
default=None,
description="The maximum number of snippets to include per URL.",
ge=1,
le=100,
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
enable_local: bool | None = Field(
default=None,
description="Whether to enable local recall. Not setting this value means auto-detect and uses local recall if any of the localization headers are provided.",
)
class WebSearchParams(BaseModel):
"""Parameters for Brave Web Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return. Actual number may be less.",
ge=1,
le=20,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
safesearch: Literal["off", "moderate", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
text_decorations: bool | None = Field(
default=None,
description="Include markup to highlight search terms in the results.",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
result_filter: ResultFilter | None = Field(
default=None,
description="Filter the results by type. Options: discussions/faq/infobox/news/query/summarizer/videos/web/locations. Note: The `count` parameter is applied only to the `web` results.",
)
units: Units | None = Field(
default=None,
description="The units to use for the results. Options: metric/imperial",
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
summary: bool | None = Field(
default=None,
description="Whether to generate a summarizer ID for the results.",
)
enable_rich_callback: bool | None = Field(
default=None,
description="Whether to enable rich callbacks for the results. Requires Pro level subscription.",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class LocalPOIsParams(BaseModel):
"""Parameters for Brave Local POIs endpoint."""
ids: list[str] = Field(
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
min_length=1,
max_length=20,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
units: Units | None = Field(
default=None,
description="The units to use for the results. Options: metric/imperial",
)
class LocalPOIsDescriptionParams(BaseModel):
"""Parameters for Brave Local POI Descriptions endpoint."""
ids: list[str] = Field(
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
min_length=1,
max_length=20,
)
class ImageSearchParams(BaseModel):
"""Parameters for Brave Image Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: Literal["off", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Default is strict.",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=200,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
class VideoSearchParams(BaseModel):
"""Parameters for Brave Video Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: SafeSearch | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=50,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class NewsSearchParams(BaseModel):
"""Parameters for Brave News Search endpoint."""
q: str = Field(
description="Search query to perform",
min_length=1,
max_length=400,
)
search_lang: str | None = Field(
default=None,
description="Language code for the search results (e.g., 'en', 'es').",
pattern=r"^[a-z]{2}$",
)
ui_lang: str | None = Field(
default=None,
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
pattern=r"^[a-z]{2}-[A-Z]{2}$",
)
country: str | None = Field(
default=None,
description="Country code for geo-targeting (e.g., 'US', 'BR').",
pattern=r"^[A-Z]{2}$",
)
safesearch: Literal["off", "moderate", "strict"] | None = Field(
default=None,
description="Filter out explicit content. Options: off/moderate/strict",
)
count: int | None = Field(
default=None,
description="The maximum number of results to return.",
ge=1,
le=50,
)
offset: int | None = Field(
default=None,
description="Skip the first N result sets/pages. Max is 9.",
ge=0,
le=9,
)
spellcheck: bool | None = Field(
default=None,
description="Attempt to correct spelling errors in the search query.",
)
freshness: Freshness | None = Field(
default=None,
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
)
extra_snippets: bool | None = Field(
default=None,
description="Include up to 5 text snippets for each page if possible.",
)
goggles: str | list[str] | None = Field(
default=None,
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
)
include_fetch_metadata: bool | None = Field(
default=None,
description="Whether to include fetch metadata in the results.",
)
operators: bool | None = Field(
default=None,
description="Whether to apply search operators (e.g., site:example.com).",
)
class BaseSearchHeaders(BaseModel):
"""Common headers for Brave Search endpoints."""
x_subscription_token: str = Field(
alias="x-subscription-token",
description="API key for Brave Search",
)
api_version: str | None = Field(
alias="api-version",
default=None,
description="API version to use. Default is latest available.",
pattern=r"^\d{4}-\d{2}-\d{2}$", # YYYY-MM-DD
)
accept: Literal["application/json"] | Literal["*/*"] | None = Field(
default=None,
description="Accept header for the request.",
)
cache_control: Literal["no-cache"] | None = Field(
alias="cache-control",
default=None,
description="Cache control header for the request.",
)
user_agent: str | None = Field(
alias="user-agent",
default=None,
description="User agent for the request.",
)
class LLMContextHeaders(BaseSearchHeaders):
"""Headers for Brave LLM Context endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
x_loc_city: str | None = Field(
alias="x-loc-city",
default=None,
description="City of the user's location.",
)
x_loc_state: str | None = Field(
alias="x-loc-state",
default=None,
description="State of the user's location.",
)
x_loc_state_name: str | None = Field(
alias="x-loc-state-name",
default=None,
description="Name of the state of the user's location.",
)
x_loc_country: str | None = Field(
alias="x-loc-country",
default=None,
description="The ISO 3166-1 alpha-2 country code of the user's location.",
)
class LocalPOIsHeaders(BaseSearchHeaders):
"""Headers for Brave Local POIs endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
class LocalPOIsDescriptionHeaders(BaseSearchHeaders):
"""Headers for Brave Local POI Descriptions endpoint."""
class VideoSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Video Search endpoint."""
class ImageSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Image Search endpoint."""
class NewsSearchHeaders(BaseSearchHeaders):
"""Headers for Brave News Search endpoint."""
class WebSearchHeaders(BaseSearchHeaders):
"""Headers for Brave Web Search endpoint."""
x_loc_lat: float | None = Field(
alias="x-loc-lat",
default=None,
description="Latitude of the user's location.",
ge=-90.0,
le=90.0,
)
x_loc_long: float | None = Field(
alias="x-loc-long",
default=None,
description="Longitude of the user's location.",
ge=-180.0,
le=180.0,
)
x_loc_timezone: str | None = Field(
alias="x-loc-timezone",
default=None,
description="Timezone of the user's location.",
)
x_loc_city: str | None = Field(
alias="x-loc-city",
default=None,
description="City of the user's location.",
)
x_loc_state: str | None = Field(
alias="x-loc-state",
default=None,
description="State of the user's location.",
)
x_loc_state_name: str | None = Field(
alias="x-loc-state-name",
default=None,
description="Name of the state of the user's location.",
)
x_loc_country: str | None = Field(
alias="x-loc-country",
default=None,
description="The ISO 3166-1 alpha-2 country code of the user's location.",
)
x_loc_postal_code: str | None = Field(
alias="x-loc-postal-code",
default=None,
description="The postal code of the user's location.",
)

View File

@@ -1,80 +1,777 @@
import json
from unittest.mock import patch
import os
from unittest.mock import MagicMock, patch
import pytest
import requests as requests_lib
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
BraveLLMContextTool,
)
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
BraveLocalPOIsTool,
BraveLocalPOIsDescriptionTool,
)
from crewai_tools.tools.brave_search_tool.schemas import (
WebSearchParams,
WebSearchHeaders,
ImageSearchParams,
ImageSearchHeaders,
NewsSearchParams,
NewsSearchHeaders,
VideoSearchParams,
VideoSearchHeaders,
LLMContextParams,
LLMContextHeaders,
LocalPOIsParams,
LocalPOIsHeaders,
LocalPOIsDescriptionParams,
LocalPOIsDescriptionHeaders,
)
def _mock_response(
status_code: int = 200,
json_data: dict | None = None,
headers: dict | None = None,
text: str = "",
) -> MagicMock:
"""Build a ``requests.Response``-like mock with the attributes used by ``_make_request``."""
resp = MagicMock(spec=requests_lib.Response)
resp.status_code = status_code
resp.ok = 200 <= status_code < 400
resp.url = "https://api.search.brave.com/res/v1/web/search?q=test"
resp.text = text or (str(json_data) if json_data else "")
resp.headers = headers or {}
resp.json.return_value = json_data if json_data is not None else {}
return resp
# Fixtures
@pytest.fixture(autouse=True)
def _brave_env_and_rate_limit():
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
with patch.dict(os.environ, {"BRAVE_API_KEY": "test-api-key"}):
yield
@pytest.fixture
def brave_tool():
return BraveSearchTool(n_results=2)
def web_tool():
return BraveWebSearchTool()
def test_brave_tool_initialization():
tool = BraveSearchTool()
assert tool.n_results == 10
@pytest.fixture
def image_tool():
return BraveImageSearchTool()
@pytest.fixture
def news_tool():
return BraveNewsSearchTool()
@pytest.fixture
def video_tool():
return BraveVideoSearchTool()
# Initialization
ALL_TOOL_CLASSES = [
BraveWebSearchTool,
BraveImageSearchTool,
BraveNewsSearchTool,
BraveVideoSearchTool,
BraveLLMContextTool,
BraveLocalPOIsTool,
BraveLocalPOIsDescriptionTool,
]
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
def test_instantiation_with_env_var(tool_cls):
"""Each tool can be created when BRAVE_API_KEY is in the environment."""
tool = tool_cls()
assert tool.api_key == "test-api-key"
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
def test_instantiation_with_explicit_key(tool_cls):
"""An explicit api_key takes precedence over the environment."""
tool = tool_cls(api_key="explicit-key")
assert tool.api_key == "explicit-key"
def test_missing_api_key_raises():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="BRAVE_API_KEY"):
BraveWebSearchTool()
def test_default_attributes():
tool = BraveWebSearchTool()
assert tool.save_file is False
assert tool.n_results == 10
assert tool._timeout == 30
assert tool._requests_per_second == 1.0
assert tool.raw is False
@patch("requests.get")
def test_brave_tool_search(mock_get, brave_tool):
mock_response = {
def test_custom_constructor_args():
tool = BraveWebSearchTool(
save_file=True,
timeout=60,
n_results=5,
requests_per_second=0.5,
raw=True,
)
assert tool.save_file is True
assert tool._timeout == 60
assert tool.n_results == 5
assert tool._requests_per_second == 0.5
assert tool.raw is True
# Headers
def test_default_headers():
tool = BraveWebSearchTool()
assert tool.headers["x-subscription-token"] == "test-api-key"
assert tool.headers["accept"] == "application/json"
def test_set_headers_merges_and_normalizes():
tool = BraveWebSearchTool()
tool.set_headers({"Cache-Control": "no-cache"})
assert tool.headers["cache-control"] == "no-cache"
assert tool.headers["x-subscription-token"] == "test-api-key"
def test_set_headers_returns_self_for_chaining():
tool = BraveWebSearchTool()
assert tool.set_headers({"Cache-Control": "no-cache"}) is tool
def test_invalid_header_value_raises():
tool = BraveImageSearchTool()
with pytest.raises(ValueError, match="Invalid headers"):
tool.set_headers({"Accept": "text/xml"})
# Endpoint & Schema Wiring
@pytest.mark.parametrize(
"tool_cls, expected_url, expected_params, expected_headers",
[
(
BraveWebSearchTool,
"https://api.search.brave.com/res/v1/web/search",
WebSearchParams,
WebSearchHeaders,
),
(
BraveImageSearchTool,
"https://api.search.brave.com/res/v1/images/search",
ImageSearchParams,
ImageSearchHeaders,
),
(
BraveNewsSearchTool,
"https://api.search.brave.com/res/v1/news/search",
NewsSearchParams,
NewsSearchHeaders,
),
(
BraveVideoSearchTool,
"https://api.search.brave.com/res/v1/videos/search",
VideoSearchParams,
VideoSearchHeaders,
),
(
BraveLLMContextTool,
"https://api.search.brave.com/res/v1/llm/context",
LLMContextParams,
LLMContextHeaders,
),
(
BraveLocalPOIsTool,
"https://api.search.brave.com/res/v1/local/pois",
LocalPOIsParams,
LocalPOIsHeaders,
),
(
BraveLocalPOIsDescriptionTool,
"https://api.search.brave.com/res/v1/local/descriptions",
LocalPOIsDescriptionParams,
LocalPOIsDescriptionHeaders,
),
],
)
def test_tool_wiring(tool_cls, expected_url, expected_params, expected_headers):
tool = tool_cls()
assert tool.search_url == expected_url
assert tool.args_schema is expected_params
assert tool.header_schema is expected_headers
# Payload Refinement (e.g., `query` -> `q`, `count` fallback, param pass-through)
def test_web_refine_request_payload_passes_all_params(web_tool):
params = web_tool._common_payload_refinement(
{
"query": "test",
"country": "US",
"search_lang": "en",
"count": 5,
"offset": 2,
"safesearch": "moderate",
"freshness": "pw",
}
)
refined_params = web_tool._refine_request_payload(params)
assert refined_params["q"] == "test"
assert "query" not in refined_params
assert refined_params["count"] == 5
assert refined_params["country"] == "US"
assert refined_params["search_lang"] == "en"
assert refined_params["offset"] == 2
assert refined_params["safesearch"] == "moderate"
assert refined_params["freshness"] == "pw"
def test_image_refine_request_payload_passes_all_params(image_tool):
params = image_tool._common_payload_refinement(
{
"query": "cat photos",
"country": "US",
"search_lang": "en",
"safesearch": "strict",
"count": 50,
"spellcheck": True,
}
)
refined_params = image_tool._refine_request_payload(params)
assert refined_params["q"] == "cat photos"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["safesearch"] == "strict"
assert refined_params["count"] == 50
assert refined_params["spellcheck"] is True
def test_news_refine_request_payload_passes_all_params(news_tool):
params = news_tool._common_payload_refinement(
{
"query": "breaking news",
"country": "US",
"count": 10,
"offset": 1,
"freshness": "pd",
"extra_snippets": True,
}
)
refined_params = news_tool._refine_request_payload(params)
assert refined_params["q"] == "breaking news"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["offset"] == 1
assert refined_params["freshness"] == "pd"
assert refined_params["extra_snippets"] is True
def test_video_refine_request_payload_passes_all_params(video_tool):
params = video_tool._common_payload_refinement(
{
"query": "tutorial",
"country": "US",
"count": 25,
"offset": 0,
"safesearch": "strict",
"freshness": "pm",
}
)
refined_params = video_tool._refine_request_payload(params)
assert refined_params["q"] == "tutorial"
assert "query" not in refined_params
assert refined_params["country"] == "US"
assert refined_params["offset"] == 0
assert refined_params["freshness"] == "pm"
def test_legacy_constructor_params_flow_into_query_params():
"""The legacy n_results and country constructor params are applied as defaults
when count/country are not explicitly provided at call time."""
tool = BraveWebSearchTool(n_results=3, country="BR")
params = tool._common_payload_refinement({"query": "test"})
assert params["count"] == 3
assert params["country"] == "BR"
def test_legacy_constructor_params_do_not_override_explicit_query_params():
"""Explicit query-time count/country take precedence over constructor defaults."""
tool = BraveWebSearchTool(n_results=3, country="BR")
params = tool._common_payload_refinement(
{"query": "test", "count": 10, "country": "US"}
)
assert params["count"] == 10
assert params["country"] == "US"
def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_tool):
result = web_tool._refine_request_payload(
{
"query": "test",
"goggles": ["goggle1", "goggle2"],
}
)
assert result["goggles"] == ["goggle1", "goggle2"]
# Null-like / empty value stripping
#
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
# every schema property as required for OpenAI strict-mode compatibility.
# Because optional Brave API parameters look required to the LLM, it fills
# them with placeholder junk — None, "", "null", or []. The test below
# verifies that _common_payload_refinement strips these from optional fields.
def test_common_refinement_strips_null_like_values(web_tool):
"""_common_payload_refinement drops optional keys with None / '' / 'null' / []."""
params = web_tool._common_payload_refinement(
{
"query": "test",
"country": "US",
"search_lang": "",
"freshness": "null",
"count": 5,
"goggles": [],
}
)
assert params["q"] == "test"
assert params["country"] == "US"
assert params["count"] == 5
assert "search_lang" not in params
assert "freshness" not in params
assert "goggles" not in params
# End-to-End _run() with Mocked HTTP Response
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_web_search_end_to_end(mock_get, web_tool):
web_tool.raw = True
data = {"web": {"results": [{"title": "R", "url": "http://r.co"}]}}
mock_get.return_value = _mock_response(json_data=data)
result = web_tool._run(query="test")
mock_get.assert_called_once()
call_args = mock_get.call_args.kwargs
assert call_args["params"]["q"] == "test"
assert call_args["headers"]["x-subscription-token"] == "test-api-key"
assert result == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_image_search_end_to_end(mock_get, image_tool):
image_tool.raw = True
data = {"results": [{"url": "http://img.co/a.jpg"}]}
mock_get.return_value = _mock_response(json_data=data)
assert image_tool._run(query="cats") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_news_search_end_to_end(mock_get, news_tool):
news_tool.raw = True
data = {"results": [{"title": "News", "url": "http://n.co"}]}
mock_get.return_value = _mock_response(json_data=data)
assert news_tool._run(query="headlines") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_video_search_end_to_end(mock_get, video_tool):
video_tool.raw = True
data = {"results": [{"title": "Vid", "url": "http://v.co"}]}
mock_get.return_value = _mock_response(json_data=data)
assert video_tool._run(query="python tutorial") == data
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_raw_false_calls_refine_response(mock_get, web_tool):
"""With raw=False (the default), _refine_response transforms the API response."""
api_response = {
"web": {
"results": [
{
"title": "Test Title",
"url": "http://test.com",
"description": "Test Description",
"title": "CrewAI",
"url": "https://crewai.com",
"description": "AI agent framework",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
mock_get.return_value = _mock_response(json_data=api_response)
result = brave_tool.run(query="test")
data = json.loads(result)
assert isinstance(data, list)
assert len(data) >= 1
assert data[0]["title"] == "Test Title"
assert data[0]["url"] == "http://test.com"
assert web_tool.raw is False
result = web_tool._run(query="crewai")
# The web tool's _refine_response extracts and reshapes results.
# The key assertion: we should NOT get back the raw API envelope.
assert result != api_response
@patch("requests.get")
def test_brave_tool(mock_get):
mock_response = {
"web": {
"results": [
{
"title": "Brave Browser",
"url": "https://brave.com",
"description": "Brave Browser description",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
tool = BraveSearchTool(n_results=2)
result = tool.run(query="Brave Browser")
assert result is not None
# Parse JSON so we can examine the structure
data = json.loads(result)
assert isinstance(data, list)
assert len(data) >= 1
# First item should have expected fields: title, url, and description
first = data[0]
assert "title" in first
assert first["title"] == "Brave Browser"
assert "url" in first
assert first["url"] == "https://brave.com"
assert "description" in first
assert first["description"] == "Brave Browser description"
# Backward Compatibility & Legacy Parameter Support
if __name__ == "__main__":
test_brave_tool()
test_brave_tool_initialization()
# test_brave_tool_search(brave_tool)
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_positional_query_argument(mock_get, web_tool):
"""tool.run('my query') works as a positional argument."""
mock_get.return_value = _mock_response(json_data={})
web_tool._run("positional test")
assert mock_get.call_args.kwargs["params"]["q"] == "positional test"
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_search_query_backward_compat(mock_get, web_tool):
"""The legacy 'search_query' param is mapped to 'query'."""
mock_get.return_value = _mock_response(json_data={})
web_tool._run(search_query="legacy test")
assert mock_get.call_args.kwargs["params"]["q"] == "legacy test"
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base._save_results_to_file")
def test_save_file_called_when_enabled(mock_save, mock_get):
mock_get.return_value = _mock_response(json_data={"results": []})
tool = BraveWebSearchTool(save_file=True)
tool._run(query="test")
mock_save.assert_called_once()
# Error Handling
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_connection_error_raises_runtime_error(mock_get, web_tool):
mock_get.side_effect = requests_lib.exceptions.ConnectionError("refused")
with pytest.raises(RuntimeError, match="Brave Search API connection failed"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_timeout_raises_runtime_error(mock_get, web_tool):
mock_get.side_effect = requests_lib.exceptions.Timeout("timed out")
with pytest.raises(RuntimeError, match="timed out"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_invalid_params_raises_value_error(mock_get, web_tool):
"""count=999 exceeds WebSearchParams.count le=20."""
with pytest.raises(ValueError, match="Invalid parameters"):
web_tool._run(query="test", count=999)
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_4xx_error_raises_with_api_detail(mock_get, web_tool):
"""A 422 with a structured error body includes code and detail in the message."""
mock_get.return_value = _mock_response(
status_code=422,
json_data={
"error": {
"id": "abc-123",
"status": 422,
"code": "OPTION_NOT_IN_PLAN",
"detail": "extra_snippets requires a Pro plan",
}
},
)
with pytest.raises(RuntimeError, match="OPTION_NOT_IN_PLAN") as exc_info:
web_tool._run(query="test")
assert "extra_snippets requires a Pro plan" in str(exc_info.value)
assert "HTTP 422" in str(exc_info.value)
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_auth_error_raises_immediately(mock_get, web_tool):
"""A 401 with SUBSCRIPTION_TOKEN_INVALID is not retried."""
mock_get.return_value = _mock_response(
status_code=401,
json_data={
"error": {
"id": "xyz",
"status": 401,
"code": "SUBSCRIPTION_TOKEN_INVALID",
"detail": "The subscription token is invalid",
}
},
)
with pytest.raises(RuntimeError, match="SUBSCRIPTION_TOKEN_INVALID"):
web_tool._run(query="test")
# Should NOT have retried — only one call.
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_quota_limited_429_raises_immediately(mock_get, web_tool):
"""A 429 with QUOTA_LIMITED is NOT retried — quota exhaustion is terminal."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "ql-1",
"status": 429,
"code": "QUOTA_LIMITED",
"detail": "Monthly quota exceeded",
}
},
)
with pytest.raises(RuntimeError, match="QUOTA_LIMITED") as exc_info:
web_tool._run(query="test")
assert "Monthly quota exceeded" in str(exc_info.value)
# Terminal — only one HTTP call, no retries.
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_usage_limit_exceeded_429_raises_immediately(mock_get, web_tool):
"""USAGE_LIMIT_EXCEEDED is also non-retryable, just like QUOTA_LIMITED."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "ule-1",
"status": 429,
"code": "USAGE_LIMIT_EXCEEDED",
}
},
text="usage limit exceeded",
)
with pytest.raises(RuntimeError, match="USAGE_LIMIT_EXCEEDED"):
web_tool._run(query="test")
assert mock_get.call_count == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_error_body_is_fully_included_in_message(mock_get, web_tool):
"""The full JSON error body is included in the RuntimeError message."""
mock_get.return_value = _mock_response(
status_code=429,
json_data={
"error": {
"id": "x",
"status": 429,
"code": "QUOTA_LIMITED",
"detail": "Exceeded",
"meta": {"plan": "free", "limit": 1000},
}
},
)
with pytest.raises(RuntimeError) as exc_info:
web_tool._run(query="test")
msg = str(exc_info.value)
assert "HTTP 429" in msg
assert "QUOTA_LIMITED" in msg
assert "free" in msg
assert "1000" in msg
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_error_without_json_body_falls_back_to_text(mock_get, web_tool):
"""When the error response isn't valid JSON, resp.text is used as the detail."""
resp = _mock_response(status_code=500, text="Internal Server Error")
resp.json.side_effect = ValueError("No JSON")
mock_get.return_value = resp
with pytest.raises(RuntimeError, match="Internal Server Error"):
web_tool._run(query="test")
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
def test_invalid_json_on_success_raises_runtime_error(mock_get, web_tool):
"""A 200 OK with a non-JSON body raises RuntimeError."""
resp = _mock_response(status_code=200)
resp.json.side_effect = ValueError("Expecting value")
mock_get.return_value = resp
with pytest.raises(RuntimeError, match="invalid JSON"):
web_tool._run(query="test")
# Rate Limiting
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_sleeps_when_too_fast(mock_time, mock_get, web_tool):
"""Back-to-back calls within the interval trigger a sleep."""
mock_get.return_value = _mock_response(json_data={})
# Simulate: last request was at t=100, "now" is t=100.2 (only 0.2s elapsed).
# With default 1 req/s the min interval is 1.0s, so it should sleep ~0.8s.
mock_time.time.return_value = 100.2
web_tool._last_request_time = 100.0
web_tool._run(query="test")
mock_time.sleep.assert_called_once()
sleep_duration = mock_time.sleep.call_args[0][0]
assert 0.7 < sleep_duration < 0.9 # approximately 0.8s
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_skips_sleep_when_enough_time_passed(mock_time, mock_get, web_tool):
"""No sleep when the elapsed time already exceeds the interval."""
mock_get.return_value = _mock_response(json_data={})
# Last request was at t=100, "now" is t=102 (2s elapsed > 1s interval).
mock_time.time.return_value = 102.0
web_tool._last_request_time = 100.0
web_tool._run(query="test")
mock_time.sleep.assert_not_called()
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_disabled_when_zero(mock_time, mock_get, web_tool):
"""requests_per_second=0 disables rate limiting entirely."""
mock_get.return_value = _mock_response(json_data={})
web_tool._last_request_time = 100.0
mock_time.time.return_value = 100.0 # same instant
web_tool._run(query="test")
mock_time.sleep.assert_not_called()
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_rate_limit_per_instance_independent(mock_time, mock_get, web_tool, image_tool):
"""Each instance has its own rate-limit clock; a request on one does not delay the other."""
mock_get.return_value = _mock_response(json_data={})
# Web tool fires at t=100 (its clock goes 0 -> 100).
mock_time.time.return_value = 100.0
web_tool._run(query="test")
# Image tool fires at t=100.3. Its clock is still 0 (separate instance), so
# next_allowed = 1.0 and 100.3 > 1.0 — no sleep. Total process rate can be sum of instance limits.
mock_time.time.return_value = 100.3
image_tool._run(query="cats")
mock_time.sleep.assert_not_called()
# Retry Behavior
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_429_rate_limited_retries_then_succeeds(mock_time, mock_get, web_tool):
"""A transient RATE_LIMITED 429 is retried; success on the second attempt."""
mock_time.time.return_value = 200.0
resp_429 = _mock_response(
status_code=429,
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
headers={"Retry-After": "2"},
)
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
mock_get.side_effect = [resp_429, resp_200]
web_tool.raw = True
result = web_tool._run(query="test")
assert result == {"web": {"results": []}}
assert mock_get.call_count == 2
# Slept for the Retry-After value.
retry_sleeps = [c for c in mock_time.sleep.call_args_list if c[0][0] == 2.0]
assert len(retry_sleeps) == 1
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_5xx_is_retried(mock_time, mock_get, web_tool):
"""A 502 server error is retried; success on the second attempt."""
mock_time.time.return_value = 200.0
resp_502 = _mock_response(status_code=502, text="Bad Gateway")
resp_502.json.side_effect = ValueError("no json")
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
mock_get.side_effect = [resp_502, resp_200]
web_tool.raw = True
result = web_tool._run(query="test")
assert result == {"web": {"results": []}}
assert mock_get.call_count == 2
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_429_rate_limited_exhausts_retries(mock_time, mock_get, web_tool):
"""Persistent RATE_LIMITED 429s exhaust retries and raise RuntimeError."""
mock_time.time.return_value = 200.0
resp_429 = _mock_response(
status_code=429,
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
)
mock_get.return_value = resp_429
with pytest.raises(RuntimeError, match="RATE_LIMITED"):
web_tool._run(query="test")
# 3 attempts (default _max_retries).
assert mock_get.call_count == 3
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
@patch("crewai_tools.tools.brave_search_tool.base.time")
def test_retry_uses_exponential_backoff_when_no_retry_after(
mock_time, mock_get, web_tool
):
"""Without Retry-After, backoff is 2^attempt (1s, 2s, ...)."""
mock_time.time.return_value = 200.0
resp_503 = _mock_response(status_code=503, text="Service Unavailable")
resp_503.json.side_effect = ValueError("no json")
resp_200 = _mock_response(status_code=200, json_data={"ok": True})
mock_get.side_effect = [resp_503, resp_503, resp_200]
web_tool.raw = True
web_tool._run(query="test")
# Two retries: attempt 0 → sleep(1.0), attempt 1 → sleep(2.0).
retry_sleeps = [c[0][0] for c in mock_time.sleep.call_args_list]
assert 1.0 in retry_sleeps
assert 2.0 in retry_sleeps

File diff suppressed because it is too large Load Diff

View File

@@ -8,8 +8,8 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
import asyncio
import contextvars
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from datetime import datetime
import inspect
import json

View File

@@ -497,6 +497,50 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._list)
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
if stop is None:
return self._list.index(value, start)
return self._list.index(value, start, stop)
def count(self, value: T) -> int:
return self._list.count(value)
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
with self._lock:
self._list.sort(key=key, reverse=reverse)
def reverse(self) -> None:
with self._lock:
self._list.reverse()
def copy(self) -> list[T]:
return self._list.copy()
def __add__(self, other: list[T]) -> list[T]:
return self._list + other
def __radd__(self, other: list[T]) -> list[T]:
return other + self._list
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
with self._lock:
self._list += list(other)
return self
def __mul__(self, n: SupportsIndex) -> list[T]:
return self._list * n
def __rmul__(self, n: SupportsIndex) -> list[T]:
return self._list * n
def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]:
with self._lock:
self._list *= n
return self
def __reversed__(self) -> Iterator[T]:
return reversed(self._list)
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying list contents."""
if isinstance(other, LockedListProxy):
@@ -579,6 +623,23 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._dict)
def copy(self) -> dict[str, T]:
return self._dict.copy()
def __or__(self, other: dict[str, T]) -> dict[str, T]:
return self._dict | other
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
return other | self._dict
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
with self._lock:
self._dict |= other
return self
def __reversed__(self) -> Iterator[str]:
return reversed(self._dict)
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying dict contents."""
if isinstance(other, LockedDictProxy):
@@ -620,6 +681,10 @@ class StateProxy(Generic[T]):
if name in ("_proxy_state", "_proxy_lock"):
object.__setattr__(self, name, value)
else:
if isinstance(value, LockedListProxy):
value = value._list
elif isinstance(value, LockedDictProxy):
value = value._dict
with object.__getattribute__(self, "_proxy_lock"):
setattr(object.__getattribute__(self, "_proxy_state"), name, value)

View File

@@ -1893,3 +1893,163 @@ def test_or_condition_self_listen_fires_once():
flow = OrSelfListenFlow()
flow.kickoff()
assert call_count == 1
class ListState(BaseModel):
items: list = []
class DictState(BaseModel):
data: dict = {}
class _ListFlow(Flow[ListState]):
@start()
def populate(self):
self.state.items = [3, 1, 4, 1, 5, 9, 2, 6]
class _DictFlow(Flow[DictState]):
@start()
def populate(self):
self.state.data = {"a": 1, "b": 2, "c": 3}
def _make_list_flow():
flow = _ListFlow()
flow.kickoff()
return flow
def _make_dict_flow():
flow = _DictFlow()
flow.kickoff()
return flow
def test_locked_list_proxy_index():
flow = _make_list_flow()
assert flow.state.items.index(4) == 2
assert flow.state.items.index(1, 2) == 3
def test_locked_list_proxy_index_missing_raises():
flow = _make_list_flow()
with pytest.raises(ValueError):
flow.state.items.index(999)
def test_locked_list_proxy_count():
flow = _make_list_flow()
assert flow.state.items.count(1) == 2
assert flow.state.items.count(999) == 0
def test_locked_list_proxy_sort():
flow = _make_list_flow()
flow.state.items.sort()
assert list(flow.state.items) == [1, 1, 2, 3, 4, 5, 6, 9]
def test_locked_list_proxy_sort_reverse():
flow = _make_list_flow()
flow.state.items.sort(reverse=True)
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
def test_locked_list_proxy_sort_key():
flow = _make_list_flow()
flow.state.items.sort(key=lambda x: -x)
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
def test_locked_list_proxy_reverse():
flow = _make_list_flow()
original = list(flow.state.items)
flow.state.items.reverse()
assert list(flow.state.items) == list(reversed(original))
def test_locked_list_proxy_copy():
flow = _make_list_flow()
copied = flow.state.items.copy()
assert copied == [3, 1, 4, 1, 5, 9, 2, 6]
assert isinstance(copied, list)
copied.append(999)
assert 999 not in flow.state.items
def test_locked_list_proxy_add():
flow = _make_list_flow()
result = flow.state.items + [10, 11]
assert result == [3, 1, 4, 1, 5, 9, 2, 6, 10, 11]
assert len(flow.state.items) == 8
def test_locked_list_proxy_radd():
flow = _make_list_flow()
result = [0] + flow.state.items
assert result[0] == 0
assert len(result) == 9
def test_locked_list_proxy_iadd():
flow = _make_list_flow()
flow.state.items += [10]
assert 10 in flow.state.items
# Verify no deadlock: mutations must still work after +=
flow.state.items.append(99)
assert 99 in flow.state.items
def test_locked_list_proxy_mul():
flow = _make_list_flow()
result = flow.state.items * 2
assert len(result) == 16
def test_locked_list_proxy_rmul():
flow = _make_list_flow()
result = 2 * flow.state.items
assert len(result) == 16
def test_locked_list_proxy_reversed():
flow = _make_list_flow()
original = list(flow.state.items)
assert list(reversed(flow.state.items)) == list(reversed(original))
def test_locked_dict_proxy_copy():
flow = _make_dict_flow()
copied = flow.state.data.copy()
assert copied == {"a": 1, "b": 2, "c": 3}
assert isinstance(copied, dict)
copied["z"] = 99
assert "z" not in flow.state.data
def test_locked_dict_proxy_or():
flow = _make_dict_flow()
result = flow.state.data | {"d": 4}
assert result == {"a": 1, "b": 2, "c": 3, "d": 4}
assert "d" not in flow.state.data
def test_locked_dict_proxy_ror():
flow = _make_dict_flow()
result = {"z": 0} | flow.state.data
assert result == {"z": 0, "a": 1, "b": 2, "c": 3}
def test_locked_dict_proxy_ior():
flow = _make_dict_flow()
flow.state.data |= {"d": 4}
assert flow.state.data["d"] == 4
# Verify no deadlock: mutations must still work after |=
flow.state.data["e"] = 5
assert flow.state.data["e"] == 5
def test_locked_dict_proxy_reversed():
flow = _make_dict_flow()
assert list(reversed(flow.state.data)) == ["c", "b", "a"]