Merge pull request #197 from beowolx/bugfix/serperdev-missing-country

fix(serper-dev): restore search localization parameters
This commit is contained in:
João Moura
2025-03-21 14:54:50 -03:00
committed by GitHub
3 changed files with 169 additions and 3 deletions

View File

@@ -26,7 +26,10 @@ from crewai_tools import SerperDevTool
tool = SerperDevTool(
n_results=10, # Optional: Number of results to return (default: 10)
save_file=False, # Optional: Save results to file (default: False)
search_type="search" # Optional: Type of search - "search" or "news" (default: "search")
search_type="search", # Optional: Type of search - "search" or "news" (default: "search")
country="us", # Optional: Country for search (default: "")
location="New York", # Optional: Location for search (default: "")
locale="en-US" # Optional: Locale for search (default: "")
)
# Execute a search

View File

@@ -2,7 +2,7 @@ import datetime
import json
import logging
import os
from typing import Any, Type
from typing import Any, Type, Optional
import requests
from crewai.tools import BaseTool
@@ -45,6 +45,9 @@ class SerperDevTool(BaseTool):
n_results: int = 10
save_file: bool = False
search_type: str = "search"
country: Optional[str] = ""
location: Optional[str] = ""
locale: Optional[str] = ""
def _get_search_url(self, search_type: str) -> str:
"""Get the appropriate endpoint URL based on search type."""
@@ -146,11 +149,20 @@ class SerperDevTool(BaseTool):
def _make_api_request(self, search_query: str, search_type: str) -> dict:
"""Make API request to Serper."""
search_url = self._get_search_url(search_type)
payload = json.dumps({"q": search_query, "num": self.n_results})
payload = {"q": search_query, "num": self.n_results}
if self.country != "":
payload["gl"] = self.country
if self.location != "":
payload["location"] = self.location
if self.locale != "":
payload["hl"] = self.locale
headers = {
"X-API-KEY": os.environ["SERPER_API_KEY"],
"content-type": "application/json",
}
payload = json.dumps(payload)
response = None
try:

View File

@@ -0,0 +1,151 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
import os
@pytest.fixture(autouse=True)
def mock_serper_api_key():
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
yield
@pytest.fixture
def serper_tool():
return SerperDevTool(n_results=2)
def test_serper_tool_initialization():
tool = SerperDevTool()
assert tool.n_results == 10
assert tool.save_file is False
assert tool.search_type == "search"
assert tool.country == ""
assert tool.location == ""
assert tool.locale == ""
def test_serper_tool_custom_initialization():
tool = SerperDevTool(
n_results=5,
save_file=True,
search_type="news",
country="US",
location="New York",
locale="en"
)
assert tool.n_results == 5
assert tool.save_file is True
assert tool.search_type == "news"
assert tool.country == "US"
assert tool.location == "New York"
assert tool.locale == "en"
@patch("requests.post")
def test_serper_tool_search(mock_post):
tool = SerperDevTool(n_results=2)
mock_response = {
"searchParameters": {
"q": "test query",
"type": "search"
},
"organic": [
{
"title": "Test Title 1",
"link": "http://test1.com",
"snippet": "Test Description 1",
"position": 1
},
{
"title": "Test Title 2",
"link": "http://test2.com",
"snippet": "Test Description 2",
"position": 2
}
],
"peopleAlsoAsk": [
{
"question": "Test Question",
"snippet": "Test Answer",
"title": "Test Source",
"link": "http://test.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test query")
assert "searchParameters" in result
assert result["searchParameters"]["q"] == "test query"
assert len(result["organic"]) == 2
assert result["organic"][0]["title"] == "Test Title 1"
@patch("requests.post")
def test_serper_tool_news_search(mock_post):
tool = SerperDevTool(n_results=2, search_type="news")
mock_response = {
"searchParameters": {
"q": "test news",
"type": "news"
},
"news": [
{
"title": "News Title 1",
"link": "http://news1.com",
"snippet": "News Description 1",
"date": "2024-01-01",
"source": "News Source 1",
"imageUrl": "http://image1.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test news")
assert "news" in result
assert len(result["news"]) == 1
assert result["news"][0]["title"] == "News Title 1"
@patch("requests.post")
def test_serper_tool_with_location_params(mock_post):
tool = SerperDevTool(
n_results=2,
country="US",
location="New York",
locale="en"
)
tool.run(search_query="test")
called_payload = mock_post.call_args.kwargs["json"]
assert called_payload["gl"] == "US"
assert called_payload["location"] == "New York"
assert called_payload["hl"] == "en"
def test_invalid_search_type():
tool = SerperDevTool()
with pytest.raises(ValueError) as exc_info:
tool.run(search_query="test", search_type="invalid")
assert "Invalid search type" in str(exc_info.value)
@patch("requests.post")
def test_api_error_handling(mock_post):
tool = SerperDevTool()
mock_post.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
tool.run(search_query="test")
assert "API Error" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__])