add error handling

This commit is contained in:
siddas27
2024-11-30 21:36:28 -06:00
parent 6c242ef3bb
commit d168b8e245
3 changed files with 96 additions and 31 deletions

View File

@@ -25,6 +25,18 @@ class BraveSearchToolSchema(BaseModel):
class BraveSearchTool(BaseTool):
"""
BraveSearchTool - A tool for performing web searches using the Brave Search API.
This module provides functionality to search the internet using Brave's Search API,
supporting customizable result counts and country-specific searches.
Dependencies:
- requests
- pydantic
- python-dotenv (for API key management)
"""
name: str = "Search the internet"
description: str = (
"A tool that can be used to search the internet with a search_query."
@@ -35,48 +47,64 @@ class BraveSearchTool(BaseTool):
n_results: int = 10
save_file: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "BRAVE_API_KEY" not in os.environ:
raise ValueError(
"BRAVE_API_KEY environment variable is required for BraveSearchTool"
)
def _run(
self,
**kwargs: Any,
) -> Any:
search_query = kwargs.get("search_query") or kwargs.get("query")
save_file = kwargs.get("save_file", self.save_file)
n_results = kwargs.get("n_results", self.n_results)
try:
search_query = kwargs.get("search_query") or kwargs.get("query")
if not search_query:
raise ValueError("Search query is required")
payload = {"q": search_query, "count": n_results}
save_file = kwargs.get("save_file", self.save_file)
n_results = kwargs.get("n_results", self.n_results)
if self.country != "":
payload["country"] = self.country
payload = {"q": search_query, "count": n_results}
headers = {
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
"Accept": "application/json",
}
if self.country != "":
payload["country"] = self.country
response = requests.get(self.search_url, headers=headers, params=payload)
results = response.json()
headers = {
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
"Accept": "application/json",
}
if "web" in results:
results = results["web"]["results"]
string = []
for result in results:
try:
string.append(
"\n".join(
[
f"Title: {result['title']}",
f"Link: {result['url']}",
f"Snippet: {result['description']}",
"---",
]
response = requests.get(self.search_url, headers=headers, params=payload)
response.raise_for_status() # Handle non-200 responses
results = response.json()
if "web" in results:
results = results["web"]["results"]
string = []
for result in results:
try:
string.append(
"\n".join(
[
f"Title: {result['title']}",
f"Link: {result['url']}",
f"Snippet: {result['description']}",
"---",
]
)
)
)
except KeyError:
continue
except KeyError:
continue
content = "\n".join(string)
if save_file:
_save_results_to_file(content)
except requests.RequestException as e:
return f"Error performing search: {str(e)}"
except KeyError as e:
return f"Error parsing search results: {str(e)}"
if save_file:
_save_results_to_file(content)
return f"\nSearch results: {content}\n"
else:
return results
return content

View File

@@ -1,6 +1,41 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
@pytest.fixture
def brave_tool():
return BraveSearchTool(n_results=2)
def test_brave_tool_initialization():
tool = BraveSearchTool()
assert tool.n_results == 10
assert tool.save_file is False
@patch("requests.get")
def test_brave_tool_search(mock_get, brave_tool):
mock_response = {
"web": {
"results": [
{
"title": "Test Title",
"url": "http://test.com",
"description": "Test Description",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
result = brave_tool.run(search_query="test")
assert "Test Title" in result
assert "http://test.com" in result
def test_brave_tool():
tool = BraveSearchTool(
n_results=2,
@@ -11,3 +46,5 @@ def test_brave_tool():
if __name__ == "__main__":
test_brave_tool()
test_brave_tool_initialization()
# test_brave_tool_search(brave_tool)