diff --git a/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py index e13f3823c..415810c1b 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py @@ -1,12 +1,17 @@ from datetime import datetime +import json import os import time -from typing import Any, ClassVar +from typing import Annotated, Any, ClassVar, Literal from crewai.tools import BaseTool, EnvVar +from dotenv import load_dotenv from pydantic import BaseModel, Field +from pydantic.types import StringConstraints import requests +load_dotenv() + def _save_results_to_file(content: str) -> None: """Saves the search results to a file.""" @@ -15,37 +20,72 @@ def _save_results_to_file(content: str) -> None: file.write(content) -class BraveSearchToolSchema(BaseModel): - """Input for BraveSearchTool.""" +FreshnessPreset = Literal["pd", "pw", "pm", "py"] +FreshnessRange = Annotated[ + str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$") +] +Freshness = FreshnessPreset | FreshnessRange +SafeSearch = Literal["off", "moderate", "strict"] - search_query: str = Field( - ..., description="Mandatory search query you want to use to search the internet" + +class BraveSearchToolSchema(BaseModel): + """Input for BraveSearchTool""" + + query: str = Field(..., description="Search query to perform") + country: str | None = Field( + default=None, + description="Country code for geo-targeting (e.g., 'US', 'BR').", + ) + search_language: str | None = Field( + default=None, + description="Language code for the search results (e.g., 'en', 'es').", + ) + count: int | None = Field( + default=None, + description="The maximum number of results to return. Actual number may be less.", + ) + offset: int | None = Field( + default=None, description="Skip the first N result sets/pages. Max is 9." + ) + safesearch: SafeSearch | None = Field( + default=None, + description="Filter out explicit content. Options: off/moderate/strict", + ) + spellcheck: bool | None = Field( + default=None, + description="Attempt to correct spelling errors in the search query.", + ) + freshness: Freshness | None = Field( + default=None, + description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD", + ) + text_decorations: bool | None = Field( + default=None, + description="Include markup to highlight search terms in the results.", + ) + extra_snippets: bool | None = Field( + default=None, + description="Include up to 5 text snippets for each page if possible.", + ) + operators: bool | None = Field( + default=None, + description="Whether to apply search operators (e.g., site:example.com).", ) +# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.) class BraveSearchTool(BaseTool): - """BraveSearchTool - A tool for performing web searches using the Brave Search API. + """A tool that performs web searches using the Brave Search API.""" - This module provides functionality to search the internet using Brave's Search API, - supporting customizable result counts and country-specific searches. - - Dependencies: - - requests - - pydantic - - python-dotenv (for API key management) - """ - - name: str = "Brave Web Search the internet" + name: str = "Brave Search" description: str = ( - "A tool that can be used to search the internet with a search_query." + "A tool that performs web searches using the Brave Search API. " + "Results are returned as structured JSON data." ) args_schema: type[BaseModel] = BraveSearchToolSchema search_url: str = "https://api.search.brave.com/res/v1/web/search" - country: str | None = "" n_results: int = 10 save_file: bool = False - _last_request_time: ClassVar[float] = 0 - _min_request_interval: ClassVar[float] = 1.0 # seconds env_vars: list[EnvVar] = Field( default_factory=lambda: [ EnvVar( @@ -55,6 +95,9 @@ class BraveSearchTool(BaseTool): ), ] ) + # Rate limiting parameters + _last_request_time: ClassVar[float] = 0 + _min_request_interval: ClassVar[float] = 1.0 # seconds def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -73,19 +116,64 @@ class BraveSearchTool(BaseTool): self._min_request_interval - (current_time - self._last_request_time) ) BraveSearchTool._last_request_time = time.time() + + # Construct and send the request try: - search_query = kwargs.get("search_query") or kwargs.get("query") - if not search_query: - raise ValueError("Search query is required") + # Maintain both "search_query" and "query" for backwards compatibility + query = kwargs.get("search_query") or kwargs.get("query") + if not query: + raise ValueError("Query is required") + + payload = {"q": query} + + if country := kwargs.get("country"): + payload["country"] = country + + if search_language := kwargs.get("search_language"): + payload["search_language"] = search_language + + # Fallback to deprecated n_results parameter if no count is provided + count = kwargs.get("count") + if count is not None: + payload["count"] = count + else: + payload["count"] = self.n_results + + # Offset may be 0, so avoid truthiness check + offset = kwargs.get("offset") + if offset is not None: + payload["offset"] = offset + + if safesearch := kwargs.get("safesearch"): + payload["safesearch"] = safesearch save_file = kwargs.get("save_file", self.save_file) - n_results = kwargs.get("n_results", self.n_results) + if freshness := kwargs.get("freshness"): + payload["freshness"] = freshness - payload = {"q": search_query, "count": n_results} + # Boolean parameters + spellcheck = kwargs.get("spellcheck") + if spellcheck is not None: + payload["spellcheck"] = spellcheck - if self.country != "": - payload["country"] = self.country + text_decorations = kwargs.get("text_decorations") + if text_decorations is not None: + payload["text_decorations"] = text_decorations + extra_snippets = kwargs.get("extra_snippets") + if extra_snippets is not None: + payload["extra_snippets"] = extra_snippets + + operators = kwargs.get("operators") + if operators is not None: + payload["operators"] = operators + + # Limit the result types to "web" since there is presently no + # handling of other types like "discussions", "faq", "infobox", + # "news", "videos", or "locations". + payload["result_filter"] = "web" + + # Setup Request Headers headers = { "X-Subscription-Token": os.environ["BRAVE_API_KEY"], "Accept": "application/json", @@ -97,25 +185,32 @@ class BraveSearchTool(BaseTool): response.raise_for_status() # Handle non-200 responses results = response.json() + # TODO: Handle other result types like "discussions", "faq", etc. + web_results_items = [] if "web" in results: - results = results["web"]["results"] - string = [] - for result in results: - try: - string.append( - "\n".join( - [ - f"Title: {result['title']}", - f"Link: {result['url']}", - f"Snippet: {result['description']}", - "---", - ] - ) - ) - except KeyError: # noqa: PERF203 - continue + web_results = results["web"]["results"] - content = "\n".join(string) + for result in web_results: + url = result.get("url") + title = result.get("title") + # If, for whatever reason, this entry does not have a title + # or url, skip it. + if not url or not title: + continue + item = { + "url": url, + "title": title, + } + description = result.get("description") + if description: + item["description"] = description + snippets = result.get("extra_snippets") + if snippets: + item["snippets"] = snippets + + web_results_items.append(item) + + content = json.dumps(web_results_items) except requests.RequestException as e: return f"Error performing search: {e!s}" except KeyError as e: diff --git a/lib/crewai-tools/tests/tools/brave_search_tool_test.py b/lib/crewai-tools/tests/tools/brave_search_tool_test.py index c1c32d830..361086abe 100644 --- a/lib/crewai-tools/tests/tools/brave_search_tool_test.py +++ b/lib/crewai-tools/tests/tools/brave_search_tool_test.py @@ -1,8 +1,10 @@ +import json from unittest.mock import patch -from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool import pytest +from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool + @pytest.fixture def brave_tool(): @@ -30,16 +32,43 @@ def test_brave_tool_search(mock_get, brave_tool): } mock_get.return_value.json.return_value = mock_response - result = brave_tool.run(search_query="test") + result = brave_tool.run(query="test") assert "Test Title" in result assert "http://test.com" in result -def test_brave_tool(): - tool = BraveSearchTool( - n_results=2, - ) - tool.run(search_query="ChatGPT") +@patch("requests.get") +def test_brave_tool(mock_get): + mock_response = { + "web": { + "results": [ + { + "title": "Brave Browser", + "url": "https://brave.com", + "description": "Brave Browser description", + } + ] + } + } + mock_get.return_value.json.return_value = mock_response + + tool = BraveSearchTool(n_results=2) + result = tool.run(query="Brave Browser") + assert result is not None + + # Parse JSON so we can examine the structure + data = json.loads(result) + assert isinstance(data, list) + assert len(data) >= 1 + + # First item should have expected fields: title, url, and description + first = data[0] + assert "title" in first + assert first["title"] == "Brave Browser" + assert "url" in first + assert first["url"] == "https://brave.com" + assert "description" in first + assert first["description"] == "Brave Browser description" if __name__ == "__main__":