mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
fix(serper-dev): restore search localization parameters
- Re-add country (gl), location, and locale (hl) parameters to SerperDevTool class - Update payload construction in _make_api_request to include localization params - Add schema validation for localization parameters - Update documentation and examples to demonstrate parameter usage These parameters were accidentally removed in the previous enhancement PR and are crucial for: - Getting region-specific search results (via country/gl) - Targeting searches to specific cities (via location) - Getting results in specific languages (via locale/hl) BREAKING CHANGE: None - This restores previously available functionality
This commit is contained in:
@@ -26,7 +26,10 @@ from crewai_tools import SerperDevTool
|
|||||||
tool = SerperDevTool(
|
tool = SerperDevTool(
|
||||||
n_results=10, # Optional: Number of results to return (default: 10)
|
n_results=10, # Optional: Number of results to return (default: 10)
|
||||||
save_file=False, # Optional: Save results to file (default: False)
|
save_file=False, # Optional: Save results to file (default: False)
|
||||||
search_type="search" # Optional: Type of search - "search" or "news" (default: "search")
|
search_type="search", # Optional: Type of search - "search" or "news" (default: "search")
|
||||||
|
country="us", # Optional: Country for search (default: "")
|
||||||
|
location="New York", # Optional: Location for search (default: "")
|
||||||
|
locale="en-US" # Optional: Locale for search (default: "")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute a search
|
# Execute a search
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Type
|
from typing import Any, Type, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
@@ -45,6 +45,9 @@ class SerperDevTool(BaseTool):
|
|||||||
n_results: int = 10
|
n_results: int = 10
|
||||||
save_file: bool = False
|
save_file: bool = False
|
||||||
search_type: str = "search"
|
search_type: str = "search"
|
||||||
|
country: Optional[str] = ""
|
||||||
|
location: Optional[str] = ""
|
||||||
|
locale: Optional[str] = ""
|
||||||
|
|
||||||
def _get_search_url(self, search_type: str) -> str:
|
def _get_search_url(self, search_type: str) -> str:
|
||||||
"""Get the appropriate endpoint URL based on search type."""
|
"""Get the appropriate endpoint URL based on search type."""
|
||||||
@@ -146,11 +149,20 @@ class SerperDevTool(BaseTool):
|
|||||||
def _make_api_request(self, search_query: str, search_type: str) -> dict:
|
def _make_api_request(self, search_query: str, search_type: str) -> dict:
|
||||||
"""Make API request to Serper."""
|
"""Make API request to Serper."""
|
||||||
search_url = self._get_search_url(search_type)
|
search_url = self._get_search_url(search_type)
|
||||||
payload = json.dumps({"q": search_query, "num": self.n_results})
|
payload = {"q": search_query, "num": self.n_results}
|
||||||
|
|
||||||
|
if self.country != "":
|
||||||
|
payload["gl"] = self.country
|
||||||
|
if self.location != "":
|
||||||
|
payload["location"] = self.location
|
||||||
|
if self.locale != "":
|
||||||
|
payload["hl"] = self.locale
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"X-API-KEY": os.environ["SERPER_API_KEY"],
|
"X-API-KEY": os.environ["SERPER_API_KEY"],
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
payload = json.dumps(payload)
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
151
tests/tools/serper_dev_tool_test.py
Normal file
151
tests/tools/serper_dev_tool_test.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_serper_api_key():
|
||||||
|
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def serper_tool():
|
||||||
|
return SerperDevTool(n_results=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serper_tool_initialization():
|
||||||
|
tool = SerperDevTool()
|
||||||
|
assert tool.n_results == 10
|
||||||
|
assert tool.save_file is False
|
||||||
|
assert tool.search_type == "search"
|
||||||
|
assert tool.country == ""
|
||||||
|
assert tool.location == ""
|
||||||
|
assert tool.locale == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_serper_tool_custom_initialization():
|
||||||
|
tool = SerperDevTool(
|
||||||
|
n_results=5,
|
||||||
|
save_file=True,
|
||||||
|
search_type="news",
|
||||||
|
country="US",
|
||||||
|
location="New York",
|
||||||
|
locale="en"
|
||||||
|
)
|
||||||
|
assert tool.n_results == 5
|
||||||
|
assert tool.save_file is True
|
||||||
|
assert tool.search_type == "news"
|
||||||
|
assert tool.country == "US"
|
||||||
|
assert tool.location == "New York"
|
||||||
|
assert tool.locale == "en"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_serper_tool_search(mock_post):
|
||||||
|
tool = SerperDevTool(n_results=2)
|
||||||
|
mock_response = {
|
||||||
|
"searchParameters": {
|
||||||
|
"q": "test query",
|
||||||
|
"type": "search"
|
||||||
|
},
|
||||||
|
"organic": [
|
||||||
|
{
|
||||||
|
"title": "Test Title 1",
|
||||||
|
"link": "http://test1.com",
|
||||||
|
"snippet": "Test Description 1",
|
||||||
|
"position": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Test Title 2",
|
||||||
|
"link": "http://test2.com",
|
||||||
|
"snippet": "Test Description 2",
|
||||||
|
"position": 2
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"peopleAlsoAsk": [
|
||||||
|
{
|
||||||
|
"question": "Test Question",
|
||||||
|
"snippet": "Test Answer",
|
||||||
|
"title": "Test Source",
|
||||||
|
"link": "http://test.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_post.return_value.json.return_value = mock_response
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
|
||||||
|
result = tool.run(search_query="test query")
|
||||||
|
|
||||||
|
assert "searchParameters" in result
|
||||||
|
assert result["searchParameters"]["q"] == "test query"
|
||||||
|
assert len(result["organic"]) == 2
|
||||||
|
assert result["organic"][0]["title"] == "Test Title 1"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_serper_tool_news_search(mock_post):
|
||||||
|
tool = SerperDevTool(n_results=2, search_type="news")
|
||||||
|
mock_response = {
|
||||||
|
"searchParameters": {
|
||||||
|
"q": "test news",
|
||||||
|
"type": "news"
|
||||||
|
},
|
||||||
|
"news": [
|
||||||
|
{
|
||||||
|
"title": "News Title 1",
|
||||||
|
"link": "http://news1.com",
|
||||||
|
"snippet": "News Description 1",
|
||||||
|
"date": "2024-01-01",
|
||||||
|
"source": "News Source 1",
|
||||||
|
"imageUrl": "http://image1.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_post.return_value.json.return_value = mock_response
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
|
||||||
|
result = tool.run(search_query="test news")
|
||||||
|
|
||||||
|
assert "news" in result
|
||||||
|
assert len(result["news"]) == 1
|
||||||
|
assert result["news"][0]["title"] == "News Title 1"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_serper_tool_with_location_params(mock_post):
|
||||||
|
tool = SerperDevTool(
|
||||||
|
n_results=2,
|
||||||
|
country="US",
|
||||||
|
location="New York",
|
||||||
|
locale="en"
|
||||||
|
)
|
||||||
|
|
||||||
|
tool.run(search_query="test")
|
||||||
|
|
||||||
|
called_payload = mock_post.call_args.kwargs["json"]
|
||||||
|
assert called_payload["gl"] == "US"
|
||||||
|
assert called_payload["location"] == "New York"
|
||||||
|
assert called_payload["hl"] == "en"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_search_type():
|
||||||
|
tool = SerperDevTool()
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
tool.run(search_query="test", search_type="invalid")
|
||||||
|
assert "Invalid search type" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_api_error_handling(mock_post):
|
||||||
|
tool = SerperDevTool()
|
||||||
|
mock_post.side_effect = Exception("API Error")
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
tool.run(search_query="test")
|
||||||
|
assert "API Error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user