mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Merge pull request #197 from beowolx/bugfix/serperdev-missing-country
fix(serper-dev): restore search localization parameters
This commit is contained in:
@@ -26,7 +26,10 @@ from crewai_tools import SerperDevTool
|
||||
tool = SerperDevTool(
|
||||
n_results=10, # Optional: Number of results to return (default: 10)
|
||||
save_file=False, # Optional: Save results to file (default: False)
|
||||
search_type="search" # Optional: Type of search - "search" or "news" (default: "search")
|
||||
search_type="search", # Optional: Type of search - "search" or "news" (default: "search")
|
||||
country="us", # Optional: Country for search (default: "")
|
||||
location="New York", # Optional: Location for search (default: "")
|
||||
locale="en-US" # Optional: Locale for search (default: "")
|
||||
)
|
||||
|
||||
# Execute a search
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Type
|
||||
from typing import Any, Type, Optional
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
@@ -45,6 +45,9 @@ class SerperDevTool(BaseTool):
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
search_type: str = "search"
|
||||
country: Optional[str] = ""
|
||||
location: Optional[str] = ""
|
||||
locale: Optional[str] = ""
|
||||
|
||||
def _get_search_url(self, search_type: str) -> str:
|
||||
"""Get the appropriate endpoint URL based on search type."""
|
||||
@@ -146,11 +149,20 @@ class SerperDevTool(BaseTool):
|
||||
def _make_api_request(self, search_query: str, search_type: str) -> dict:
|
||||
"""Make API request to Serper."""
|
||||
search_url = self._get_search_url(search_type)
|
||||
payload = json.dumps({"q": search_query, "num": self.n_results})
|
||||
payload = {"q": search_query, "num": self.n_results}
|
||||
|
||||
if self.country != "":
|
||||
payload["gl"] = self.country
|
||||
if self.location != "":
|
||||
payload["location"] = self.location
|
||||
if self.locale != "":
|
||||
payload["hl"] = self.locale
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": os.environ["SERPER_API_KEY"],
|
||||
"content-type": "application/json",
|
||||
}
|
||||
payload = json.dumps(payload)
|
||||
|
||||
response = None
|
||||
try:
|
||||
|
||||
151
tests/tools/serper_dev_tool_test.py
Normal file
151
tests/tools/serper_dev_tool_test.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
import os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_serper_api_key():
|
||||
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serper_tool():
|
||||
return SerperDevTool(n_results=2)
|
||||
|
||||
|
||||
def test_serper_tool_initialization():
|
||||
tool = SerperDevTool()
|
||||
assert tool.n_results == 10
|
||||
assert tool.save_file is False
|
||||
assert tool.search_type == "search"
|
||||
assert tool.country == ""
|
||||
assert tool.location == ""
|
||||
assert tool.locale == ""
|
||||
|
||||
|
||||
def test_serper_tool_custom_initialization():
|
||||
tool = SerperDevTool(
|
||||
n_results=5,
|
||||
save_file=True,
|
||||
search_type="news",
|
||||
country="US",
|
||||
location="New York",
|
||||
locale="en"
|
||||
)
|
||||
assert tool.n_results == 5
|
||||
assert tool.save_file is True
|
||||
assert tool.search_type == "news"
|
||||
assert tool.country == "US"
|
||||
assert tool.location == "New York"
|
||||
assert tool.locale == "en"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_search(mock_post):
|
||||
tool = SerperDevTool(n_results=2)
|
||||
mock_response = {
|
||||
"searchParameters": {
|
||||
"q": "test query",
|
||||
"type": "search"
|
||||
},
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title 1",
|
||||
"link": "http://test1.com",
|
||||
"snippet": "Test Description 1",
|
||||
"position": 1
|
||||
},
|
||||
{
|
||||
"title": "Test Title 2",
|
||||
"link": "http://test2.com",
|
||||
"snippet": "Test Description 2",
|
||||
"position": 2
|
||||
}
|
||||
],
|
||||
"peopleAlsoAsk": [
|
||||
{
|
||||
"question": "Test Question",
|
||||
"snippet": "Test Answer",
|
||||
"title": "Test Source",
|
||||
"link": "http://test.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value.json.return_value = mock_response
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
result = tool.run(search_query="test query")
|
||||
|
||||
assert "searchParameters" in result
|
||||
assert result["searchParameters"]["q"] == "test query"
|
||||
assert len(result["organic"]) == 2
|
||||
assert result["organic"][0]["title"] == "Test Title 1"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_news_search(mock_post):
|
||||
tool = SerperDevTool(n_results=2, search_type="news")
|
||||
mock_response = {
|
||||
"searchParameters": {
|
||||
"q": "test news",
|
||||
"type": "news"
|
||||
},
|
||||
"news": [
|
||||
{
|
||||
"title": "News Title 1",
|
||||
"link": "http://news1.com",
|
||||
"snippet": "News Description 1",
|
||||
"date": "2024-01-01",
|
||||
"source": "News Source 1",
|
||||
"imageUrl": "http://image1.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value.json.return_value = mock_response
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
result = tool.run(search_query="test news")
|
||||
|
||||
assert "news" in result
|
||||
assert len(result["news"]) == 1
|
||||
assert result["news"][0]["title"] == "News Title 1"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_with_location_params(mock_post):
|
||||
tool = SerperDevTool(
|
||||
n_results=2,
|
||||
country="US",
|
||||
location="New York",
|
||||
locale="en"
|
||||
)
|
||||
|
||||
tool.run(search_query="test")
|
||||
|
||||
called_payload = mock_post.call_args.kwargs["json"]
|
||||
assert called_payload["gl"] == "US"
|
||||
assert called_payload["location"] == "New York"
|
||||
assert called_payload["hl"] == "en"
|
||||
|
||||
|
||||
def test_invalid_search_type():
|
||||
tool = SerperDevTool()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
tool.run(search_query="test", search_type="invalid")
|
||||
assert "Invalid search type" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_api_error_handling(mock_post):
|
||||
tool = SerperDevTool()
|
||||
mock_post.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool.run(search_query="test")
|
||||
assert "API Error" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user