mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/feat/plan-execute-pattern
This commit is contained in:
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.10.2a1"
|
||||
|
||||
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.10.1",
|
||||
"crewai==1.10.2a1",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
@@ -108,7 +108,7 @@ stagehand = [
|
||||
"stagehand>=0.4.1",
|
||||
]
|
||||
github = [
|
||||
"gitpython==3.1.38",
|
||||
"gitpython>=3.1.41,<4",
|
||||
"PyGithub==1.59.1",
|
||||
]
|
||||
rag = [
|
||||
|
||||
@@ -10,7 +10,18 @@ from crewai_tools.aws.s3.writer_tool import S3WriterTool
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
|
||||
BrightDataDatasetTool,
|
||||
)
|
||||
@@ -200,7 +211,14 @@ __all__ = [
|
||||
"ArxivPaperTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
@@ -291,4 +309,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.10.2a1"
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
BraveLocalPOIsTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brightdata_tool import (
|
||||
BrightDataDatasetTool,
|
||||
BrightDataSearchTool,
|
||||
@@ -185,7 +196,14 @@ __all__ = [
|
||||
"AIMindTool",
|
||||
"ApifyActorsTool",
|
||||
"ArxivPaperTool",
|
||||
"BraveImageSearchTool",
|
||||
"BraveLLMContextTool",
|
||||
"BraveLocalPOIsDescriptionTool",
|
||||
"BraveLocalPOIsTool",
|
||||
"BraveNewsSearchTool",
|
||||
"BraveSearchTool",
|
||||
"BraveVideoSearchTool",
|
||||
"BraveWebSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Brave API error codes that indicate non-retryable quota/usage exhaustion.
|
||||
_QUOTA_CODES = frozenset({"QUOTA_LIMITED", "USAGE_LIMIT_EXCEEDED"})
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def _parse_error_body(resp: requests.Response) -> dict[str, Any] | None:
|
||||
"""Extract the structured "error" object from a Brave API error response."""
|
||||
try:
|
||||
body = resp.json()
|
||||
error = body.get("error")
|
||||
return error if isinstance(error, dict) else None
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def _raise_for_error(resp: requests.Response) -> None:
|
||||
"""Brave Search API error responses contain helpful JSON payloads"""
|
||||
status = resp.status_code
|
||||
try:
|
||||
body = json.dumps(resp.json())
|
||||
except (ValueError, KeyError):
|
||||
body = resp.text[:500]
|
||||
|
||||
raise RuntimeError(f"Brave Search API error (HTTP {status}): {body}")
|
||||
|
||||
|
||||
def _is_retryable(resp: requests.Response) -> bool:
|
||||
"""Return True for transient failures that are worth retrying.
|
||||
|
||||
* 429 + RATE_LIMITED — the per-second sliding window is full.
|
||||
* 5xx — transient server-side errors.
|
||||
|
||||
Quota exhaustion (QUOTA_LIMITED, USAGE_LIMIT_EXCEEDED) is
|
||||
explicitly excluded: retrying will never succeed until the billing
|
||||
period resets.
|
||||
"""
|
||||
if resp.status_code == 429:
|
||||
error = _parse_error_body(resp) or {}
|
||||
return error.get("code") not in _QUOTA_CODES
|
||||
return 500 <= resp.status_code < 600
|
||||
|
||||
|
||||
def _retry_delay(resp: requests.Response, attempt: int) -> float:
|
||||
"""Compute wait time before the next retry attempt.
|
||||
|
||||
Prefers the server-supplied Retry-After header when available;
|
||||
falls back to exponential backoff (1s, 2s, 4s, ...).
|
||||
"""
|
||||
retry_after = resp.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
try:
|
||||
return max(0.0, float(retry_after))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return float(2**attempt)
|
||||
|
||||
|
||||
class BraveSearchToolBase(BaseTool, ABC):
|
||||
"""
|
||||
Base class for Brave Search API interactions.
|
||||
|
||||
Individual tool subclasses must provide the following:
|
||||
- search_url
|
||||
- header_schema (pydantic model)
|
||||
- args_schema (pydantic model)
|
||||
- _refine_payload() -> dict[str, Any]
|
||||
"""
|
||||
|
||||
search_url: str
|
||||
raw: bool = False
|
||||
args_schema: type[BaseModel]
|
||||
header_schema: type[BaseModel]
|
||||
|
||||
# Tool options (legacy parameters)
|
||||
country: str | None = None
|
||||
save_file: bool = False
|
||||
n_results: int = 10
|
||||
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY",
|
||||
description="API key for Brave Search",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
requests_per_second: float = 1.0,
|
||||
save_file: bool = False,
|
||||
raw: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key or os.environ.get("BRAVE_API_KEY")
|
||||
if not self._api_key:
|
||||
raise ValueError("BRAVE_API_KEY environment variable is required")
|
||||
|
||||
self.raw = bool(raw)
|
||||
self._timeout = int(timeout)
|
||||
self.save_file = bool(save_file)
|
||||
self._requests_per_second = float(requests_per_second)
|
||||
self._headers = self._build_and_validate_headers(headers or {})
|
||||
# Per-instance rate limiting: each instance has its own clock and lock.
|
||||
# Total process rate is the sum of limits of instances you create.
|
||||
self._last_request_time: float = 0
|
||||
self._rate_limit_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, Any]:
|
||||
return self._headers
|
||||
|
||||
def set_headers(self, headers: dict[str, Any]) -> BraveSearchToolBase:
|
||||
merged = {**self._headers, **{k.lower(): v for k, v in headers.items()}}
|
||||
self._headers = self._build_and_validate_headers(merged)
|
||||
return self
|
||||
|
||||
def _build_and_validate_headers(self, headers: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = {k.lower(): v for k, v in headers.items()}
|
||||
normalized.setdefault("x-subscription-token", self._api_key)
|
||||
normalized.setdefault("accept", "application/json")
|
||||
|
||||
try:
|
||||
self.header_schema(**normalized)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid headers: {e}") from e
|
||||
|
||||
return normalized
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
"""Enforce minimum interval between requests for this instance. Thread-safe."""
|
||||
if self._requests_per_second <= 0:
|
||||
return
|
||||
|
||||
min_interval = 1.0 / self._requests_per_second
|
||||
with self._rate_limit_lock:
|
||||
now = time.time()
|
||||
next_allowed = self._last_request_time + min_interval
|
||||
if now < next_allowed:
|
||||
time.sleep(next_allowed - now)
|
||||
now = time.time()
|
||||
self._last_request_time = now
|
||||
|
||||
def _make_request(
|
||||
self, params: dict[str, Any], *, _max_retries: int = 3
|
||||
) -> dict[str, Any]:
|
||||
"""Execute an HTTP GET against the Brave Search API with retry logic."""
|
||||
last_resp: requests.Response | None = None
|
||||
|
||||
# Retry the request up to _max_retries times
|
||||
for attempt in range(_max_retries):
|
||||
self._rate_limit()
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
resp = requests.get(
|
||||
self.search_url,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except requests.ConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API connection failed: {exc}"
|
||||
) from exc
|
||||
except requests.Timeout as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API request timed out after {self._timeout}s: {exc}"
|
||||
) from exc
|
||||
|
||||
# Log the rate limit headers and request details
|
||||
logger.debug(
|
||||
"Brave Search API request: %s %s -> %d",
|
||||
"GET",
|
||||
resp.url,
|
||||
resp.status_code,
|
||||
)
|
||||
|
||||
# Response was OK, return the JSON body
|
||||
if resp.ok:
|
||||
try:
|
||||
return resp.json()
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
|
||||
) from exc
|
||||
|
||||
# Response was not OK, but is retryable
|
||||
# (e.g., 429 Too Many Requests, 500 Internal Server Error)
|
||||
if _is_retryable(resp) and attempt < _max_retries - 1:
|
||||
delay = _retry_delay(resp, attempt)
|
||||
logger.warning(
|
||||
"Brave Search API returned %d. Retrying in %.1fs (attempt %d/%d)",
|
||||
resp.status_code,
|
||||
delay,
|
||||
attempt + 1,
|
||||
_max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
last_resp = resp
|
||||
continue
|
||||
|
||||
# Response was not OK, nor was it retryable
|
||||
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
|
||||
_raise_for_error(resp)
|
||||
|
||||
# All retries exhausted
|
||||
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
|
||||
return {} # unreachable (here to satisfy the type checker and linter)
|
||||
|
||||
def _run(self, q: str | None = None, **params: Any) -> Any:
|
||||
# Allow positional usage: tool.run("latest Brave browser features")
|
||||
if q is not None:
|
||||
params["q"] = q
|
||||
|
||||
params = self._common_payload_refinement(params)
|
||||
|
||||
# Validate only schema fields
|
||||
schema_keys = self.args_schema.model_fields
|
||||
payload_in = {k: v for k, v in params.items() if k in schema_keys}
|
||||
|
||||
try:
|
||||
validated = self.args_schema(**payload_in)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid parameters: {e}") from e
|
||||
|
||||
# The subclass may have additional refinements to apply to the payload, such as goggles or other parameters
|
||||
payload = self._refine_request_payload(validated.model_dump(exclude_none=True))
|
||||
response = self._make_request(payload)
|
||||
|
||||
if not self.raw:
|
||||
response = self._refine_response(response)
|
||||
|
||||
if self.save_file:
|
||||
_save_results_to_file(json.dumps(response, indent=2))
|
||||
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Subclass must implement: transform validated params dict into API request params."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _refine_response(self, response: dict[str, Any]) -> Any:
|
||||
"""Subclass must implement: transform response dict into a more useful format."""
|
||||
raise NotImplementedError
|
||||
|
||||
_EMPTY_VALUES: ClassVar[tuple[None, str, str, list[Any]]] = (None, "", "null", [])
|
||||
|
||||
def _common_payload_refinement(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Common payload refinement for all tools."""
|
||||
# crewAI's schema pipeline (ensure_all_properties_required in
|
||||
# pydantic_schema_utils.py) marks every property as required so
|
||||
# that OpenAI strict-mode structured outputs work correctly.
|
||||
# The side-effect is that the LLM fills in *every* parameter —
|
||||
# even truly optional ones — using placeholder values such as
|
||||
# None, "", "null", or []. Only optional fields are affected,
|
||||
# so we limit the check to those.
|
||||
fields = self.args_schema.model_fields
|
||||
params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
# Permit custom and required fields, and fields with non-empty values
|
||||
if k not in fields or fields[k].is_required() or v not in self._EMPTY_VALUES
|
||||
}
|
||||
|
||||
# Make sure params has "q" for query instead of "query" or "search_query"
|
||||
query = params.get("query") or params.get("search_query")
|
||||
if query is not None and "q" not in params:
|
||||
params["q"] = query
|
||||
params.pop("query", None)
|
||||
params.pop("search_query", None)
|
||||
|
||||
# If "count" was not explicitly provided, use n_results
|
||||
# (only when the schema actually supports a "count" field)
|
||||
if "count" in self.args_schema.model_fields:
|
||||
if "count" not in params and self.n_results is not None:
|
||||
params["count"] = self.n_results
|
||||
|
||||
# If "country" was not explicitly provided, but self.country is set, use it
|
||||
# (only when the schema actually supports a "country" field)
|
||||
if "country" in self.args_schema.model_fields:
|
||||
if "country" not in params and self.country is not None:
|
||||
params["country"] = self.country
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
ImageSearchHeaders,
|
||||
ImageSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveImageSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs image searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Image Search"
|
||||
args_schema: type[BaseModel] = ImageSearchParams
|
||||
header_schema: type[BaseModel] = ImageSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs image searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/images/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("properties", {}).get("url"),
|
||||
"dimensions": f"{w}x{h}"
|
||||
if (w := result.get("properties", {}).get("width"))
|
||||
and (h := result.get("properties", {}).get("height"))
|
||||
else None,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LLMContextHeaders,
|
||||
LLMContextParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveLLMContextTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves context for LLM usage from the Brave Search API."""
|
||||
|
||||
name: str = "Brave LLM Context"
|
||||
args_schema: type[BaseModel] = LLMContextParams
|
||||
header_schema: type[BaseModel] = LLMContextHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that retrieves context for LLM usage from the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/llm/context"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
|
||||
"""The LLM Context response schema is fairly simple. Return as is."""
|
||||
return response
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LocalPOIs
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LocalPOIsDescriptionHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsParams,
|
||||
)
|
||||
|
||||
|
||||
DayOpeningHours = LocalPOIs.DayOpeningHours
|
||||
OpeningHours = LocalPOIs.OpeningHours
|
||||
LocationResult = LocalPOIs.LocationResult
|
||||
LocalPOIsResponse = LocalPOIs.Response
|
||||
|
||||
|
||||
def _flatten_slots(slots: list[DayOpeningHours]) -> list[dict[str, str]]:
|
||||
"""Convert a list of DayOpeningHours dicts into simplified entries."""
|
||||
return [
|
||||
{
|
||||
"day": slot["full_name"].lower(),
|
||||
"opens": slot["opens"],
|
||||
"closes": slot["closes"],
|
||||
}
|
||||
for slot in slots
|
||||
]
|
||||
|
||||
|
||||
def _simplify_opening_hours(result: LocationResult) -> list[dict[str, str]] | None:
|
||||
"""Collapse opening_hours into a flat list of {day, opens, closes} dicts."""
|
||||
hours = result.get("opening_hours")
|
||||
if not hours:
|
||||
return None
|
||||
|
||||
entries: list[dict[str, str]] = []
|
||||
|
||||
current = hours.get("current_day")
|
||||
if current:
|
||||
entries.extend(_flatten_slots(current))
|
||||
|
||||
days = hours.get("days")
|
||||
if days:
|
||||
for day_slots in days:
|
||||
entries.extend(_flatten_slots(day_slots))
|
||||
|
||||
return entries or None
|
||||
|
||||
|
||||
class BraveLocalPOIsTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POIs"
|
||||
args_schema: type[BaseModel] = LocalPOIsParams
|
||||
header_schema: type[BaseModel] = LocalPOIsHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/pois"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"url": result.get("url"),
|
||||
"description": result.get("description"),
|
||||
"address": result.get("postal_address", {}).get("displayAddress"),
|
||||
"contact": result.get("contact", {}).get("telephone")
|
||||
or result.get("contact", {}).get("email")
|
||||
or None,
|
||||
"opening_hours": _simplify_opening_hours(result),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
|
||||
class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
|
||||
"""A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Local POI Descriptions"
|
||||
args_schema: type[BaseModel] = LocalPOIsDescriptionParams
|
||||
header_schema: type[BaseModel] = LocalPOIsDescriptionHeaders
|
||||
description: str = (
|
||||
"A tool that retrieves AI-generated descriptions for local POIs using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
search_url: str = "https://api.search.brave.com/res/v1/local/descriptions"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"id": result.get("id"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
NewsSearchHeaders,
|
||||
NewsSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveNewsSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs news searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave News Search"
|
||||
args_schema: type[BaseModel] = NewsSearchParams
|
||||
header_schema: type[BaseModel] = NewsSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs news searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/news/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -10,17 +10,13 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
|
||||
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
|
||||
FreshnessRange = Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
@@ -29,51 +25,6 @@ Freshness = FreshnessPreset | FreshnessRange
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
"""Input for BraveSearchTool"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
)
|
||||
search_language: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None, description="Skip the first N result sets/pages. Max is 9."
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
@@ -83,7 +34,7 @@ class BraveSearchTool(BaseTool):
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
@@ -120,8 +71,8 @@ class BraveSearchTool(BaseTool):
|
||||
|
||||
# Construct and send the request
|
||||
try:
|
||||
# Maintain both "search_query" and "query" for backwards compatibility
|
||||
query = kwargs.get("search_query") or kwargs.get("query")
|
||||
# Fallback to "query" or "search_query" for backwards compatibility
|
||||
query = kwargs.get("q") or kwargs.get("query") or kwargs.get("search_query")
|
||||
if not query:
|
||||
raise ValueError("Query is required")
|
||||
|
||||
@@ -130,8 +81,11 @@ class BraveSearchTool(BaseTool):
|
||||
if country := kwargs.get("country"):
|
||||
payload["country"] = country
|
||||
|
||||
if search_language := kwargs.get("search_language"):
|
||||
payload["search_language"] = search_language
|
||||
# Fallback to "search_language" for backwards compatibility
|
||||
if search_lang := kwargs.get("search_lang") or kwargs.get(
|
||||
"search_language"
|
||||
):
|
||||
payload["search_lang"] = search_lang
|
||||
|
||||
# Fallback to deprecated n_results parameter if no count is provided
|
||||
count = kwargs.get("count")
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
VideoSearchHeaders,
|
||||
VideoSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveVideoSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs video searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Video Search"
|
||||
args_schema: type[BaseModel] = VideoSearchParams
|
||||
header_schema: type[BaseModel] = VideoSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs video searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/videos/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"description": result.get("description"),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchHeaders,
|
||||
WebSearchParams,
|
||||
)
|
||||
|
||||
|
||||
class BraveWebSearchTool(BraveSearchToolBase):
|
||||
"""A tool that performs web searches using the Brave Search API."""
|
||||
|
||||
name: str = "Brave Web Search"
|
||||
args_schema: type[BaseModel] = WebSearchParams
|
||||
header_schema: type[BaseModel] = WebSearchHeaders
|
||||
|
||||
description: str = (
|
||||
"A tool that performs web searches using the Brave Search API. "
|
||||
"Results are returned as structured JSON data."
|
||||
)
|
||||
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
results = response.get("web", {}).get("results", [])
|
||||
refined = []
|
||||
for result in results:
|
||||
snippets = result.get("extra_snippets") or []
|
||||
if not snippets:
|
||||
desc = result.get("description")
|
||||
if desc:
|
||||
snippets = [desc]
|
||||
refined.append(
|
||||
{
|
||||
"url": result.get("url"),
|
||||
"title": result.get("title"),
|
||||
"snippets": snippets,
|
||||
}
|
||||
)
|
||||
return refined
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
|
||||
class LocalPOIs:
|
||||
class PostalAddress(TypedDict, total=False):
|
||||
type: Literal["PostalAddress"]
|
||||
country: str
|
||||
postalCode: str
|
||||
streetAddress: str
|
||||
addressRegion: str
|
||||
addressLocality: str
|
||||
displayAddress: str
|
||||
|
||||
class DayOpeningHours(TypedDict):
|
||||
abbr_name: str
|
||||
full_name: str
|
||||
opens: str
|
||||
closes: str
|
||||
|
||||
class OpeningHours(TypedDict, total=False):
|
||||
current_day: list[LocalPOIs.DayOpeningHours]
|
||||
days: list[list[LocalPOIs.DayOpeningHours]]
|
||||
|
||||
class LocationResult(TypedDict, total=False):
|
||||
provider_url: str
|
||||
title: str
|
||||
url: str
|
||||
id: str | None
|
||||
opening_hours: LocalPOIs.OpeningHours | None
|
||||
postal_address: LocalPOIs.PostalAddress | None
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
type: Literal["local_pois"]
|
||||
results: list[LocalPOIs.LocationResult]
|
||||
|
||||
|
||||
class LLMContext:
|
||||
class LLMContextItem(TypedDict, total=False):
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextMapItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class LLMContextPOIItem(TypedDict, total=False):
|
||||
name: str
|
||||
snippets: list[str]
|
||||
title: str
|
||||
url: str
|
||||
|
||||
class Grounding(TypedDict, total=False):
|
||||
generic: list[LLMContext.LLMContextItem]
|
||||
poi: LLMContext.LLMContextPOIItem
|
||||
map: list[LLMContext.LLMContextMapItem]
|
||||
|
||||
class Sources(TypedDict, total=False):
|
||||
pass
|
||||
|
||||
class Response(TypedDict, total=False):
|
||||
grounding: LLMContext.Grounding
|
||||
sources: LLMContext.Sources
|
||||
@@ -0,0 +1,525 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.types import StringConstraints
|
||||
|
||||
|
||||
# Common types
|
||||
Units = Literal["metric", "imperial"]
|
||||
SafeSearch = Literal["off", "moderate", "strict"]
|
||||
Freshness = (
|
||||
Literal["pd", "pw", "pm", "py"]
|
||||
| Annotated[
|
||||
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
|
||||
]
|
||||
)
|
||||
ResultFilter = list[
|
||||
Literal[
|
||||
"discussions",
|
||||
"faq",
|
||||
"infobox",
|
||||
"news",
|
||||
"query",
|
||||
"summarizer",
|
||||
"videos",
|
||||
"web",
|
||||
"locations",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class LLMContextParams(BaseModel):
|
||||
"""Parameters for Brave LLM Context endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_urls: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of URLs to include in the context.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
maximum_number_of_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="The approximate maximum number of tokens to include in the context.",
|
||||
ge=1,
|
||||
le=32768,
|
||||
)
|
||||
maximum_number_of_snippets: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of different snippets to include in the context.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
context_threshold_mode: (
|
||||
Literal["disabled", "strict", "lenient", "balanced"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="The mode to use for the context thresholding.",
|
||||
)
|
||||
maximum_number_of_tokens_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to include for each URL in the context.",
|
||||
ge=1,
|
||||
le=8192,
|
||||
)
|
||||
maximum_number_of_snippets_per_url: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of snippets to include per URL.",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
enable_local: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable local recall. Not setting this value means auto-detect and uses local recall if any of the localization headers are provided.",
|
||||
)
|
||||
|
||||
|
||||
class WebSearchParams(BaseModel):
|
||||
"""Parameters for Brave Web Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return. Actual number may be less.",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
text_decorations: bool | None = Field(
|
||||
default=None,
|
||||
description="Include markup to highlight search terms in the results.",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
result_filter: ResultFilter | None = Field(
|
||||
default=None,
|
||||
description="Filter the results by type. Options: discussions/faq/infobox/news/query/summarizer/videos/web/locations. Note: The `count` parameter is applied only to the `web` results.",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
summary: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to generate a summarizer ID for the results.",
|
||||
)
|
||||
enable_rich_callback: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to enable rich callbacks for the results. Requires Pro level subscription.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsParams(BaseModel):
|
||||
"""Parameters for Brave Local POIs endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
units: Units | None = Field(
|
||||
default=None,
|
||||
description="The units to use for the results. Options: metric/imperial",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionParams(BaseModel):
|
||||
"""Parameters for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
ids: list[str] = Field(
|
||||
description="List of POI IDs to retrieve. Maximum of 20. IDs are valid for 8 hours.",
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
class ImageSearchParams(BaseModel):
|
||||
"""Parameters for Brave Image Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Default is strict.",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=200,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
|
||||
|
||||
class VideoSearchParams(BaseModel):
|
||||
"""Parameters for Brave Video Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: SafeSearch | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata (e.g., last fetch time) in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class NewsSearchParams(BaseModel):
|
||||
"""Parameters for Brave News Search endpoint."""
|
||||
|
||||
q: str = Field(
|
||||
description="Search query to perform",
|
||||
min_length=1,
|
||||
max_length=400,
|
||||
)
|
||||
search_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the search results (e.g., 'en', 'es').",
|
||||
pattern=r"^[a-z]{2}$",
|
||||
)
|
||||
ui_lang: str | None = Field(
|
||||
default=None,
|
||||
description="Language code for the user interface (e.g., 'en-US', 'es-AR').",
|
||||
pattern=r"^[a-z]{2}-[A-Z]{2}$",
|
||||
)
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="Country code for geo-targeting (e.g., 'US', 'BR').",
|
||||
pattern=r"^[A-Z]{2}$",
|
||||
)
|
||||
safesearch: Literal["off", "moderate", "strict"] | None = Field(
|
||||
default=None,
|
||||
description="Filter out explicit content. Options: off/moderate/strict",
|
||||
)
|
||||
count: int | None = Field(
|
||||
default=None,
|
||||
description="The maximum number of results to return.",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
offset: int | None = Field(
|
||||
default=None,
|
||||
description="Skip the first N result sets/pages. Max is 9.",
|
||||
ge=0,
|
||||
le=9,
|
||||
)
|
||||
spellcheck: bool | None = Field(
|
||||
default=None,
|
||||
description="Attempt to correct spelling errors in the search query.",
|
||||
)
|
||||
freshness: Freshness | None = Field(
|
||||
default=None,
|
||||
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
|
||||
)
|
||||
extra_snippets: bool | None = Field(
|
||||
default=None,
|
||||
description="Include up to 5 text snippets for each page if possible.",
|
||||
)
|
||||
goggles: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Goggles act as a custom re-ranking mechanism. Goggle source or URLs.",
|
||||
)
|
||||
include_fetch_metadata: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include fetch metadata in the results.",
|
||||
)
|
||||
operators: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to apply search operators (e.g., site:example.com).",
|
||||
)
|
||||
|
||||
|
||||
class BaseSearchHeaders(BaseModel):
|
||||
"""Common headers for Brave Search endpoints."""
|
||||
|
||||
x_subscription_token: str = Field(
|
||||
alias="x-subscription-token",
|
||||
description="API key for Brave Search",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
alias="api-version",
|
||||
default=None,
|
||||
description="API version to use. Default is latest available.",
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}$", # YYYY-MM-DD
|
||||
)
|
||||
accept: Literal["application/json"] | Literal["*/*"] | None = Field(
|
||||
default=None,
|
||||
description="Accept header for the request.",
|
||||
)
|
||||
cache_control: Literal["no-cache"] | None = Field(
|
||||
alias="cache-control",
|
||||
default=None,
|
||||
description="Cache control header for the request.",
|
||||
)
|
||||
user_agent: str | None = Field(
|
||||
alias="user-agent",
|
||||
default=None,
|
||||
description="User agent for the request.",
|
||||
)
|
||||
|
||||
|
||||
class LLMContextHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave LLM Context endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POIs endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
|
||||
|
||||
class LocalPOIsDescriptionHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Local POI Descriptions endpoint."""
|
||||
|
||||
|
||||
class VideoSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Video Search endpoint."""
|
||||
|
||||
|
||||
class ImageSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Image Search endpoint."""
|
||||
|
||||
|
||||
class NewsSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave News Search endpoint."""
|
||||
|
||||
|
||||
class WebSearchHeaders(BaseSearchHeaders):
|
||||
"""Headers for Brave Web Search endpoint."""
|
||||
|
||||
x_loc_lat: float | None = Field(
|
||||
alias="x-loc-lat",
|
||||
default=None,
|
||||
description="Latitude of the user's location.",
|
||||
ge=-90.0,
|
||||
le=90.0,
|
||||
)
|
||||
x_loc_long: float | None = Field(
|
||||
alias="x-loc-long",
|
||||
default=None,
|
||||
description="Longitude of the user's location.",
|
||||
ge=-180.0,
|
||||
le=180.0,
|
||||
)
|
||||
x_loc_timezone: str | None = Field(
|
||||
alias="x-loc-timezone",
|
||||
default=None,
|
||||
description="Timezone of the user's location.",
|
||||
)
|
||||
x_loc_city: str | None = Field(
|
||||
alias="x-loc-city",
|
||||
default=None,
|
||||
description="City of the user's location.",
|
||||
)
|
||||
x_loc_state: str | None = Field(
|
||||
alias="x-loc-state",
|
||||
default=None,
|
||||
description="State of the user's location.",
|
||||
)
|
||||
x_loc_state_name: str | None = Field(
|
||||
alias="x-loc-state-name",
|
||||
default=None,
|
||||
description="Name of the state of the user's location.",
|
||||
)
|
||||
x_loc_country: str | None = Field(
|
||||
alias="x-loc-country",
|
||||
default=None,
|
||||
description="The ISO 3166-1 alpha-2 country code of the user's location.",
|
||||
)
|
||||
x_loc_postal_code: str | None = Field(
|
||||
alias="x-loc-postal-code",
|
||||
default=None,
|
||||
description="The postal code of the user's location.",
|
||||
)
|
||||
@@ -1,80 +1,777 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests as requests_lib
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.brave_web_tool import BraveWebSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_image_tool import BraveImageSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_news_tool import BraveNewsSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_video_tool import BraveVideoSearchTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_llm_context_tool import (
|
||||
BraveLLMContextTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.brave_local_pois_tool import (
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
)
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
text: str = "",
|
||||
) -> MagicMock:
|
||||
"""Build a ``requests.Response``-like mock with the attributes used by ``_make_request``."""
|
||||
resp = MagicMock(spec=requests_lib.Response)
|
||||
resp.status_code = status_code
|
||||
resp.ok = 200 <= status_code < 400
|
||||
resp.url = "https://api.search.brave.com/res/v1/web/search?q=test"
|
||||
resp.text = text or (str(json_data) if json_data else "")
|
||||
resp.headers = headers or {}
|
||||
resp.json.return_value = json_data if json_data is not None else {}
|
||||
return resp
|
||||
|
||||
|
||||
# Fixtures
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _brave_env_and_rate_limit():
|
||||
"""Set BRAVE_API_KEY for every test. Rate limiting is per-instance (each tool starts with a fresh clock)."""
|
||||
with patch.dict(os.environ, {"BRAVE_API_KEY": "test-api-key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
return BraveSearchTool(n_results=2)
|
||||
def web_tool():
|
||||
return BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_brave_tool_initialization():
|
||||
tool = BraveSearchTool()
|
||||
assert tool.n_results == 10
|
||||
@pytest.fixture
|
||||
def image_tool():
|
||||
return BraveImageSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def news_tool():
|
||||
return BraveNewsSearchTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_tool():
|
||||
return BraveVideoSearchTool()
|
||||
|
||||
|
||||
# Initialization
|
||||
|
||||
ALL_TOOL_CLASSES = [
|
||||
BraveWebSearchTool,
|
||||
BraveImageSearchTool,
|
||||
BraveNewsSearchTool,
|
||||
BraveVideoSearchTool,
|
||||
BraveLLMContextTool,
|
||||
BraveLocalPOIsTool,
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_env_var(tool_cls):
|
||||
"""Each tool can be created when BRAVE_API_KEY is in the environment."""
|
||||
tool = tool_cls()
|
||||
assert tool.api_key == "test-api-key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_cls", ALL_TOOL_CLASSES)
|
||||
def test_instantiation_with_explicit_key(tool_cls):
|
||||
"""An explicit api_key takes precedence over the environment."""
|
||||
tool = tool_cls(api_key="explicit-key")
|
||||
assert tool.api_key == "explicit-key"
|
||||
|
||||
|
||||
def test_missing_api_key_raises():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="BRAVE_API_KEY"):
|
||||
BraveWebSearchTool()
|
||||
|
||||
|
||||
def test_default_attributes():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.save_file is False
|
||||
assert tool.n_results == 10
|
||||
assert tool._timeout == 30
|
||||
assert tool._requests_per_second == 1.0
|
||||
assert tool.raw is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool_search(mock_get, brave_tool):
|
||||
mock_response = {
|
||||
def test_custom_constructor_args():
|
||||
tool = BraveWebSearchTool(
|
||||
save_file=True,
|
||||
timeout=60,
|
||||
n_results=5,
|
||||
requests_per_second=0.5,
|
||||
raw=True,
|
||||
)
|
||||
assert tool.save_file is True
|
||||
assert tool._timeout == 60
|
||||
assert tool.n_results == 5
|
||||
assert tool._requests_per_second == 0.5
|
||||
assert tool.raw is True
|
||||
|
||||
|
||||
# Headers
|
||||
|
||||
|
||||
def test_default_headers():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
assert tool.headers["accept"] == "application/json"
|
||||
|
||||
|
||||
def test_set_headers_merges_and_normalizes():
|
||||
tool = BraveWebSearchTool()
|
||||
tool.set_headers({"Cache-Control": "no-cache"})
|
||||
assert tool.headers["cache-control"] == "no-cache"
|
||||
assert tool.headers["x-subscription-token"] == "test-api-key"
|
||||
|
||||
|
||||
def test_set_headers_returns_self_for_chaining():
|
||||
tool = BraveWebSearchTool()
|
||||
assert tool.set_headers({"Cache-Control": "no-cache"}) is tool
|
||||
|
||||
|
||||
def test_invalid_header_value_raises():
|
||||
tool = BraveImageSearchTool()
|
||||
with pytest.raises(ValueError, match="Invalid headers"):
|
||||
tool.set_headers({"Accept": "text/xml"})
|
||||
|
||||
|
||||
# Endpoint & Schema Wiring
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_cls, expected_url, expected_params, expected_headers",
|
||||
[
|
||||
(
|
||||
BraveWebSearchTool,
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
WebSearchParams,
|
||||
WebSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveImageSearchTool,
|
||||
"https://api.search.brave.com/res/v1/images/search",
|
||||
ImageSearchParams,
|
||||
ImageSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveNewsSearchTool,
|
||||
"https://api.search.brave.com/res/v1/news/search",
|
||||
NewsSearchParams,
|
||||
NewsSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveVideoSearchTool,
|
||||
"https://api.search.brave.com/res/v1/videos/search",
|
||||
VideoSearchParams,
|
||||
VideoSearchHeaders,
|
||||
),
|
||||
(
|
||||
BraveLLMContextTool,
|
||||
"https://api.search.brave.com/res/v1/llm/context",
|
||||
LLMContextParams,
|
||||
LLMContextHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsTool,
|
||||
"https://api.search.brave.com/res/v1/local/pois",
|
||||
LocalPOIsParams,
|
||||
LocalPOIsHeaders,
|
||||
),
|
||||
(
|
||||
BraveLocalPOIsDescriptionTool,
|
||||
"https://api.search.brave.com/res/v1/local/descriptions",
|
||||
LocalPOIsDescriptionParams,
|
||||
LocalPOIsDescriptionHeaders,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_wiring(tool_cls, expected_url, expected_params, expected_headers):
|
||||
tool = tool_cls()
|
||||
assert tool.search_url == expected_url
|
||||
assert tool.args_schema is expected_params
|
||||
assert tool.header_schema is expected_headers
|
||||
|
||||
|
||||
# Payload Refinement (e.g., `query` -> `q`, `count` fallback, param pass-through)
|
||||
|
||||
|
||||
def test_web_refine_request_payload_passes_all_params(web_tool):
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"count": 5,
|
||||
"offset": 2,
|
||||
"safesearch": "moderate",
|
||||
"freshness": "pw",
|
||||
}
|
||||
)
|
||||
refined_params = web_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "test"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["count"] == 5
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["search_lang"] == "en"
|
||||
assert refined_params["offset"] == 2
|
||||
assert refined_params["safesearch"] == "moderate"
|
||||
assert refined_params["freshness"] == "pw"
|
||||
|
||||
|
||||
def test_image_refine_request_payload_passes_all_params(image_tool):
|
||||
params = image_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "cat photos",
|
||||
"country": "US",
|
||||
"search_lang": "en",
|
||||
"safesearch": "strict",
|
||||
"count": 50,
|
||||
"spellcheck": True,
|
||||
}
|
||||
)
|
||||
refined_params = image_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "cat photos"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["safesearch"] == "strict"
|
||||
assert refined_params["count"] == 50
|
||||
assert refined_params["spellcheck"] is True
|
||||
|
||||
|
||||
def test_news_refine_request_payload_passes_all_params(news_tool):
|
||||
params = news_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "breaking news",
|
||||
"country": "US",
|
||||
"count": 10,
|
||||
"offset": 1,
|
||||
"freshness": "pd",
|
||||
"extra_snippets": True,
|
||||
}
|
||||
)
|
||||
refined_params = news_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "breaking news"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 1
|
||||
assert refined_params["freshness"] == "pd"
|
||||
assert refined_params["extra_snippets"] is True
|
||||
|
||||
|
||||
def test_video_refine_request_payload_passes_all_params(video_tool):
|
||||
params = video_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "tutorial",
|
||||
"country": "US",
|
||||
"count": 25,
|
||||
"offset": 0,
|
||||
"safesearch": "strict",
|
||||
"freshness": "pm",
|
||||
}
|
||||
)
|
||||
refined_params = video_tool._refine_request_payload(params)
|
||||
|
||||
assert refined_params["q"] == "tutorial"
|
||||
assert "query" not in refined_params
|
||||
assert refined_params["country"] == "US"
|
||||
assert refined_params["offset"] == 0
|
||||
assert refined_params["freshness"] == "pm"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_flow_into_query_params():
|
||||
"""The legacy n_results and country constructor params are applied as defaults
|
||||
when count/country are not explicitly provided at call time."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement({"query": "test"})
|
||||
|
||||
assert params["count"] == 3
|
||||
assert params["country"] == "BR"
|
||||
|
||||
|
||||
def test_legacy_constructor_params_do_not_override_explicit_query_params():
|
||||
"""Explicit query-time count/country take precedence over constructor defaults."""
|
||||
tool = BraveWebSearchTool(n_results=3, country="BR")
|
||||
params = tool._common_payload_refinement(
|
||||
{"query": "test", "count": 10, "country": "US"}
|
||||
)
|
||||
|
||||
assert params["count"] == 10
|
||||
assert params["country"] == "US"
|
||||
|
||||
|
||||
def test_refine_request_payload_passes_multiple_goggles_as_multiple_params(web_tool):
|
||||
result = web_tool._refine_request_payload(
|
||||
{
|
||||
"query": "test",
|
||||
"goggles": ["goggle1", "goggle2"],
|
||||
}
|
||||
)
|
||||
assert result["goggles"] == ["goggle1", "goggle2"]
|
||||
|
||||
|
||||
# Null-like / empty value stripping
|
||||
#
|
||||
# crewAI's ensure_all_properties_required (pydantic_schema_utils.py) marks
|
||||
# every schema property as required for OpenAI strict-mode compatibility.
|
||||
# Because optional Brave API parameters look required to the LLM, it fills
|
||||
# them with placeholder junk — None, "", "null", or []. The test below
|
||||
# verifies that _common_payload_refinement strips these from optional fields.
|
||||
|
||||
|
||||
def test_common_refinement_strips_null_like_values(web_tool):
|
||||
"""_common_payload_refinement drops optional keys with None / '' / 'null' / []."""
|
||||
params = web_tool._common_payload_refinement(
|
||||
{
|
||||
"query": "test",
|
||||
"country": "US",
|
||||
"search_lang": "",
|
||||
"freshness": "null",
|
||||
"count": 5,
|
||||
"goggles": [],
|
||||
}
|
||||
)
|
||||
assert params["q"] == "test"
|
||||
assert params["country"] == "US"
|
||||
assert params["count"] == 5
|
||||
assert "search_lang" not in params
|
||||
assert "freshness" not in params
|
||||
assert "goggles" not in params
|
||||
|
||||
|
||||
# End-to-End _run() with Mocked HTTP Response
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_web_search_end_to_end(mock_get, web_tool):
|
||||
web_tool.raw = True
|
||||
data = {"web": {"results": [{"title": "R", "url": "http://r.co"}]}}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
call_args = mock_get.call_args.kwargs
|
||||
assert call_args["params"]["q"] == "test"
|
||||
assert call_args["headers"]["x-subscription-token"] == "test-api-key"
|
||||
assert result == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_image_search_end_to_end(mock_get, image_tool):
|
||||
image_tool.raw = True
|
||||
data = {"results": [{"url": "http://img.co/a.jpg"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert image_tool._run(query="cats") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_news_search_end_to_end(mock_get, news_tool):
|
||||
news_tool.raw = True
|
||||
data = {"results": [{"title": "News", "url": "http://n.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert news_tool._run(query="headlines") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_video_search_end_to_end(mock_get, video_tool):
|
||||
video_tool.raw = True
|
||||
data = {"results": [{"title": "Vid", "url": "http://v.co"}]}
|
||||
mock_get.return_value = _mock_response(json_data=data)
|
||||
|
||||
assert video_tool._run(query="python tutorial") == data
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_raw_false_calls_refine_response(mock_get, web_tool):
|
||||
"""With raw=False (the default), _refine_response transforms the API response."""
|
||||
api_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "http://test.com",
|
||||
"description": "Test Description",
|
||||
"title": "CrewAI",
|
||||
"url": "https://crewai.com",
|
||||
"description": "AI agent framework",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value = _mock_response(json_data=api_response)
|
||||
|
||||
result = brave_tool.run(query="test")
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
assert data[0]["title"] == "Test Title"
|
||||
assert data[0]["url"] == "http://test.com"
|
||||
assert web_tool.raw is False
|
||||
result = web_tool._run(query="crewai")
|
||||
|
||||
# The web tool's _refine_response extracts and reshapes results.
|
||||
# The key assertion: we should NOT get back the raw API envelope.
|
||||
assert result != api_response
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool(mock_get):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Brave Browser",
|
||||
"url": "https://brave.com",
|
||||
"description": "Brave Browser description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
tool = BraveSearchTool(n_results=2)
|
||||
result = tool.run(query="Brave Browser")
|
||||
assert result is not None
|
||||
|
||||
# Parse JSON so we can examine the structure
|
||||
data = json.loads(result)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# First item should have expected fields: title, url, and description
|
||||
first = data[0]
|
||||
assert "title" in first
|
||||
assert first["title"] == "Brave Browser"
|
||||
assert "url" in first
|
||||
assert first["url"] == "https://brave.com"
|
||||
assert "description" in first
|
||||
assert first["description"] == "Brave Browser description"
|
||||
# Backward Compatibility & Legacy Parameter Support
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_brave_tool()
|
||||
test_brave_tool_initialization()
|
||||
# test_brave_tool_search(brave_tool)
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_positional_query_argument(mock_get, web_tool):
|
||||
"""tool.run('my query') works as a positional argument."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run("positional test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "positional test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_search_query_backward_compat(mock_get, web_tool):
|
||||
"""The legacy 'search_query' param is mapped to 'query'."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._run(search_query="legacy test")
|
||||
|
||||
assert mock_get.call_args.kwargs["params"]["q"] == "legacy test"
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base._save_results_to_file")
|
||||
def test_save_file_called_when_enabled(mock_save, mock_get):
|
||||
mock_get.return_value = _mock_response(json_data={"results": []})
|
||||
|
||||
tool = BraveWebSearchTool(save_file=True)
|
||||
tool._run(query="test")
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
# Error Handling
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_connection_error_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.ConnectionError("refused")
|
||||
with pytest.raises(RuntimeError, match="Brave Search API connection failed"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_timeout_raises_runtime_error(mock_get, web_tool):
|
||||
mock_get.side_effect = requests_lib.exceptions.Timeout("timed out")
|
||||
with pytest.raises(RuntimeError, match="timed out"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_params_raises_value_error(mock_get, web_tool):
|
||||
"""count=999 exceeds WebSearchParams.count le=20."""
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
web_tool._run(query="test", count=999)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_4xx_error_raises_with_api_detail(mock_get, web_tool):
|
||||
"""A 422 with a structured error body includes code and detail in the message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=422,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "abc-123",
|
||||
"status": 422,
|
||||
"code": "OPTION_NOT_IN_PLAN",
|
||||
"detail": "extra_snippets requires a Pro plan",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="OPTION_NOT_IN_PLAN") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "extra_snippets requires a Pro plan" in str(exc_info.value)
|
||||
assert "HTTP 422" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_auth_error_raises_immediately(mock_get, web_tool):
|
||||
"""A 401 with SUBSCRIPTION_TOKEN_INVALID is not retried."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=401,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "xyz",
|
||||
"status": 401,
|
||||
"code": "SUBSCRIPTION_TOKEN_INVALID",
|
||||
"detail": "The subscription token is invalid",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="SUBSCRIPTION_TOKEN_INVALID"):
|
||||
web_tool._run(query="test")
|
||||
# Should NOT have retried — only one call.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_quota_limited_429_raises_immediately(mock_get, web_tool):
|
||||
"""A 429 with QUOTA_LIMITED is NOT retried — quota exhaustion is terminal."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ql-1",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Monthly quota exceeded",
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="QUOTA_LIMITED") as exc_info:
|
||||
web_tool._run(query="test")
|
||||
assert "Monthly quota exceeded" in str(exc_info.value)
|
||||
# Terminal — only one HTTP call, no retries.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_usage_limit_exceeded_429_raises_immediately(mock_get, web_tool):
|
||||
"""USAGE_LIMIT_EXCEEDED is also non-retryable, just like QUOTA_LIMITED."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "ule-1",
|
||||
"status": 429,
|
||||
"code": "USAGE_LIMIT_EXCEEDED",
|
||||
}
|
||||
},
|
||||
text="usage limit exceeded",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="USAGE_LIMIT_EXCEEDED"):
|
||||
web_tool._run(query="test")
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_body_is_fully_included_in_message(mock_get, web_tool):
|
||||
"""The full JSON error body is included in the RuntimeError message."""
|
||||
mock_get.return_value = _mock_response(
|
||||
status_code=429,
|
||||
json_data={
|
||||
"error": {
|
||||
"id": "x",
|
||||
"status": 429,
|
||||
"code": "QUOTA_LIMITED",
|
||||
"detail": "Exceeded",
|
||||
"meta": {"plan": "free", "limit": 1000},
|
||||
}
|
||||
},
|
||||
)
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
web_tool._run(query="test")
|
||||
msg = str(exc_info.value)
|
||||
assert "HTTP 429" in msg
|
||||
assert "QUOTA_LIMITED" in msg
|
||||
assert "free" in msg
|
||||
assert "1000" in msg
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_error_without_json_body_falls_back_to_text(mock_get, web_tool):
|
||||
"""When the error response isn't valid JSON, resp.text is used as the detail."""
|
||||
resp = _mock_response(status_code=500, text="Internal Server Error")
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="Internal Server Error"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
def test_invalid_json_on_success_raises_runtime_error(mock_get, web_tool):
|
||||
"""A 200 OK with a non-JSON body raises RuntimeError."""
|
||||
resp = _mock_response(status_code=200)
|
||||
resp.json.side_effect = ValueError("Expecting value")
|
||||
mock_get.return_value = resp
|
||||
|
||||
with pytest.raises(RuntimeError, match="invalid JSON"):
|
||||
web_tool._run(query="test")
|
||||
|
||||
|
||||
# Rate Limiting
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_sleeps_when_too_fast(mock_time, mock_get, web_tool):
|
||||
"""Back-to-back calls within the interval trigger a sleep."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Simulate: last request was at t=100, "now" is t=100.2 (only 0.2s elapsed).
|
||||
# With default 1 req/s the min interval is 1.0s, so it should sleep ~0.8s.
|
||||
mock_time.time.return_value = 100.2
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_called_once()
|
||||
sleep_duration = mock_time.sleep.call_args[0][0]
|
||||
assert 0.7 < sleep_duration < 0.9 # approximately 0.8s
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_skips_sleep_when_enough_time_passed(mock_time, mock_get, web_tool):
|
||||
"""No sleep when the elapsed time already exceeds the interval."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Last request was at t=100, "now" is t=102 (2s elapsed > 1s interval).
|
||||
mock_time.time.return_value = 102.0
|
||||
web_tool._last_request_time = 100.0
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_disabled_when_zero(mock_time, mock_get, web_tool):
|
||||
"""requests_per_second=0 disables rate limiting entirely."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
web_tool._last_request_time = 100.0
|
||||
mock_time.time.return_value = 100.0 # same instant
|
||||
|
||||
web_tool._run(query="test")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_rate_limit_per_instance_independent(mock_time, mock_get, web_tool, image_tool):
|
||||
"""Each instance has its own rate-limit clock; a request on one does not delay the other."""
|
||||
mock_get.return_value = _mock_response(json_data={})
|
||||
|
||||
# Web tool fires at t=100 (its clock goes 0 -> 100).
|
||||
mock_time.time.return_value = 100.0
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Image tool fires at t=100.3. Its clock is still 0 (separate instance), so
|
||||
# next_allowed = 1.0 and 100.3 > 1.0 — no sleep. Total process rate can be sum of instance limits.
|
||||
mock_time.time.return_value = 100.3
|
||||
image_tool._run(query="cats")
|
||||
|
||||
mock_time.sleep.assert_not_called()
|
||||
|
||||
|
||||
# Retry Behavior
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_retries_then_succeeds(mock_time, mock_get, web_tool):
|
||||
"""A transient RATE_LIMITED 429 is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
headers={"Retry-After": "2"},
|
||||
)
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_429, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
# Slept for the Retry-After value.
|
||||
retry_sleeps = [c for c in mock_time.sleep.call_args_list if c[0][0] == 2.0]
|
||||
assert len(retry_sleeps) == 1
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_5xx_is_retried(mock_time, mock_get, web_tool):
|
||||
"""A 502 server error is retried; success on the second attempt."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_502 = _mock_response(status_code=502, text="Bad Gateway")
|
||||
resp_502.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"web": {"results": []}})
|
||||
mock_get.side_effect = [resp_502, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
result = web_tool._run(query="test")
|
||||
|
||||
assert result == {"web": {"results": []}}
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_429_rate_limited_exhausts_retries(mock_time, mock_get, web_tool):
|
||||
"""Persistent RATE_LIMITED 429s exhaust retries and raise RuntimeError."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_429 = _mock_response(
|
||||
status_code=429,
|
||||
json_data={"error": {"id": "r", "status": 429, "code": "RATE_LIMITED"}},
|
||||
)
|
||||
mock_get.return_value = resp_429
|
||||
|
||||
with pytest.raises(RuntimeError, match="RATE_LIMITED"):
|
||||
web_tool._run(query="test")
|
||||
# 3 attempts (default _max_retries).
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.requests.get")
|
||||
@patch("crewai_tools.tools.brave_search_tool.base.time")
|
||||
def test_retry_uses_exponential_backoff_when_no_retry_after(
|
||||
mock_time, mock_get, web_tool
|
||||
):
|
||||
"""Without Retry-After, backoff is 2^attempt (1s, 2s, ...)."""
|
||||
mock_time.time.return_value = 200.0
|
||||
|
||||
resp_503 = _mock_response(status_code=503, text="Service Unavailable")
|
||||
resp_503.json.side_effect = ValueError("no json")
|
||||
resp_200 = _mock_response(status_code=200, json_data={"ok": True})
|
||||
mock_get.side_effect = [resp_503, resp_503, resp_200]
|
||||
|
||||
web_tool.raw = True
|
||||
web_tool._run(query="test")
|
||||
|
||||
# Two retries: attempt 0 → sleep(1.0), attempt 1 → sleep(2.0).
|
||||
retry_sleeps = [c[0][0] for c in mock_time.sleep.call_args_list]
|
||||
assert 1.0 in retry_sleeps
|
||||
assert 2.0 in retry_sleeps
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.10.1",
|
||||
"crewai-tools==1.10.2a1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -41,7 +41,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.10.2a1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -30,12 +30,9 @@ class CrewAgentExecutorMixin:
|
||||
memory = getattr(self.agent, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
if memory is None or not self.task or getattr(memory, "_read_only", False):
|
||||
if memory is None or not self.task or memory.read_only:
|
||||
return
|
||||
if (
|
||||
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
in output.text
|
||||
):
|
||||
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
|
||||
return
|
||||
try:
|
||||
raw = (
|
||||
@@ -48,6 +45,4 @@ class CrewAgentExecutorMixin:
|
||||
if extracted:
|
||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to save to memory: {e}"
|
||||
)
|
||||
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -755,6 +756,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
self._execute_single_native_tool_call,
|
||||
call_id=call_id,
|
||||
func_name=func_name,
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.10.2a1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.10.2a1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.10.1"
|
||||
"crewai[tools]==1.10.2a1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
@@ -797,7 +798,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
max_workers = min(8, len(runnable_tool_calls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_idx = {
|
||||
pool.submit(self._execute_single_native_tool_call, tool_call): idx
|
||||
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
|
||||
for idx, tool_call in enumerate(runnable_tool_calls)
|
||||
}
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||
|
||||
@@ -497,6 +497,50 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._list)
|
||||
|
||||
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
|
||||
def count(self, value: T) -> int:
|
||||
return self._list.count(value)
|
||||
|
||||
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._list.sort(key=key, reverse=reverse)
|
||||
|
||||
def reverse(self) -> None:
|
||||
with self._lock:
|
||||
self._list.reverse()
|
||||
|
||||
def copy(self) -> list[T]:
|
||||
return self._list.copy()
|
||||
|
||||
def __add__(self, other: list[T]) -> list[T]:
|
||||
return self._list + other
|
||||
|
||||
def __radd__(self, other: list[T]) -> list[T]:
|
||||
return other + self._list
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
|
||||
with self._lock:
|
||||
self._list += list(other)
|
||||
return self
|
||||
|
||||
def __mul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __rmul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]:
|
||||
with self._lock:
|
||||
self._list *= n
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[T]:
|
||||
return reversed(self._list)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying list contents."""
|
||||
if isinstance(other, LockedListProxy):
|
||||
@@ -579,6 +623,23 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._dict)
|
||||
|
||||
def copy(self) -> dict[str, T]:
|
||||
return self._dict.copy()
|
||||
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return self._dict | other
|
||||
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return other | self._dict
|
||||
|
||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
|
||||
with self._lock:
|
||||
self._dict |= other
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[str]:
|
||||
return reversed(self._dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying dict contents."""
|
||||
if isinstance(other, LockedDictProxy):
|
||||
@@ -620,6 +681,10 @@ class StateProxy(Generic[T]):
|
||||
if name in ("_proxy_state", "_proxy_lock"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if isinstance(value, LockedListProxy):
|
||||
value = value._list
|
||||
elif isinstance(value, LockedDictProxy):
|
||||
value = value._dict
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
setattr(object.__getattribute__(self, "_proxy_state"), name, value)
|
||||
|
||||
|
||||
@@ -408,7 +408,7 @@ def human_feedback(
|
||||
emit=list(emit) if emit else None,
|
||||
default_outcome=default_outcome,
|
||||
metadata=metadata or {},
|
||||
llm=llm if isinstance(llm, str) else None,
|
||||
llm=llm if isinstance(llm, str) else getattr(llm, "model", None),
|
||||
)
|
||||
|
||||
# Determine effective provider:
|
||||
|
||||
@@ -72,7 +72,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -136,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
@@ -163,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
@@ -213,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -248,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json, context_json
|
||||
@@ -272,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -600,7 +600,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
def _save_to_memory(self, output_text: str) -> None:
|
||||
"""Extract discrete memories from the run and remember each. No-op if _memory is None or read-only."""
|
||||
if self._memory is None or getattr(self._memory, "_read_only", False):
|
||||
if self._memory is None or self._memory.read_only:
|
||||
return
|
||||
input_str = self._get_last_user_content() or "User request"
|
||||
try:
|
||||
|
||||
@@ -22,7 +22,12 @@ if TYPE_CHECKING:
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types import (
|
||||
Message,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
@@ -31,6 +36,11 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
TOOL_SEARCH_TOOL_TYPES: Final[tuple[str, ...]] = (
|
||||
"tool_search_tool_regex_20251119",
|
||||
"tool_search_tool_bm25_20251119",
|
||||
)
|
||||
|
||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||
|
||||
@@ -117,6 +127,22 @@ class AnthropicThinkingConfig(BaseModel):
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicToolSearchConfig(BaseModel):
|
||||
"""Configuration for Anthropic's server-side tool search.
|
||||
|
||||
When enabled, tools marked with defer_loading=True are not loaded into
|
||||
context immediately. Instead, Claude uses the tool search tool to
|
||||
dynamically discover and load relevant tools on-demand.
|
||||
|
||||
Attributes:
|
||||
type: The tool search variant to use.
|
||||
- "regex": Claude constructs regex patterns to search tool names/descriptions.
|
||||
- "bm25": Claude uses natural language queries to search tools.
|
||||
"""
|
||||
|
||||
type: Literal["regex", "bm25"] = "bm25"
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -140,6 +166,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -159,6 +186,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -190,6 +221,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
@@ -432,10 +470,23 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
converted_tools = self._convert_tools_for_interference(tools)
|
||||
|
||||
# When tool_search is enabled and there are 2+ regular tools,
|
||||
# inject the search tool and mark regular tools with defer_loading.
|
||||
# With only 1 tool there's nothing to search — skip tool search
|
||||
# entirely so the normal forced tool_choice optimisation still works.
|
||||
regular_tools = [
|
||||
t
|
||||
for t in converted_tools
|
||||
if t.get("type", "") not in TOOL_SEARCH_TOOL_TYPES
|
||||
]
|
||||
if self.tool_search is not None and len(regular_tools) >= 2:
|
||||
converted_tools = self._apply_tool_search(converted_tools)
|
||||
|
||||
params["tools"] = converted_tools
|
||||
|
||||
if available_functions and len(converted_tools) == 1:
|
||||
tool_name = converted_tools[0].get("name")
|
||||
if available_functions and len(regular_tools) == 1:
|
||||
tool_name = regular_tools[0].get("name")
|
||||
if tool_name and tool_name in available_functions:
|
||||
params["tool_choice"] = {"type": "tool", "name": tool_name}
|
||||
|
||||
@@ -454,6 +505,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
# Pass through tool search tool definitions unchanged
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in TOOL_SEARCH_TOOL_TYPES:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
|
||||
if "input_schema" in tool and "name" in tool and "description" in tool:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
@@ -466,15 +523,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
anthropic_tool = {
|
||||
anthropic_tool: dict[str, Any] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
else:
|
||||
anthropic_tool["input_schema"] = { # type: ignore[assignment]
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
@@ -484,6 +541,55 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _apply_tool_search(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Inject tool search tool and mark regular tools with defer_loading.
|
||||
|
||||
When tool_search is enabled, this method:
|
||||
1. Adds the appropriate tool search tool definition (regex or bm25)
|
||||
2. Marks all regular tools with defer_loading=True so they are only
|
||||
loaded when Claude discovers them via search
|
||||
|
||||
Args:
|
||||
tools: Converted tool definitions in Anthropic format.
|
||||
|
||||
Returns:
|
||||
Updated tools list with tool search tool prepended and
|
||||
regular tools marked as deferred.
|
||||
"""
|
||||
if self.tool_search is None:
|
||||
return tools
|
||||
|
||||
# Check if a tool search tool is already present (user passed one manually)
|
||||
has_search_tool = any(
|
||||
t.get("type", "") in TOOL_SEARCH_TOOL_TYPES for t in tools
|
||||
)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
|
||||
if not has_search_tool:
|
||||
# Map config type to API type identifier
|
||||
type_map = {
|
||||
"regex": "tool_search_tool_regex_20251119",
|
||||
"bm25": "tool_search_tool_bm25_20251119",
|
||||
}
|
||||
tool_type = type_map[self.tool_search.type]
|
||||
# Tool search tool names follow the convention: tool_search_tool_{variant}
|
||||
tool_name = f"tool_search_tool_{self.tool_search.type}"
|
||||
result.append({"type": tool_type, "name": tool_name})
|
||||
|
||||
for tool in tools:
|
||||
# Don't modify tool search tools
|
||||
if tool.get("type", "") in TOOL_SEARCH_TOOL_TYPES:
|
||||
result.append(tool)
|
||||
continue
|
||||
|
||||
# Mark regular tools as deferred if not already set
|
||||
if "defer_loading" not in tool:
|
||||
tool = {**tool, "defer_loading": True}
|
||||
result.append(tool)
|
||||
|
||||
return result
|
||||
|
||||
def _extract_thinking_block(
|
||||
self, content_block: Any
|
||||
) -> ThinkingBlock | dict[str, Any] | None:
|
||||
|
||||
@@ -1781,6 +1781,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
converse_messages: list[LLMMessage] = []
|
||||
system_message: str | None = None
|
||||
pending_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
for message in formatted_messages:
|
||||
role = message.get("role")
|
||||
@@ -1794,53 +1795,56 @@ class BedrockCompletion(BaseLLM):
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = cast(str, content)
|
||||
elif role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
elif role == "tool":
|
||||
if not tool_call_id:
|
||||
raise ValueError("Tool message missing required tool_call_id")
|
||||
converse_messages.append(
|
||||
pending_tool_results.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [
|
||||
{"text": str(content) if content else ""}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
"toolResult": {
|
||||
"toolUseId": tool_call_id,
|
||||
"content": [{"text": str(content) if content else ""}],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
if pending_tool_results:
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert OpenAI-style tool_calls to Bedrock toolUse format
|
||||
bedrock_content = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_use_block = {
|
||||
"toolUse": {
|
||||
"toolUseId": tc.get("id", f"call_{id(tc)}"),
|
||||
"name": func.get("name", ""),
|
||||
"input": func.get("arguments", {})
|
||||
if isinstance(func.get("arguments"), dict)
|
||||
else json.loads(func.get("arguments", "{}") or "{}"),
|
||||
}
|
||||
}
|
||||
bedrock_content.append(tool_use_block)
|
||||
converse_messages.append(
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
)
|
||||
|
||||
if pending_tool_results:
|
||||
converse_messages.append({"role": "user", "content": pending_tool_results})
|
||||
|
||||
# CRITICAL: Handle model-specific conversation requirements
|
||||
# Cohere and some other models require conversation to end with user message
|
||||
|
||||
@@ -22,6 +22,7 @@ from crewai.mcp.config import (
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
@@ -74,10 +75,9 @@ class MCPToolResolver:
|
||||
elif isinstance(mcp_config, str):
|
||||
amp_refs.append(self._parse_amp_ref(mcp_config))
|
||||
else:
|
||||
tools, client = self._resolve_native(mcp_config)
|
||||
tools, clients = self._resolve_native(mcp_config)
|
||||
all_tools.extend(tools)
|
||||
if client:
|
||||
self._clients.append(client)
|
||||
self._clients.extend(clients)
|
||||
|
||||
if amp_refs:
|
||||
tools, clients = self._resolve_amp(amp_refs)
|
||||
@@ -131,7 +131,7 @@ class MCPToolResolver:
|
||||
all_tools: list[BaseTool] = []
|
||||
all_clients: list[Any] = []
|
||||
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {}
|
||||
|
||||
for slug in unique_slugs:
|
||||
config_dict = amp_configs_map.get(slug)
|
||||
@@ -149,10 +149,9 @@ class MCPToolResolver:
|
||||
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
try:
|
||||
tools, client = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, client)
|
||||
if client:
|
||||
all_clients.append(client)
|
||||
tools, clients = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, clients)
|
||||
all_clients.extend(clients)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -170,8 +169,9 @@ class MCPToolResolver:
|
||||
|
||||
slug_tools, _ = cached
|
||||
if specific_tool:
|
||||
sanitized = sanitize_tool_name(specific_tool)
|
||||
all_tools.extend(
|
||||
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
|
||||
t for t in slug_tools if t.name.endswith(f"_{sanitized}")
|
||||
)
|
||||
else:
|
||||
all_tools.extend(slug_tools)
|
||||
@@ -198,7 +198,6 @@ class MCPToolResolver:
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
|
||||
if response.status_code == 200:
|
||||
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
|
||||
return configs
|
||||
@@ -218,6 +217,7 @@ class MCPToolResolver:
|
||||
|
||||
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Resolve an HTTPS MCP server URL into tools."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
if "#" in mcp_ref:
|
||||
@@ -227,6 +227,7 @@ class MCPToolResolver:
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
@@ -239,7 +240,7 @@ class MCPToolResolver:
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
if sanitized_specific_tool and tool_name != sanitized_specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -271,14 +272,16 @@ class MCPToolResolver:
|
||||
)
|
||||
return []
|
||||
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
@staticmethod
|
||||
def _create_transport(
|
||||
mcp_config: MCPServerConfig,
|
||||
) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]:
|
||||
"""Create a fresh transport instance from an MCP server config.
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
Returns a ``(transport, server_name)`` tuple. Each call produces an
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -292,38 +295,54 @@ class MCPToolResolver:
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
return transport, server_name
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], list[Any]]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools.
|
||||
|
||||
Returns ``(tools, clients)`` where *clients* is always empty for
|
||||
native tools (clients are now created on-demand per invocation).
|
||||
A ``client_factory`` closure is passed to each ``MCPNativeTool`` so
|
||||
every call -- even concurrent calls to the *same* tool -- gets its
|
||||
own ``MCPClient`` + transport with no shared mutable state.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
discovery_transport, server_name = self._create_transport(mcp_config)
|
||||
discovery_client = MCPClient(
|
||||
transport=discovery_transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
if not discovery_client.connected:
|
||||
await discovery_client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
tools_list = await discovery_client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
if discovery_client.connected:
|
||||
await discovery_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
@@ -376,6 +395,13 @@ class MCPToolResolver:
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
def _client_factory() -> MCPClient:
|
||||
transport, _ = self._create_transport(mcp_config)
|
||||
return MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
@@ -396,7 +422,7 @@ class MCPToolResolver:
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
client_factory=_client_factory,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
@@ -407,10 +433,10 @@ class MCPToolResolver:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
return cast(list[BaseTool], tools), []
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
if discovery_client.connected:
|
||||
asyncio.run(discovery_client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
|
||||
@@ -3,11 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
@@ -15,22 +13,38 @@ from crewai.memory.types import (
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
)
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
class MemoryScope:
|
||||
class MemoryScope(BaseModel):
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
def __init__(self, memory: Memory, root_path: str) -> None:
|
||||
"""Initialize scope.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
root_path: Root path for this scope (e.g. /agent/1).
|
||||
"""
|
||||
self._memory = memory
|
||||
self._root = root_path.rstrip("/") or ""
|
||||
if self._root and not self._root.startswith("/"):
|
||||
self._root = "/" + self._root
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_root: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemoryScope:
|
||||
"""Extract memory dependency and normalize root path before validation."""
|
||||
if isinstance(data, MemoryScope):
|
||||
return data
|
||||
memory = data.pop("memory")
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
root = instance.root_path.rstrip("/") or ""
|
||||
if root and not root.startswith("/"):
|
||||
root = "/" + root
|
||||
instance._root = root
|
||||
return instance
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
return self._memory.read_only
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
@@ -52,7 +66,7 @@ class MemoryScope:
|
||||
importance: float | None = None,
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
) -> MemoryRecord:
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
@@ -71,7 +85,7 @@ class MemoryScope:
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "deep",
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
@@ -138,34 +152,34 @@ class MemoryScope:
|
||||
"""Return a narrower scope under this scope."""
|
||||
child = path.strip("/")
|
||||
if not child:
|
||||
return MemoryScope(self._memory, self._root or "/")
|
||||
return MemoryScope(memory=self._memory, root_path=self._root or "/")
|
||||
base = self._root.rstrip("/") or ""
|
||||
new_root = f"{base}/{child}" if base else f"/{child}"
|
||||
return MemoryScope(self._memory, new_root)
|
||||
return MemoryScope(memory=self._memory, root_path=new_root)
|
||||
|
||||
|
||||
class MemorySlice:
|
||||
class MemorySlice(BaseModel):
|
||||
"""View over multiple scopes: recall searches all, remember is a no-op when read_only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Memory,
|
||||
scopes: list[str],
|
||||
categories: list[str] | None = None,
|
||||
read_only: bool = True,
|
||||
) -> None:
|
||||
"""Initialize slice.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
scopes: List of scope paths to include.
|
||||
categories: Optional category filter for recall.
|
||||
read_only: If True, remember() is a silent no-op.
|
||||
"""
|
||||
self._memory = memory
|
||||
self._scopes = [s.rstrip("/") or "/" for s in scopes]
|
||||
self._categories = categories
|
||||
self._read_only = read_only
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemorySlice:
|
||||
"""Extract memory dependency and normalize scopes before validation."""
|
||||
if isinstance(data, MemorySlice):
|
||||
return data
|
||||
memory = data.pop("memory")
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
instance._memory = memory
|
||||
return instance
|
||||
|
||||
def remember(
|
||||
self,
|
||||
@@ -178,7 +192,7 @@ class MemorySlice:
|
||||
private: bool = False,
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember into an explicit scope. No-op when read_only=True."""
|
||||
if self._read_only:
|
||||
if self.read_only:
|
||||
return None
|
||||
return self._memory.remember(
|
||||
content,
|
||||
@@ -196,14 +210,14 @@ class MemorySlice:
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "deep",
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall across all slice scopes; results merged and re-ranked."""
|
||||
cats = categories or self._categories
|
||||
cats = categories or self.categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
matches = self._memory.recall(
|
||||
query,
|
||||
scope=sc,
|
||||
@@ -231,7 +245,7 @@ class MemorySlice:
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
return sorted(set(out))
|
||||
@@ -243,15 +257,23 @@ class MemorySlice:
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
children: list[str] = []
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
|
||||
oldest = (
|
||||
inf.oldest_record
|
||||
if oldest is None
|
||||
else min(oldest, inf.oldest_record)
|
||||
)
|
||||
if inf.newest_record:
|
||||
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
|
||||
newest = (
|
||||
inf.newest_record
|
||||
if newest is None
|
||||
else max(newest, inf.newest_record)
|
||||
)
|
||||
children.extend(inf.child_scopes)
|
||||
return ScopeInfo(
|
||||
path=path,
|
||||
@@ -265,7 +287,7 @@ class MemorySlice:
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories and counts across slice scopes."""
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self._scopes:
|
||||
for sc in self.scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
|
||||
@@ -38,7 +38,8 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -82,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -125,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -166,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -205,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -14,6 +15,7 @@ from typing import Any, ClassVar
|
||||
import lancedb
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -90,6 +92,7 @@ class LanceDBStorage:
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft < 4096:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||
@@ -99,7 +102,8 @@ class LanceDBStorage:
|
||||
self._compact_every = compact_every
|
||||
self._save_count = 0
|
||||
|
||||
# Get or create a shared write lock for this database path.
|
||||
self._lock_name = f"lancedb:{self._path.resolve()}"
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
@@ -110,10 +114,13 @@ class LanceDBStorage:
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
@@ -124,7 +131,8 @@ class LanceDBStorage:
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
self._table = self._create_table(vector_dim)
|
||||
with self._file_lock():
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
@@ -149,18 +157,14 @@ class LanceDBStorage:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _retry_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a table operation with retry on LanceDB commit conflicts.
|
||||
def _file_lock(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock for serialising writes."""
|
||||
return store_lock(self._lock_name)
|
||||
|
||||
Args:
|
||||
op: Method name on the table object (e.g. "add", "delete").
|
||||
*args, **kwargs: Passed to the table method.
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
LanceDB uses optimistic concurrency: if two transactions overlap,
|
||||
the second to commit fails with an ``OSError`` containing
|
||||
"Commit conflict". This helper retries with exponential backoff,
|
||||
refreshing the table reference before each retry so the retried
|
||||
call uses the latest committed version (not a stale reference).
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -171,20 +175,24 @@ class LanceDBStorage:
|
||||
raise
|
||||
_logger.debug(
|
||||
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
||||
op, attempt + 1, _MAX_RETRIES, delay,
|
||||
op,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES,
|
||||
delay,
|
||||
)
|
||||
# Refresh table to pick up the latest version before retrying.
|
||||
# The next getattr(self._table, op) will use the fresh table.
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass # table refresh is best-effort
|
||||
pass
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
"""Create a new table with the given vector dimension."""
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
"id": "__schema_placeholder__",
|
||||
@@ -200,8 +208,12 @@ class LanceDBStorage:
|
||||
"vector": [0.0] * vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
try:
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
except ValueError:
|
||||
table = self._db.open_table(self._table_name)
|
||||
else:
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
return table
|
||||
|
||||
def _ensure_scope_index(self) -> None:
|
||||
@@ -248,9 +260,9 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
self._table.optimize()
|
||||
# Refresh the scope index so new fragments are covered.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
@@ -280,7 +292,9 @@ class LanceDBStorage:
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"source": record.source or "",
|
||||
"private": record.private,
|
||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
||||
"vector": record.embedding
|
||||
if record.embedding
|
||||
else [0.0] * self._vector_dim,
|
||||
}
|
||||
|
||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||
@@ -296,7 +310,9 @@ class LanceDBStorage:
|
||||
id=str(row["id"]),
|
||||
content=str(row["content"]),
|
||||
scope=str(row["scope"]),
|
||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
||||
categories=json.loads(row["categories_str"])
|
||||
if row.get("categories_str")
|
||||
else [],
|
||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||
importance=float(row.get("importance", 0.5)),
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -316,16 +332,15 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", rows)
|
||||
# Create the scope index on the first save so it covers the initial dataset.
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
# Auto-compact every N saves so fragment files don't pile up.
|
||||
self._save_count += 1
|
||||
if self._compact_every > 0 and self._save_count % self._compact_every == 0:
|
||||
@@ -333,14 +348,14 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._retry_write("delete", f"id = '{safe_id}'")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
row = self._record_to_row(record)
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", [row])
|
||||
self._do_write("add", [row])
|
||||
|
||||
def touch_records(self, record_ids: list[str]) -> None:
|
||||
"""Update last_accessed to now for the given record IDs.
|
||||
@@ -354,11 +369,11 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
self._retry_write(
|
||||
self._do_write(
|
||||
"update",
|
||||
where=f"id IN ({ids_expr})",
|
||||
values={"last_accessed": now},
|
||||
@@ -390,13 +405,17 @@ class LanceDBStorage:
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
distance = row.get("_distance", 0.0)
|
||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||
@@ -416,20 +435,24 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
if categories and not any(
|
||||
c in record.categories for c in categories
|
||||
):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
@@ -438,7 +461,7 @@ class LanceDBStorage:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
@@ -450,11 +473,11 @@ class LanceDBStorage:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", "id != ''")
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", where_expr)
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
|
||||
def _scan_rows(
|
||||
@@ -528,7 +551,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix):]
|
||||
rest = sc[len(child_prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(child_prefix + first_component)
|
||||
@@ -539,7 +562,11 @@ class LanceDBStorage:
|
||||
pass
|
||||
created = row.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
||||
dt = (
|
||||
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||
if isinstance(created, str)
|
||||
else created
|
||||
)
|
||||
if isinstance(dt, datetime):
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
@@ -562,7 +589,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix):]
|
||||
rest = sc[len(prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(prefix + first_component)
|
||||
@@ -590,17 +617,19 @@ class LanceDBStorage:
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
# Dimension is preserved; table will be recreated on next save.
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
||||
with self._write_lock, self._file_lock():
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._do_write(
|
||||
"delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'"
|
||||
)
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Compact the table synchronously and refresh the scope index.
|
||||
@@ -614,8 +643,9 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
with self._write_lock, self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
self.save(records)
|
||||
|
||||
@@ -6,7 +6,9 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PlainValidator, PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -39,13 +41,18 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def _passthrough(v: Any) -> Any:
|
||||
"""PlainValidator that accepts any value, bypassing strict union discrimination."""
|
||||
return v
|
||||
|
||||
|
||||
def _default_embedder() -> OpenAIEmbeddingFunction:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
|
||||
return build_embedder(spec)
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(BaseModel):
|
||||
"""Unified memory: standalone, LLM-analyzed, with intelligent recall flow.
|
||||
|
||||
Works without agent/crew. Uses LLM to infer scope, categories, importance on save.
|
||||
@@ -53,116 +60,119 @@ class Memory:
|
||||
pluggable storage (LanceDB default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | str = "gpt-4o-mini",
|
||||
storage: StorageBackend | str = "lancedb",
|
||||
embedder: Any = None,
|
||||
# -- Scoring weights --
|
||||
# These three weights control how recall results are ranked.
|
||||
# The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance.
|
||||
# They should sum to ~1.0 for intuitive scoring.
|
||||
recency_weight: float = 0.3,
|
||||
semantic_weight: float = 0.5,
|
||||
importance_weight: float = 0.2,
|
||||
# How quickly old memories lose relevance. The recency score halves every
|
||||
# N days (exponential decay). Lower = faster forgetting; higher = longer relevance.
|
||||
recency_half_life_days: int = 30,
|
||||
# -- Consolidation --
|
||||
# When remembering new content, if an existing record has similarity >= this
|
||||
# threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable.
|
||||
consolidation_threshold: float = 0.85,
|
||||
# Max existing records to compare against when checking for consolidation.
|
||||
consolidation_limit: int = 5,
|
||||
# -- Save defaults --
|
||||
# Importance assigned to new memories when no explicit value is given and
|
||||
# the LLM analysis path is skipped (all fields provided by the caller).
|
||||
default_importance: float = 0.5,
|
||||
# -- Recall depth control --
|
||||
# These thresholds govern the RecallFlow router that decides between
|
||||
# returning results immediately ("synthesize") vs. doing an extra
|
||||
# LLM-driven exploration round ("explore_deeper").
|
||||
# confidence >= confidence_threshold_high => always synthesize
|
||||
# confidence < confidence_threshold_low => explore deeper (if budget > 0)
|
||||
# complex query + confidence < complex_query_threshold => explore deeper
|
||||
confidence_threshold_high: float = 0.8,
|
||||
confidence_threshold_low: float = 0.5,
|
||||
complex_query_threshold: float = 0.7,
|
||||
# How many LLM-driven exploration rounds the RecallFlow is allowed to run.
|
||||
# 0 = always shallow (vector search only); higher = more thorough but slower.
|
||||
exploration_budget: int = 1,
|
||||
# Queries shorter than this skip LLM analysis (saving ~1-3s).
|
||||
# Longer queries (full task descriptions) benefit from LLM distillation.
|
||||
query_analysis_threshold: int = 200,
|
||||
# When True, all write operations (remember, remember_many) are silently
|
||||
# skipped. Useful for sharing a read-only view of memory across agents
|
||||
# without any of them persisting new memories.
|
||||
read_only: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Memory.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
llm: LLM for analysis (model name or BaseLLM instance).
|
||||
storage: Backend: "lancedb" or a StorageBackend instance.
|
||||
embedder: Embedding callable, provider config dict, or None (default OpenAI).
|
||||
recency_weight: Weight for recency in the composite relevance score.
|
||||
semantic_weight: Weight for semantic similarity in the composite relevance score.
|
||||
importance_weight: Weight for importance in the composite relevance score.
|
||||
recency_half_life_days: Recency score halves every N days (exponential decay).
|
||||
consolidation_threshold: Similarity above which consolidation is triggered on save.
|
||||
consolidation_limit: Max existing records to compare during consolidation.
|
||||
default_importance: Default importance when not provided or inferred.
|
||||
confidence_threshold_high: Recall confidence above which results are returned directly.
|
||||
confidence_threshold_low: Recall confidence below which deeper exploration is triggered.
|
||||
complex_query_threshold: For complex queries, explore deeper below this confidence.
|
||||
exploration_budget: Number of LLM-driven exploration rounds during deep recall.
|
||||
query_analysis_threshold: Queries shorter than this skip LLM analysis during deep recall.
|
||||
read_only: If True, remember() and remember_many() are silent no-ops.
|
||||
"""
|
||||
self._read_only = read_only
|
||||
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="LLM for analysis (model name or BaseLLM instance).",
|
||||
)
|
||||
storage: Annotated[StorageBackend | str, PlainValidator(_passthrough)] = Field(
|
||||
default="lancedb",
|
||||
description="Storage backend instance or path string.",
|
||||
)
|
||||
embedder: Any = Field(
|
||||
default=None,
|
||||
description="Embedding callable, provider config dict, or None for default OpenAI.",
|
||||
)
|
||||
recency_weight: float = Field(
|
||||
default=0.3,
|
||||
description="Weight for recency in the composite relevance score.",
|
||||
)
|
||||
semantic_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight for semantic similarity in the composite relevance score.",
|
||||
)
|
||||
importance_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for importance in the composite relevance score.",
|
||||
)
|
||||
recency_half_life_days: int = Field(
|
||||
default=30,
|
||||
description="Recency score halves every N days (exponential decay).",
|
||||
)
|
||||
consolidation_threshold: float = Field(
|
||||
default=0.85,
|
||||
description="Similarity above which consolidation is triggered on save.",
|
||||
)
|
||||
consolidation_limit: int = Field(
|
||||
default=5,
|
||||
description="Max existing records to compare during consolidation.",
|
||||
)
|
||||
default_importance: float = Field(
|
||||
default=0.5,
|
||||
description="Default importance when not provided or inferred.",
|
||||
)
|
||||
confidence_threshold_high: float = Field(
|
||||
default=0.8,
|
||||
description="Recall confidence above which results are returned directly.",
|
||||
)
|
||||
confidence_threshold_low: float = Field(
|
||||
default=0.5,
|
||||
description="Recall confidence below which deeper exploration is triggered.",
|
||||
)
|
||||
complex_query_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="For complex queries, explore deeper below this confidence.",
|
||||
)
|
||||
exploration_budget: int = Field(
|
||||
default=1,
|
||||
description="Number of LLM-driven exploration rounds during deep recall.",
|
||||
)
|
||||
query_analysis_threshold: int = Field(
|
||||
default=200,
|
||||
description="Queries shorter than this skip LLM analysis during deep recall.",
|
||||
)
|
||||
read_only: bool = Field(
|
||||
default=False,
|
||||
description="If True, remember() and remember_many() are silent no-ops.",
|
||||
)
|
||||
|
||||
_config: MemoryConfig = PrivateAttr()
|
||||
_llm_instance: BaseLLM | None = PrivateAttr(default=None)
|
||||
_embedder_instance: Any = PrivateAttr(default=None)
|
||||
_storage: StorageBackend = PrivateAttr()
|
||||
_save_pool: ThreadPoolExecutor = PrivateAttr(
|
||||
default_factory=lambda: ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
)
|
||||
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize runtime state from field values."""
|
||||
self._config = MemoryConfig(
|
||||
recency_weight=recency_weight,
|
||||
semantic_weight=semantic_weight,
|
||||
importance_weight=importance_weight,
|
||||
recency_half_life_days=recency_half_life_days,
|
||||
consolidation_threshold=consolidation_threshold,
|
||||
consolidation_limit=consolidation_limit,
|
||||
default_importance=default_importance,
|
||||
confidence_threshold_high=confidence_threshold_high,
|
||||
confidence_threshold_low=confidence_threshold_low,
|
||||
complex_query_threshold=complex_query_threshold,
|
||||
exploration_budget=exploration_budget,
|
||||
query_analysis_threshold=query_analysis_threshold,
|
||||
recency_weight=self.recency_weight,
|
||||
semantic_weight=self.semantic_weight,
|
||||
importance_weight=self.importance_weight,
|
||||
recency_half_life_days=self.recency_half_life_days,
|
||||
consolidation_threshold=self.consolidation_threshold,
|
||||
consolidation_limit=self.consolidation_limit,
|
||||
default_importance=self.default_importance,
|
||||
confidence_threshold_high=self.confidence_threshold_high,
|
||||
confidence_threshold_low=self.confidence_threshold_low,
|
||||
complex_query_threshold=self.complex_query_threshold,
|
||||
exploration_budget=self.exploration_budget,
|
||||
query_analysis_threshold=self.query_analysis_threshold,
|
||||
)
|
||||
|
||||
# Store raw config for lazy initialization. LLM and embedder are only
|
||||
# built on first access so that Memory() never fails at construction
|
||||
# time (e.g. when auto-created by Flow without an API key set).
|
||||
self._llm_config: BaseLLM | str = llm
|
||||
self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm
|
||||
self._embedder_config: Any = embedder
|
||||
self._embedder_instance: Any = (
|
||||
embedder
|
||||
if (embedder is not None and not isinstance(embedder, dict))
|
||||
self._llm_instance = None if isinstance(self.llm, str) else self.llm
|
||||
self._embedder_instance = (
|
||||
self.embedder
|
||||
if (self.embedder is not None and not isinstance(self.embedder, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(storage, str):
|
||||
if isinstance(self.storage, str):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
|
||||
self._storage = (
|
||||
LanceDBStorage()
|
||||
if self.storage == "lancedb"
|
||||
else LanceDBStorage(path=self.storage)
|
||||
)
|
||||
else:
|
||||
self._storage = storage
|
||||
|
||||
# Background save queue. max_workers=1 serializes saves to avoid
|
||||
# concurrent storage mutations (two saves finding the same similar
|
||||
# record and both trying to update/delete it). Within each save,
|
||||
# the parallel LLM calls still run on their own thread pool.
|
||||
self._save_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
self._pending_saves: list[Future[Any]] = []
|
||||
self._pending_lock = threading.Lock()
|
||||
self._storage = self.storage
|
||||
|
||||
_MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory"
|
||||
|
||||
@@ -173,11 +183,7 @@ class Memory:
|
||||
from crewai.llm import LLM
|
||||
|
||||
try:
|
||||
model_name = (
|
||||
self._llm_config
|
||||
if isinstance(self._llm_config, str)
|
||||
else str(self._llm_config)
|
||||
)
|
||||
model_name = self.llm if isinstance(self.llm, str) else str(self.llm)
|
||||
self._llm_instance = LLM(model=model_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
@@ -197,8 +203,8 @@ class Memory:
|
||||
"""Lazy embedder initialization -- only created when first needed."""
|
||||
if self._embedder_instance is None:
|
||||
try:
|
||||
if isinstance(self._embedder_config, dict):
|
||||
self._embedder_instance = build_embedder(self._embedder_config)
|
||||
if isinstance(self.embedder, dict):
|
||||
self._embedder_instance = build_embedder(self.embedder)
|
||||
else:
|
||||
self._embedder_instance = _default_embedder()
|
||||
except Exception as e:
|
||||
@@ -356,7 +362,7 @@ class Memory:
|
||||
Raises:
|
||||
Exception: On save failure (events emitted).
|
||||
"""
|
||||
if self._read_only:
|
||||
if self.read_only:
|
||||
return None
|
||||
_source_type = "unified_memory"
|
||||
try:
|
||||
@@ -444,7 +450,7 @@ class Memory:
|
||||
Returns:
|
||||
Empty list (records are not available until the background save completes).
|
||||
"""
|
||||
if not contents or self._read_only:
|
||||
if not contents or self.read_only:
|
||||
return []
|
||||
|
||||
self._submit_save(
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
from hashlib import md5
|
||||
import os
|
||||
|
||||
from chromadb import PersistentClient
|
||||
import portalocker
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.utilities.lock_store import lock
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
@@ -25,10 +24,8 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
with lock(f"chromadb:{persist_dir}"):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
|
||||
@@ -1,29 +1,30 @@
|
||||
"""Native MCP tool wrapper for CrewAI agents.
|
||||
|
||||
This module provides a tool wrapper that reuses existing MCP client sessions
|
||||
for better performance and connection management.
|
||||
This module provides a tool wrapper that creates a fresh MCP client for every
|
||||
invocation, ensuring safe parallel execution even when the same tool is called
|
||||
concurrently by the executor.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class MCPNativeTool(BaseTool):
|
||||
"""Native MCP tool that reuses client sessions.
|
||||
"""Native MCP tool that creates a fresh client per invocation.
|
||||
|
||||
This tool wrapper is used when agents connect to MCP servers using
|
||||
structured configurations. It reuses existing client sessions for
|
||||
better performance and proper connection lifecycle management.
|
||||
|
||||
Unlike MCPToolWrapper which connects on-demand, this tool uses
|
||||
a shared MCP client instance that maintains a persistent connection.
|
||||
A ``client_factory`` callable produces an independent ``MCPClient`` +
|
||||
transport for every ``_run_async`` call. This guarantees that parallel
|
||||
invocations -- whether of the *same* tool or *different* tools from the
|
||||
same server -- never share mutable connection state (which would cause
|
||||
anyio cancel-scope errors).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_client: Any,
|
||||
client_factory: Callable[[], Any],
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
@@ -32,19 +33,16 @@ class MCPNativeTool(BaseTool):
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
client_factory: Zero-arg callable that returns a new MCPClient.
|
||||
tool_name: Name of the tool (may be prefixed).
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
@@ -57,16 +55,9 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._client_factory = client_factory
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def mcp_client(self) -> Any:
|
||||
"""Get the MCP client instance."""
|
||||
return self._mcp_client
|
||||
|
||||
@property
|
||||
def original_tool_name(self) -> str:
|
||||
@@ -108,51 +99,26 @@ class MCPNativeTool(BaseTool):
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
A fresh ``MCPClient`` is created for every invocation so that
|
||||
concurrent calls never share transport or session state.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||
# use anyio.create_task_group() which can't span different event loops
|
||||
if self._mcp_client.connected:
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
await self._mcp_client.connect()
|
||||
client = self._client_factory()
|
||||
await client.connect()
|
||||
|
||||
try:
|
||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"not connected" in error_str
|
||||
or "connection" in error_str
|
||||
or "send" in error_str
|
||||
):
|
||||
await self._mcp_client.disconnect()
|
||||
await self._mcp_client.connect()
|
||||
# Retry the call
|
||||
result = await self._mcp_client.call_tool(
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
result = await client.call_tool(self.original_tool_name, kwargs)
|
||||
finally:
|
||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||
# This prevents "exit cancel scope in different task" errors
|
||||
# All transport context managers must be exited in the same event loop they were entered
|
||||
await self._mcp_client.disconnect()
|
||||
await client.disconnect()
|
||||
|
||||
# Extract result content
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
|
||||
# Handle various result formats
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
|
||||
@@ -121,7 +121,7 @@ def create_memory_tools(memory: Any) -> list[BaseTool]:
|
||||
description=i18n.tools("recall_memory"),
|
||||
),
|
||||
]
|
||||
if not getattr(memory, "_read_only", False):
|
||||
if not memory.read_only:
|
||||
tools.append(
|
||||
RememberTool(
|
||||
memory=memory,
|
||||
|
||||
61
lib/crewai/src/crewai/utilities/lock_store.py
Normal file
61
lib/crewai/src/crewai/utilities/lock_store.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Centralised lock factory.
|
||||
|
||||
If ``REDIS_URL`` is set, locks are distributed via ``portalocker.RedisLock``. Otherwise, falls
|
||||
back to the standard ``portalocker.Lock``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from hashlib import md5
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import portalocker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import redis
|
||||
|
||||
|
||||
_REDIS_URL: str | None = os.environ.get("REDIS_URL")
|
||||
|
||||
_DEFAULT_TIMEOUT: Final[int] = 120
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _redis_connection() -> redis.Redis:
|
||||
"""Return a cached Redis connection, creating one on first call."""
|
||||
from redis import Redis
|
||||
|
||||
if _REDIS_URL is None:
|
||||
raise ValueError("REDIS_URL environment variable is not set")
|
||||
return Redis.from_url(_REDIS_URL)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""Acquire a named lock, yielding while it is held.
|
||||
|
||||
Args:
|
||||
name: A human-readable lock name (e.g. ``"chromadb_init"``).
|
||||
Automatically namespaced to avoid collisions.
|
||||
timeout: Maximum seconds to wait for the lock before raising.
|
||||
"""
|
||||
channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"
|
||||
|
||||
if _REDIS_URL:
|
||||
with portalocker.RedisLock(
|
||||
channel=channel,
|
||||
connection=_redis_connection(),
|
||||
timeout=timeout,
|
||||
):
|
||||
yield
|
||||
else:
|
||||
lock_dir = tempfile.gettempdir()
|
||||
lock_path = os.path.join(lock_dir, f"{channel}.lock")
|
||||
with portalocker.Lock(lock_path, timeout=timeout):
|
||||
yield
|
||||
@@ -2353,3 +2353,68 @@ def test_agent_without_apps_no_platform_tools():
|
||||
|
||||
tools = crew._prepare_tools(agent, task, [])
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_slug_with_specific_tool():
|
||||
"""Agent(mcps=["notion#get_page"]) must pass validation (_SLUG_RE)."""
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get_page"],
|
||||
)
|
||||
assert agent.mcps == ["notion#get_page"]
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_slug_with_hyphenated_tool():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get-page"],
|
||||
)
|
||||
assert agent.mcps == ["notion#get-page"]
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_multiple_hash_refs():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["notion#get_page", "notion#search", "github#list_repos"],
|
||||
)
|
||||
assert len(agent.mcps) == 3
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_mixed_ref_types():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=[
|
||||
"notion#get_page",
|
||||
"notion",
|
||||
"https://mcp.example.com/api",
|
||||
],
|
||||
)
|
||||
assert len(agent.mcps) == 3
|
||||
|
||||
|
||||
def test_agent_mcps_rejects_hash_without_slug():
|
||||
with pytest.raises(ValueError, match="Invalid MCP reference"):
|
||||
Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["#get_page"],
|
||||
)
|
||||
|
||||
|
||||
def test_agent_mcps_accepts_legacy_prefix_with_tool():
|
||||
agent = Agent(
|
||||
role="MCP Agent",
|
||||
goal="Test MCP validation",
|
||||
backstory="Test agent",
|
||||
mcps=["crewai-amp:notion#get_page"],
|
||||
)
|
||||
assert agent.mcps == ["crewai-amp:notion#get_page"]
|
||||
|
||||
@@ -1136,7 +1136,7 @@ def test_lite_agent_memory_instance_recall_and_save_called():
|
||||
successful_requests=1,
|
||||
)
|
||||
mock_memory = Mock()
|
||||
mock_memory._read_only = False
|
||||
mock_memory.read_only = False
|
||||
mock_memory.recall.return_value = []
|
||||
mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."]
|
||||
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"type":"tool_search_tool_bm25_20251119","name":"tool_search_tool_bm25"},{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]},"defer_loading":true},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]},"defer_loading":true},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]},"defer_loading":true},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]},"defer_loading":true},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]},"defer_loading":true},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]},"defer_loading":true},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]},"defer_loading":true},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]},"defer_loading":true},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]},"defer_loading":true},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]},"defer_loading":true},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]},"defer_loading":true},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]},"defer_loading":true},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]},"defer_loading":true},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]},"defer_loading":true},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]},"defer_loading":true}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '3952'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
x-api-key:
|
||||
- X-API-KEY-XXX
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 0.73.0
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
x-stainless-timeout:
|
||||
- NOT_GIVEN
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01DAGCoL6C12u6yAgR1UqNAs","type":"message","role":"assistant","content":[{"type":"text","text":"I''ll
|
||||
search for a weather-related tool to help you get the weather information
|
||||
for Tokyo."},{"type":"server_tool_use","id":"srvtoolu_0176qgHeeBpSygYAnUzKHCfh","name":"tool_search_tool_bm25","input":{"query":"weather
|
||||
Tokyo current conditions forecast"},"caller":{"type":"direct"}},{"type":"tool_search_tool_result","tool_use_id":"srvtoolu_0176qgHeeBpSygYAnUzKHCfh","content":{"type":"tool_search_tool_search_result","tool_references":[{"type":"tool_reference","tool_name":"get_weather"}]}},{"type":"text","text":"Great!
|
||||
I found a weather tool. Let me get the current weather conditions for Tokyo."},{"type":"tool_use","id":"toolu_01R3FavQLuTrwNvEk9gMaViK","name":"get_weather","input":{"input":"Tokyo"},"caller":{"type":"direct"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1566,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":155,"service_tier":"standard","inference_geo":"not_available","server_tool_use":{"web_search_requests":0,"web_fetch_requests":0}}}'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Security-Policy:
|
||||
- CSP-FILTERED
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sun, 08 Mar 2026 21:04:12 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Robots-Tag:
|
||||
- none
|
||||
anthropic-organization-id:
|
||||
- ANTHROPIC-ORGANIZATION-ID-XXX
|
||||
anthropic-ratelimit-input-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-input-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-input-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX
|
||||
anthropic-ratelimit-output-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-output-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-output-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX
|
||||
anthropic-ratelimit-requests-limit:
|
||||
- '20000'
|
||||
anthropic-ratelimit-requests-remaining:
|
||||
- '19999'
|
||||
anthropic-ratelimit-requests-reset:
|
||||
- '2026-03-08T21:04:07Z'
|
||||
anthropic-ratelimit-tokens-limit:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX
|
||||
anthropic-ratelimit-tokens-remaining:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX
|
||||
anthropic-ratelimit-tokens-reset:
|
||||
- ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
request-id:
|
||||
- REQUEST-ID-XXX
|
||||
strict-transport-security:
|
||||
- STS-XXX
|
||||
vary:
|
||||
- Accept-Encoding
|
||||
x-envoy-upstream-service-time:
|
||||
- '4330'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,112 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]}},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]}},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]}},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]}},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]}},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]}},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]}},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]}},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]}},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]}},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]}},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]}},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]}},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]}},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]}}]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01NoSearch001","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01NoSearch001","name":"get_weather","input":{"input":"Tokyo"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1943,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":54,"service_tier":"standard"}}'
|
||||
headers:
|
||||
Content-Type:
|
||||
- application/json
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"What is the weather
|
||||
in Tokyo?"}],"model":"claude-sonnet-4-5","stream":false,"tools":[{"type":"tool_search_tool_bm25_20251119","name":"tool_search_tool_bm25"},{"name":"get_weather","description":"Get
|
||||
current weather conditions for a specified location","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_weather"}},"required":["input"]},"defer_loading":true},{"name":"search_files","description":"Search
|
||||
through files in the workspace by name or content","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for search_files"}},"required":["input"]},"defer_loading":true},{"name":"read_database","description":"Read
|
||||
records from a database table with optional filtering","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_database"}},"required":["input"]},"defer_loading":true},{"name":"write_database","description":"Write
|
||||
or update records in a database table","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for write_database"}},"required":["input"]},"defer_loading":true},{"name":"send_email","description":"Send
|
||||
an email message to one or more recipients","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for send_email"}},"required":["input"]},"defer_loading":true},{"name":"read_email","description":"Read
|
||||
emails from inbox with filtering options","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for read_email"}},"required":["input"]},"defer_loading":true},{"name":"create_ticket","description":"Create
|
||||
a new support ticket in the ticketing system","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_ticket"}},"required":["input"]},"defer_loading":true},{"name":"update_ticket","description":"Update
|
||||
an existing support ticket status or description","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for update_ticket"}},"required":["input"]},"defer_loading":true},{"name":"list_users","description":"List
|
||||
all users in the system with optional filters","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for list_users"}},"required":["input"]},"defer_loading":true},{"name":"get_user_profile","description":"Get
|
||||
detailed profile information for a specific user","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_user_profile"}},"required":["input"]},"defer_loading":true},{"name":"deploy_service","description":"Deploy
|
||||
a service to the specified environment","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for deploy_service"}},"required":["input"]},"defer_loading":true},{"name":"rollback_service","description":"Rollback
|
||||
a service deployment to a previous version","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for rollback_service"}},"required":["input"]},"defer_loading":true},{"name":"get_service_logs","description":"Get
|
||||
service logs filtered by time range and severity","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for get_service_logs"}},"required":["input"]},"defer_loading":true},{"name":"run_sql_query","description":"Run
|
||||
a read-only SQL query against the analytics database","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for run_sql_query"}},"required":["input"]},"defer_loading":true},{"name":"create_dashboard","description":"Create
|
||||
a new monitoring dashboard with widgets","input_schema":{"type":"object","properties":{"input":{"type":"string","description":"Input
|
||||
for create_dashboard"}},"required":["input"]},"defer_loading":true}]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-5-20250929","id":"msg_01WithSearch001","type":"message","role":"assistant","content":[{"type":"text","text":"I''ll search for a weather tool."},{"type":"server_tool_use","id":"srvtoolu_01Search001","name":"tool_search_tool_bm25","input":{"query":"weather conditions"},"caller":{"type":"direct"}},{"type":"tool_search_tool_result","tool_use_id":"srvtoolu_01Search001","content":{"type":"tool_search_tool_search_result","tool_references":[{"type":"tool_reference","tool_name":"get_weather"}]}},{"type":"text","text":"Found it. Let me get the weather for Tokyo."},{"type":"tool_use","id":"toolu_01WithSearch001","name":"get_weather","input":{"input":"Tokyo"},"caller":{"type":"direct"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":1566,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":155,"service_tier":"standard"}}'
|
||||
headers:
|
||||
Content-Type:
|
||||
- application/json
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1121,3 +1121,345 @@ def test_anthropic_cached_prompt_tokens_with_tools():
|
||||
assert usage.successful_requests == 2
|
||||
# The second call should have cached prompt tokens
|
||||
assert usage.cached_prompt_tokens > 0
|
||||
|
||||
|
||||
# ---- Tool Search Tool Tests ----
|
||||
|
||||
|
||||
def test_tool_search_true_injects_bm25_and_defer_loading():
|
||||
"""tool_search=True should inject bm25 tool search and defer all tools."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculator",
|
||||
"description": "Perform math calculations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
# Should have 3 tools: tool_search + 2 regular
|
||||
assert len(tools) == 3
|
||||
|
||||
# First tool should be the bm25 tool search tool
|
||||
assert tools[0]["type"] == "tool_search_tool_bm25_20251119"
|
||||
assert tools[0]["name"] == "tool_search_tool_bm25"
|
||||
assert "input_schema" not in tools[0]
|
||||
|
||||
# All regular tools should have defer_loading=True
|
||||
for t in tools[1:]:
|
||||
assert t.get("defer_loading") is True, f"Tool {t['name']} missing defer_loading"
|
||||
|
||||
|
||||
def test_tool_search_regex_config():
|
||||
"""tool_search with regex config should use regex variant."""
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicToolSearchConfig
|
||||
|
||||
config = AnthropicToolSearchConfig(type="regex")
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=config)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_a",
|
||||
"description": "First tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_b",
|
||||
"description": "Second tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
assert tools[0]["type"] == "tool_search_tool_regex_20251119"
|
||||
assert tools[0]["name"] == "tool_search_tool_regex"
|
||||
|
||||
|
||||
def test_tool_search_disabled_by_default():
|
||||
"""tool_search=None (default) should NOT inject anything."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, crewai_tools
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
assert len(tools) == 1
|
||||
for t in tools:
|
||||
assert t.get("type", "") not in (
|
||||
"tool_search_tool_bm25_20251119",
|
||||
"tool_search_tool_regex_20251119",
|
||||
)
|
||||
assert "defer_loading" not in t
|
||||
|
||||
|
||||
def test_tool_search_no_duplicate_when_manually_provided():
|
||||
"""If user passes a tool search tool manually, don't inject a duplicate."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
# User manually includes a tool search tool
|
||||
tools_with_search = [
|
||||
{"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages, system_message, tools_with_search
|
||||
)
|
||||
|
||||
tools = params["tools"]
|
||||
search_tools = [
|
||||
t for t in tools
|
||||
if t.get("type", "").startswith("tool_search_tool")
|
||||
]
|
||||
# Should only have 1 tool search tool (the user's manual one)
|
||||
assert len(search_tools) == 1
|
||||
assert search_tools[0]["type"] == "tool_search_tool_regex_20251119"
|
||||
|
||||
|
||||
def test_tool_search_passthrough_preserves_tool_search_type():
|
||||
"""_convert_tools_for_interference should pass through tool search tools unchanged."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
|
||||
tools = [
|
||||
{"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"},
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
converted = llm._convert_tools_for_interference(tools)
|
||||
assert len(converted) == 2
|
||||
# Tool search tool should be passed through exactly
|
||||
assert converted[0] == {
|
||||
"type": "tool_search_tool_regex_20251119",
|
||||
"name": "tool_search_tool_regex",
|
||||
}
|
||||
# Regular tool should be preserved
|
||||
assert converted[1]["name"] == "get_weather"
|
||||
assert "input_schema" in converted[1]
|
||||
|
||||
|
||||
def test_tool_search_single_tool_skips_search_and_forces_choice():
|
||||
"""With only 1 tool, tool_search is skipped (nothing to search) and the
|
||||
normal forced tool_choice optimisation still applies."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
crewai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
params = llm._prepare_completion_params(
|
||||
formatted_messages,
|
||||
system_message,
|
||||
crewai_tools,
|
||||
available_functions={"test_tool": lambda q: "result"},
|
||||
)
|
||||
|
||||
# Single tool — tool_search skipped, tool_choice forced as normal
|
||||
assert "tool_choice" in params
|
||||
assert params["tool_choice"]["name"] == "test_tool"
|
||||
|
||||
# No tool search tool should be injected
|
||||
tool_types = [t.get("type", "") for t in params["tools"]]
|
||||
for ts_type in ("tool_search_tool_bm25_20251119", "tool_search_tool_regex_20251119"):
|
||||
assert ts_type not in tool_types
|
||||
|
||||
# No defer_loading on the single tool
|
||||
assert "defer_loading" not in params["tools"][0]
|
||||
|
||||
|
||||
def test_tool_search_via_llm_class():
|
||||
"""Verify tool_search param passes through LLM class correctly."""
|
||||
from crewai.llms.providers.anthropic.completion import (
|
||||
AnthropicCompletion,
|
||||
AnthropicToolSearchConfig,
|
||||
)
|
||||
|
||||
# Test with True
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.tool_search is not None
|
||||
assert llm.tool_search.type == "bm25"
|
||||
|
||||
# Test with config
|
||||
llm2 = LLM(
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
tool_search=AnthropicToolSearchConfig(type="regex"),
|
||||
)
|
||||
assert llm2.tool_search is not None
|
||||
assert llm2.tool_search.type == "regex"
|
||||
|
||||
# Test without (default)
|
||||
llm3 = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
assert llm3.tool_search is None
|
||||
|
||||
|
||||
# Many tools shared by the VCR tests below
|
||||
_MANY_TOOLS = [
|
||||
{
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string", "description": f"Input for {name}"}},
|
||||
"required": ["input"],
|
||||
},
|
||||
}
|
||||
for name, desc in [
|
||||
("get_weather", "Get current weather conditions for a specified location"),
|
||||
("search_files", "Search through files in the workspace by name or content"),
|
||||
("read_database", "Read records from a database table with optional filtering"),
|
||||
("write_database", "Write or update records in a database table"),
|
||||
("send_email", "Send an email message to one or more recipients"),
|
||||
("read_email", "Read emails from inbox with filtering options"),
|
||||
("create_ticket", "Create a new support ticket in the ticketing system"),
|
||||
("update_ticket", "Update an existing support ticket status or description"),
|
||||
("list_users", "List all users in the system with optional filters"),
|
||||
("get_user_profile", "Get detailed profile information for a specific user"),
|
||||
("deploy_service", "Deploy a service to the specified environment"),
|
||||
("rollback_service", "Rollback a service deployment to a previous version"),
|
||||
("get_service_logs", "Get service logs filtered by time range and severity"),
|
||||
("run_sql_query", "Run a read-only SQL query against the analytics database"),
|
||||
("create_dashboard", "Create a new monitoring dashboard with widgets"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_search_discovers_and_calls_tool():
|
||||
"""Tool search should discover the right tool and return a tool_use block."""
|
||||
llm = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
|
||||
result = llm.call(
|
||||
"What is the weather in Tokyo?",
|
||||
tools=_MANY_TOOLS,
|
||||
)
|
||||
|
||||
# Should return tool_use blocks (list) since no available_functions provided
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
# The discovered tool should be get_weather
|
||||
tool_names = [getattr(block, "name", None) for block in result]
|
||||
assert "get_weather" in tool_names
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_search_saves_input_tokens():
|
||||
"""Tool search with deferred loading should use fewer input tokens than loading all tools."""
|
||||
# Call WITHOUT tool search — all 15 tools loaded upfront
|
||||
llm_no_search = LLM(model="anthropic/claude-sonnet-4-5")
|
||||
llm_no_search.call("What is the weather in Tokyo?", tools=_MANY_TOOLS)
|
||||
usage_no_search = llm_no_search.get_token_usage_summary()
|
||||
|
||||
# Call WITH tool search — tools deferred
|
||||
llm_search = LLM(model="anthropic/claude-sonnet-4-5", tool_search=True)
|
||||
llm_search.call("What is the weather in Tokyo?", tools=_MANY_TOOLS)
|
||||
usage_search = llm_search.get_token_usage_summary()
|
||||
|
||||
# Tool search should use fewer input tokens
|
||||
assert usage_search.prompt_tokens < usage_no_search.prompt_tokens, (
|
||||
f"Expected tool_search ({usage_search.prompt_tokens}) to use fewer input tokens "
|
||||
f"than no search ({usage_no_search.prompt_tokens})"
|
||||
)
|
||||
|
||||
@@ -967,3 +967,211 @@ def test_bedrock_agent_kickoff_structured_output_with_tools():
|
||||
assert result.pydantic.result == 42, f"Expected result 42 but got {result.pydantic.result}"
|
||||
assert result.pydantic.operation, "Operation should not be empty"
|
||||
assert result.pydantic.explanation, "Explanation should not be empty"
|
||||
|
||||
|
||||
def test_bedrock_groups_three_tool_results():
|
||||
"""Consecutive tool results should be grouped into one Bedrock user message."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Use all three tools, then continue."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"arguments": '{"location": "New York"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tool-2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_news",
|
||||
"arguments": '{"topic": "AI"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tool-3",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_stock",
|
||||
"arguments": '{"ticker": "AMZN"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tool-1", "content": "72F and sunny"},
|
||||
{"role": "tool", "tool_call_id": "tool-2", "content": "AI news summary"},
|
||||
{"role": "tool", "tool_call_id": "tool-3", "content": "AMZN up 1.2%"},
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_converse(messages)
|
||||
|
||||
assert system_message is None
|
||||
assert [message["role"] for message in formatted_messages] == [
|
||||
"user",
|
||||
"assistant",
|
||||
"user",
|
||||
]
|
||||
assert len(formatted_messages[1]["content"]) == 3
|
||||
|
||||
tool_results = formatted_messages[2]["content"]
|
||||
assert len(tool_results) == 3
|
||||
assert [block["toolResult"]["toolUseId"] for block in tool_results] == [
|
||||
"tool-1",
|
||||
"tool-2",
|
||||
"tool-3",
|
||||
]
|
||||
assert [block["toolResult"]["content"][0]["text"] for block in tool_results] == [
|
||||
"72F and sunny",
|
||||
"AI news summary",
|
||||
"AMZN up 1.2%",
|
||||
]
|
||||
|
||||
|
||||
def test_bedrock_parallel_tool_results_grouped():
|
||||
"""Regression test for issue #4749.
|
||||
|
||||
When an assistant message contains multiple parallel tool calls,
|
||||
Bedrock requires all corresponding tool results to be grouped
|
||||
in a single user message. Previously each tool result was emitted
|
||||
as a separate user message, causing:
|
||||
ValidationException: Expected toolResult blocks at messages.2.content
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Calculate 25 + 17 AND 10 * 5"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_add",
|
||||
"type": "function",
|
||||
"function": {"name": "add_tool", "arguments": '{"a": 25, "b": 17}'},
|
||||
},
|
||||
{
|
||||
"id": "call_mul",
|
||||
"type": "function",
|
||||
"function": {"name": "multiply_tool", "arguments": '{"a": 10, "b": 5}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_add", "content": "42"},
|
||||
{"role": "tool", "tool_call_id": "call_mul", "content": "50"},
|
||||
]
|
||||
|
||||
converse_msgs, system_msg = llm._format_messages_for_converse(messages)
|
||||
|
||||
# Find the user message that contains toolResult blocks
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
|
||||
# There must be exactly ONE user message with tool results (not two)
|
||||
assert len(tool_result_messages) == 1, (
|
||||
f"Expected 1 grouped tool-result message, got {len(tool_result_messages)}. "
|
||||
"Bedrock requires all parallel tool results in a single user message."
|
||||
)
|
||||
|
||||
# That single message must contain both tool results
|
||||
tool_results = tool_result_messages[0]["content"]
|
||||
assert len(tool_results) == 2, (
|
||||
f"Expected 2 toolResult blocks in grouped message, got {len(tool_results)}"
|
||||
)
|
||||
|
||||
# Verify the tool use IDs match
|
||||
tool_use_ids = {
|
||||
block["toolResult"]["toolUseId"] for block in tool_results
|
||||
}
|
||||
assert tool_use_ids == {"call_add", "call_mul"}
|
||||
|
||||
|
||||
def test_bedrock_single_tool_result_still_works():
|
||||
"""Ensure single tool call still produces a single-block user message."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Add 1 + 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_single",
|
||||
"type": "function",
|
||||
"function": {"name": "add_tool", "arguments": '{"a": 1, "b": 2}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_single", "content": "3"},
|
||||
]
|
||||
|
||||
converse_msgs, _ = llm._format_messages_for_converse(messages)
|
||||
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
assert len(tool_result_messages) == 1
|
||||
assert len(tool_result_messages[0]["content"]) == 1
|
||||
assert tool_result_messages[0]["content"][0]["toolResult"]["toolUseId"] == "call_single"
|
||||
|
||||
|
||||
def test_bedrock_tool_results_not_merged_across_assistant_messages():
|
||||
"""Tool results from different assistant turns must NOT be merged."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "First task"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": {"name": "tool_a", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "result_a"},
|
||||
{"role": "assistant", "content": "Now doing second task"},
|
||||
{"role": "user", "content": "Second task"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": {"name": "tool_b", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "result_b"},
|
||||
]
|
||||
|
||||
converse_msgs, _ = llm._format_messages_for_converse(messages)
|
||||
|
||||
tool_result_messages = [
|
||||
m for m in converse_msgs
|
||||
if m.get("role") == "user"
|
||||
and any("toolResult" in b for b in m.get("content", []))
|
||||
]
|
||||
|
||||
# Two separate tool-result messages (one per assistant turn)
|
||||
assert len(tool_result_messages) == 2, (
|
||||
"Tool results from different assistant turns must remain separate"
|
||||
)
|
||||
assert tool_result_messages[0]["content"][0]["toolResult"]["toolUseId"] == "call_a"
|
||||
assert tool_result_messages[1]["content"][0]["toolResult"]["toolUseId"] == "call_b"
|
||||
|
||||
@@ -268,6 +268,54 @@ class TestGetMCPToolsAmpIntegration:
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "mcp_notion_so_sse_search"
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_tool_filter_with_hyphenated_hash_syntax(
|
||||
self, mock_fetch, mock_client_class, agent
|
||||
):
|
||||
"""notion#get-page must match the tool whose sanitized name is get_page."""
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
hyphenated_tool_definitions = [
|
||||
{
|
||||
"name": "get_page",
|
||||
"original_name": "get-page",
|
||||
"description": "Get a page",
|
||||
"inputSchema": {},
|
||||
},
|
||||
{
|
||||
"name": "search",
|
||||
"original_name": "search",
|
||||
"description": "Search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=hyphenated_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#get-page"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name.endswith("_get_page")
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_deduplicates_slugs(
|
||||
@@ -371,3 +419,87 @@ class TestGetMCPToolsAmpIntegration:
|
||||
mock_external.assert_called_once_with("https://external.mcp.com/api")
|
||||
# 2 from notion + 1 from external + 2 from http_config
|
||||
assert len(tools) == 5
|
||||
|
||||
|
||||
class TestResolveExternalToolFilter:
|
||||
"""Tests for _resolve_external with #tool-name filtering."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(self, agent):
|
||||
return MCPToolResolver(agent=agent, logger=agent._logger)
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_filters_hyphenated_tool_name(self, mock_schemas, resolver):
|
||||
"""https://...#get-page must match the sanitized key get_page in schemas."""
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#get-page")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "get_page" in tools[0].name
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_filters_underscored_tool_name(self, mock_schemas, resolver):
|
||||
"""https://...#get_page must also match the sanitized key get_page."""
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#get_page")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "get_page" in tools[0].name
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_returns_all_tools_without_hash(self, mock_schemas, resolver):
|
||||
mock_schemas.return_value = {
|
||||
"get_page": {
|
||||
"description": "Get a page",
|
||||
"args_schema": None,
|
||||
},
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api")
|
||||
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
|
||||
def test_returns_empty_for_nonexistent_tool(self, mock_schemas, resolver):
|
||||
mock_schemas.return_value = {
|
||||
"search": {
|
||||
"description": "Search tool",
|
||||
"args_schema": None,
|
||||
},
|
||||
}
|
||||
|
||||
tools = resolver._resolve_external("https://mcp.example.com/api#nonexistent")
|
||||
|
||||
assert len(tools) == 0
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -30,6 +31,17 @@ def mock_tool_definitions():
|
||||
]
|
||||
|
||||
|
||||
def _make_mock_client(tool_definitions):
|
||||
"""Create a mock MCPClient that returns *tool_definitions*."""
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
client.call_tool = AsyncMock(return_value="test result")
|
||||
return client
|
||||
|
||||
|
||||
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
"""Test agent setup with MCPServerStdio configuration."""
|
||||
stdio_config = MCPServerStdio(
|
||||
@@ -45,14 +57,8 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
mcps=[stdio_config],
|
||||
)
|
||||
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([stdio_config])
|
||||
|
||||
@@ -60,8 +66,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.command == "python"
|
||||
assert transport.args == ["server.py"]
|
||||
assert transport.env == {"API_KEY": "test_key"}
|
||||
@@ -83,12 +88,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
|
||||
@@ -96,8 +96,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
assert transport.streamable is True
|
||||
@@ -118,12 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
tools = agent.get_mcp_tools([sse_config])
|
||||
|
||||
@@ -131,8 +125,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
transport = mock_client_class.call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp/sse"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@@ -142,13 +135,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
@@ -160,12 +147,12 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
# 1 discovery + 1 for the run() invocation
|
||||
assert mock_client_class.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -174,13 +161,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
@@ -192,9 +173,129 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
assert mock_client_class.call_count == 2
|
||||
|
||||
|
||||
def test_each_invocation_gets_fresh_client(mock_tool_definitions):
|
||||
"""Every tool.run() must create its own MCPClient (no shared state)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
clients_created: list = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = _make_mock_client(mock_tool_definitions)
|
||||
clients_created.append(client)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
# 1 discovery client so far
|
||||
assert len(clients_created) == 1
|
||||
|
||||
# Two sequential calls to the same tool must create 2 new clients
|
||||
tools[0].run(query="q1")
|
||||
tools[0].run(query="q2")
|
||||
assert len(clients_created) == 3
|
||||
assert clients_created[1] is not clients_created[2]
|
||||
|
||||
|
||||
def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions):
|
||||
"""Parallel calls to the *same* tool must not interfere."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
|
||||
async def _call_tool(name, args):
|
||||
call_log.append(name)
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result-{name}"
|
||||
|
||||
client.call_tool = AsyncMock(side_effect=_call_tool)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) >= 1
|
||||
tool = tools[0]
|
||||
|
||||
# Call the SAME tool concurrently -- the exact scenario from the bug
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(tool.run, query="q1"),
|
||||
pool.submit(tool.run, query="q2"),
|
||||
]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all("result-" in r for r in results)
|
||||
assert len(call_log) == 2
|
||||
|
||||
|
||||
def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions):
|
||||
"""Parallel calls to different tools from the same server must not interfere."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
def _make_client(**kwargs):
|
||||
client = AsyncMock()
|
||||
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
client.connected = False
|
||||
client.connect = AsyncMock()
|
||||
client.disconnect = AsyncMock()
|
||||
|
||||
async def _call_tool(name, args):
|
||||
call_log.append(name)
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result-{name}"
|
||||
|
||||
client.call_tool = AsyncMock(side_effect=_call_tool)
|
||||
return client
|
||||
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
futures = [
|
||||
pool.submit(tools[0].run, query="q1"),
|
||||
pool.submit(tools[1].run, query="q2"),
|
||||
]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all("result-" in r for r in results)
|
||||
assert len(call_log) == 2
|
||||
|
||||
13
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
13
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Stress tests for concurrent multi-process storage access.
|
||||
|
||||
Simulates the Airflow pattern: N worker processes each writing to the
|
||||
same storage directory simultaneously. Verifies no LockException and
|
||||
data integrity after all writes complete.
|
||||
|
||||
Uses temp files for IPC instead of multiprocessing.Manager (which uses
|
||||
sockets blocked by pytest_recording).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.skip(reason="Multiprocessing tests incompatible with xdist --import-mode=importlib")
|
||||
@@ -172,8 +172,8 @@ def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
sc = mem.scope("/agent/1")
|
||||
assert sc._root in ("/agent/1", "/agent/1/")
|
||||
sl = mem.slice(["/a", "/b"], read_only=True)
|
||||
assert sl._read_only is True
|
||||
assert "/a" in sl._scopes and "/b" in sl._scopes
|
||||
assert sl.read_only is True
|
||||
assert "/a" in sl.scopes and "/b" in sl.scopes
|
||||
|
||||
|
||||
def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
@@ -198,7 +198,7 @@ def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock)
|
||||
from crewai.memory.memory_scope import MemoryScope
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder)
|
||||
scope = MemoryScope(mem, "/crew/1")
|
||||
scope = MemoryScope(memory=mem, root_path="/crew/1")
|
||||
scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={})
|
||||
results = scope.recall("note", limit=5, depth="shallow")
|
||||
assert len(results) >= 1
|
||||
@@ -213,7 +213,7 @@ def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder)
|
||||
mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={})
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
|
||||
matches = sl.recall("scope", limit=5, depth="shallow")
|
||||
assert isinstance(matches, list)
|
||||
|
||||
@@ -223,7 +223,7 @@ def test_memory_slice_remember_is_noop_when_read_only(tmp_path: Path, mock_embed
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder)
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
|
||||
result = sl.remember("x", scope="/a")
|
||||
assert result is None
|
||||
assert mem.list_records() == []
|
||||
@@ -319,7 +319,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
||||
from crewai.agents.parser import AgentFinish
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory._read_only = False
|
||||
mock_memory.read_only = False
|
||||
mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
@@ -360,7 +360,7 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory._read_only = False
|
||||
mock_memory.read_only = False
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.memory = mock_memory
|
||||
mock_agent._logger = MagicMock()
|
||||
@@ -393,7 +393,7 @@ def test_memory_scope_extract_memories_delegates() -> None:
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Scoped fact."]
|
||||
scope = MemoryScope(mock_memory, "/agent/1")
|
||||
scope = MemoryScope(memory=mock_memory, root_path="/agent/1")
|
||||
result = scope.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Scoped fact."]
|
||||
@@ -405,7 +405,7 @@ def test_memory_slice_extract_memories_delegates() -> None:
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Sliced fact."]
|
||||
sl = MemorySlice(mock_memory, ["/a", "/b"], read_only=True)
|
||||
sl = MemorySlice(memory=mock_memory, scopes=["/a", "/b"], read_only=True)
|
||||
result = sl.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Sliced fact."]
|
||||
@@ -670,10 +670,10 @@ def test_agent_kickoff_memory_recall_and_save(tmp_path: Path, mock_embedder: Mag
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Mock recall to verify it's called, but return real results
|
||||
with patch.object(mem, "recall", wraps=mem.recall) as recall_mock, \
|
||||
patch.object(mem, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
|
||||
patch.object(mem, "remember_many", wraps=mem.remember_many) as remember_many_mock:
|
||||
# Patch on the class to avoid Pydantic BaseModel __delattr__ restriction
|
||||
with patch.object(Memory, "recall", wraps=mem.recall) as recall_mock, \
|
||||
patch.object(Memory, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
|
||||
patch.object(Memory, "remember_many", wraps=mem.remember_many) as remember_many_mock:
|
||||
result = agent.kickoff("What database do we use?")
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -971,6 +971,128 @@ class TestCollapseToOutcomeJsonParsing:
|
||||
assert mock_llm.call.call_count == 2
|
||||
|
||||
|
||||
class TestLLMObjectPreservedInContext:
|
||||
"""Tests that BaseLLM objects have their model string preserved in PendingFeedbackContext."""
|
||||
|
||||
@patch("crewai.flow.flow.crewai_event_bus.emit")
|
||||
def test_basellm_object_model_string_survives_roundtrip(self, mock_emit: MagicMock) -> None:
|
||||
"""Test that when llm is a BaseLLM object, its model string is stored in context
|
||||
so that outcome collapsing works after async pause/resume.
|
||||
|
||||
This is the exact bug: locally the sync path keeps the LLM object in memory,
|
||||
but in production the async path serializes the context and the LLM object was
|
||||
discarded (stored as None), causing resume to skip classification and always
|
||||
fall back to emit[0].
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
# Create a mock BaseLLM object (not a string)
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj.model = "gemini/gemini-2.0-flash"
|
||||
|
||||
class PausingProvider:
|
||||
def __init__(self, persistence: SQLiteFlowPersistence):
|
||||
self.persistence = persistence
|
||||
self.captured_context: PendingFeedbackContext | None = None
|
||||
|
||||
def request_feedback(
|
||||
self, context: PendingFeedbackContext, flow: Flow
|
||||
) -> str:
|
||||
self.captured_context = context
|
||||
self.persistence.save_pending_feedback(
|
||||
flow_uuid=context.flow_id,
|
||||
context=context,
|
||||
state_data=flow.state if isinstance(flow.state, dict) else flow.state.model_dump(),
|
||||
)
|
||||
raise HumanFeedbackPending(context=context)
|
||||
|
||||
provider = PausingProvider(persistence)
|
||||
|
||||
class TestFlow(Flow):
|
||||
result_path: str = ""
|
||||
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Approve?",
|
||||
emit=["needs_changes", "approved"],
|
||||
llm=mock_llm_obj,
|
||||
default_outcome="approved",
|
||||
provider=provider,
|
||||
)
|
||||
def review(self):
|
||||
return "content for review"
|
||||
|
||||
@listen("approved")
|
||||
def handle_approved(self):
|
||||
self.result_path = "approved"
|
||||
return "Approved!"
|
||||
|
||||
@listen("needs_changes")
|
||||
def handle_changes(self):
|
||||
self.result_path = "needs_changes"
|
||||
return "Changes needed"
|
||||
|
||||
# Phase 1: Start flow (should pause)
|
||||
flow1 = TestFlow(persistence=persistence)
|
||||
result = flow1.kickoff()
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
|
||||
# Verify the context stored the model STRING, not None
|
||||
assert provider.captured_context is not None
|
||||
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Verify it survives persistence roundtrip
|
||||
flow_id = result.context.flow_id
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
assert loaded is not None
|
||||
_, loaded_context = loaded
|
||||
assert loaded_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Phase 2: Resume with positive feedback - should use LLM to classify
|
||||
flow2 = TestFlow.from_pending(flow_id, persistence)
|
||||
assert flow2._pending_feedback_context is not None
|
||||
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Mock _collapse_to_outcome to verify it gets called (not skipped)
|
||||
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
|
||||
flow2.resume("this looks good, proceed!")
|
||||
|
||||
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
|
||||
mock_collapse.assert_called_once_with(
|
||||
feedback="this looks good, proceed!",
|
||||
outcomes=["needs_changes", "approved"],
|
||||
llm="gemini/gemini-2.0-flash",
|
||||
)
|
||||
assert flow2.last_human_feedback.outcome == "approved"
|
||||
assert flow2.result_path == "approved"
|
||||
|
||||
def test_string_llm_still_works(self) -> None:
|
||||
"""Test that passing llm as a string still works correctly."""
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="str-llm-test",
|
||||
flow_class="test.Flow",
|
||||
method_name="review",
|
||||
method_output="output",
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
serialized = context.to_dict()
|
||||
restored = PendingFeedbackContext.from_dict(serialized)
|
||||
assert restored.llm == "gpt-4o-mini"
|
||||
|
||||
def test_none_llm_when_no_model_attr(self) -> None:
|
||||
"""Test that llm is None when object has no model attribute."""
|
||||
mock_obj = MagicMock(spec=[]) # No attributes
|
||||
|
||||
# Simulate what the decorator does
|
||||
llm_value = mock_obj if isinstance(mock_obj, str) else getattr(mock_obj, "model", None)
|
||||
assert llm_value is None
|
||||
|
||||
|
||||
class TestAsyncHumanFeedbackEdgeCases:
|
||||
"""Edge case tests for async human feedback."""
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from crewai.flow import Flow, start
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
from crewai.task import Task
|
||||
@@ -2618,9 +2618,9 @@ def test_memory_remember_called_after_task():
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
crew._memory, "extract_memories", wraps=crew._memory.extract_memories
|
||||
Memory, "extract_memories", wraps=crew._memory.extract_memories
|
||||
) as extract_mock, patch.object(
|
||||
crew._memory, "remember", wraps=crew._memory.remember
|
||||
Memory, "remember", wraps=crew._memory.remember
|
||||
) as remember_mock:
|
||||
crew.kickoff()
|
||||
|
||||
@@ -4773,13 +4773,13 @@ def test_memory_remember_receives_task_content():
|
||||
# Mock extract_memories to return fake memories and capture the raw input.
|
||||
# No wraps= needed -- the test only checks what args it receives, not the output.
|
||||
patch.object(
|
||||
crew._memory, "extract_memories", return_value=["Fake memory."]
|
||||
Memory, "extract_memories", return_value=["Fake memory."]
|
||||
) as extract_mock,
|
||||
# Mock recall to avoid LLM calls for query analysis (not in cassette).
|
||||
patch.object(crew._memory, "recall", return_value=[]),
|
||||
patch.object(Memory, "recall", return_value=[]),
|
||||
# Mock remember_many to prevent the background save from triggering
|
||||
# LLM calls (field resolution) that aren't in the cassette.
|
||||
patch.object(crew._memory, "remember_many", return_value=[]),
|
||||
patch.object(Memory, "remember_many", return_value=[]),
|
||||
):
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
@@ -1893,3 +1893,163 @@ def test_or_condition_self_listen_fires_once():
|
||||
flow = OrSelfListenFlow()
|
||||
flow.kickoff()
|
||||
assert call_count == 1
|
||||
|
||||
class ListState(BaseModel):
|
||||
items: list = []
|
||||
|
||||
|
||||
class DictState(BaseModel):
|
||||
data: dict = {}
|
||||
|
||||
|
||||
class _ListFlow(Flow[ListState]):
|
||||
@start()
|
||||
def populate(self):
|
||||
self.state.items = [3, 1, 4, 1, 5, 9, 2, 6]
|
||||
|
||||
|
||||
class _DictFlow(Flow[DictState]):
|
||||
@start()
|
||||
def populate(self):
|
||||
self.state.data = {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def _make_list_flow():
|
||||
flow = _ListFlow()
|
||||
flow.kickoff()
|
||||
return flow
|
||||
|
||||
|
||||
def _make_dict_flow():
|
||||
flow = _DictFlow()
|
||||
flow.kickoff()
|
||||
return flow
|
||||
|
||||
|
||||
def test_locked_list_proxy_index():
|
||||
flow = _make_list_flow()
|
||||
assert flow.state.items.index(4) == 2
|
||||
assert flow.state.items.index(1, 2) == 3
|
||||
|
||||
|
||||
def test_locked_list_proxy_index_missing_raises():
|
||||
flow = _make_list_flow()
|
||||
with pytest.raises(ValueError):
|
||||
flow.state.items.index(999)
|
||||
|
||||
|
||||
def test_locked_list_proxy_count():
|
||||
flow = _make_list_flow()
|
||||
assert flow.state.items.count(1) == 2
|
||||
assert flow.state.items.count(999) == 0
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort()
|
||||
assert list(flow.state.items) == [1, 1, 2, 3, 4, 5, 6, 9]
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort_reverse():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort(reverse=True)
|
||||
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort_key():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort(key=lambda x: -x)
|
||||
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
|
||||
|
||||
|
||||
def test_locked_list_proxy_reverse():
|
||||
flow = _make_list_flow()
|
||||
original = list(flow.state.items)
|
||||
flow.state.items.reverse()
|
||||
assert list(flow.state.items) == list(reversed(original))
|
||||
|
||||
|
||||
def test_locked_list_proxy_copy():
|
||||
flow = _make_list_flow()
|
||||
copied = flow.state.items.copy()
|
||||
assert copied == [3, 1, 4, 1, 5, 9, 2, 6]
|
||||
assert isinstance(copied, list)
|
||||
copied.append(999)
|
||||
assert 999 not in flow.state.items
|
||||
|
||||
|
||||
def test_locked_list_proxy_add():
|
||||
flow = _make_list_flow()
|
||||
result = flow.state.items + [10, 11]
|
||||
assert result == [3, 1, 4, 1, 5, 9, 2, 6, 10, 11]
|
||||
assert len(flow.state.items) == 8
|
||||
|
||||
|
||||
def test_locked_list_proxy_radd():
|
||||
flow = _make_list_flow()
|
||||
result = [0] + flow.state.items
|
||||
assert result[0] == 0
|
||||
assert len(result) == 9
|
||||
|
||||
|
||||
def test_locked_list_proxy_iadd():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items += [10]
|
||||
assert 10 in flow.state.items
|
||||
# Verify no deadlock: mutations must still work after +=
|
||||
flow.state.items.append(99)
|
||||
assert 99 in flow.state.items
|
||||
|
||||
|
||||
def test_locked_list_proxy_mul():
|
||||
flow = _make_list_flow()
|
||||
result = flow.state.items * 2
|
||||
assert len(result) == 16
|
||||
|
||||
|
||||
def test_locked_list_proxy_rmul():
|
||||
flow = _make_list_flow()
|
||||
result = 2 * flow.state.items
|
||||
assert len(result) == 16
|
||||
|
||||
|
||||
def test_locked_list_proxy_reversed():
|
||||
flow = _make_list_flow()
|
||||
original = list(flow.state.items)
|
||||
assert list(reversed(flow.state.items)) == list(reversed(original))
|
||||
|
||||
|
||||
def test_locked_dict_proxy_copy():
|
||||
flow = _make_dict_flow()
|
||||
copied = flow.state.data.copy()
|
||||
assert copied == {"a": 1, "b": 2, "c": 3}
|
||||
assert isinstance(copied, dict)
|
||||
copied["z"] = 99
|
||||
assert "z" not in flow.state.data
|
||||
|
||||
|
||||
def test_locked_dict_proxy_or():
|
||||
flow = _make_dict_flow()
|
||||
result = flow.state.data | {"d": 4}
|
||||
assert result == {"a": 1, "b": 2, "c": 3, "d": 4}
|
||||
assert "d" not in flow.state.data
|
||||
|
||||
|
||||
def test_locked_dict_proxy_ror():
|
||||
flow = _make_dict_flow()
|
||||
result = {"z": 0} | flow.state.data
|
||||
assert result == {"z": 0, "a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def test_locked_dict_proxy_ior():
|
||||
flow = _make_dict_flow()
|
||||
flow.state.data |= {"d": 4}
|
||||
assert flow.state.data["d"] == 4
|
||||
# Verify no deadlock: mutations must still work after |=
|
||||
flow.state.data["e"] = 5
|
||||
assert flow.state.data["e"] == 5
|
||||
|
||||
|
||||
def test_locked_dict_proxy_reversed():
|
||||
flow = _make_dict_flow()
|
||||
assert list(reversed(flow.state.data)) == ["c", "b", "a"]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.10.1"
|
||||
__version__ = "1.10.2a1"
|
||||
|
||||
Reference in New Issue
Block a user