From d168b8e24554e37a706d0af18c4b82af483fd442 Mon Sep 17 00:00:00 2001 From: siddas27 Date: Sat, 30 Nov 2024 21:36:28 -0600 Subject: [PATCH] add error handling --- .../tools/brave_search_tool/__init__.py | 0 .../brave_search_tool/brave_search_tool.py | 90 ++++++++++++------- tests/tools/brave_search_tool_test.py | 37 ++++++++ 3 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 src/crewai_tools/tools/brave_search_tool/__init__.py diff --git a/src/crewai_tools/tools/brave_search_tool/__init__.py b/src/crewai_tools/tools/brave_search_tool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py b/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py index 54f546f1e..6a8818d75 100644 --- a/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py +++ b/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py @@ -25,6 +25,18 @@ class BraveSearchToolSchema(BaseModel): class BraveSearchTool(BaseTool): + """ + BraveSearchTool - A tool for performing web searches using the Brave Search API. + + This module provides functionality to search the internet using Brave's Search API, + supporting customizable result counts and country-specific searches. + + Dependencies: + - requests + - pydantic + - python-dotenv (for API key management) + """ + name: str = "Search the internet" description: str = ( "A tool that can be used to search the internet with a search_query." @@ -35,48 +47,64 @@ class BraveSearchTool(BaseTool): n_results: int = 10 save_file: bool = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "BRAVE_API_KEY" not in os.environ: + raise ValueError( + "BRAVE_API_KEY environment variable is required for BraveSearchTool" + ) + def _run( self, **kwargs: Any, ) -> Any: - search_query = kwargs.get("search_query") or kwargs.get("query") - save_file = kwargs.get("save_file", self.save_file) - n_results = kwargs.get("n_results", self.n_results) + try: + search_query = kwargs.get("search_query") or kwargs.get("query") + if not search_query: + raise ValueError("Search query is required") - payload = {"q": search_query, "count": n_results} + save_file = kwargs.get("save_file", self.save_file) + n_results = kwargs.get("n_results", self.n_results) - if self.country != "": - payload["country"] = self.country + payload = {"q": search_query, "count": n_results} - headers = { - "X-Subscription-Token": os.environ["BRAVE_API_KEY"], - "Accept": "application/json", - } + if self.country != "": + payload["country"] = self.country - response = requests.get(self.search_url, headers=headers, params=payload) - results = response.json() + headers = { + "X-Subscription-Token": os.environ["BRAVE_API_KEY"], + "Accept": "application/json", + } - if "web" in results: - results = results["web"]["results"] - string = [] - for result in results: - try: - string.append( - "\n".join( - [ - f"Title: {result['title']}", - f"Link: {result['url']}", - f"Snippet: {result['description']}", - "---", - ] + response = requests.get(self.search_url, headers=headers, params=payload) + response.raise_for_status() # Handle non-200 responses + results = response.json() + + if "web" in results: + results = results["web"]["results"] + string = [] + for result in results: + try: + string.append( + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['url']}", + f"Snippet: {result['description']}", + "---", + ] + ) ) - ) - except KeyError: - continue + except KeyError: + continue content = "\n".join(string) - if save_file: - _save_results_to_file(content) + except requests.RequestException as e: + return f"Error performing search: {str(e)}" + except KeyError as e: + return f"Error parsing search results: {str(e)}" + if save_file: + _save_results_to_file(content) return f"\nSearch results: {content}\n" else: - return results + return content diff --git a/tests/tools/brave_search_tool_test.py b/tests/tools/brave_search_tool_test.py index 16c1bcb92..969bd48fe 100644 --- a/tests/tools/brave_search_tool_test.py +++ b/tests/tools/brave_search_tool_test.py @@ -1,6 +1,41 @@ +from unittest.mock import patch + +import pytest + from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool +@pytest.fixture +def brave_tool(): + return BraveSearchTool(n_results=2) + + +def test_brave_tool_initialization(): + tool = BraveSearchTool() + assert tool.n_results == 10 + assert tool.save_file is False + + +@patch("requests.get") +def test_brave_tool_search(mock_get, brave_tool): + mock_response = { + "web": { + "results": [ + { + "title": "Test Title", + "url": "http://test.com", + "description": "Test Description", + } + ] + } + } + mock_get.return_value.json.return_value = mock_response + + result = brave_tool.run(search_query="test") + assert "Test Title" in result + assert "http://test.com" in result + + def test_brave_tool(): tool = BraveSearchTool( n_results=2, @@ -11,3 +46,5 @@ def test_brave_tool(): if __name__ == "__main__": test_brave_tool() + test_brave_tool_initialization() + # test_brave_tool_search(brave_tool)