From 3808f98c14738b86dae67748b23694965d22fcf6 Mon Sep 17 00:00:00 2001 From: Luis Cardoso Date: Tue, 28 Jan 2025 10:11:46 +0100 Subject: [PATCH] fix(serper-dev): restore search localization parameters - Re-add country (gl), location, and locale (hl) parameters to SerperDevTool class - Update payload construction in _make_api_request to include localization params - Add schema validation for localization parameters - Update documentation and examples to demonstrate parameter usage These parameters were accidentally removed in the previous enhancement PR and are crucial for: - Getting region-specific search results (via country/gl) - Targeting searches to specific cities (via location) - Getting results in specific languages (via locale/hl) BREAKING CHANGE: None - This restores previously available functionality --- .../tools/serper_dev_tool/README.md | 5 +- .../tools/serper_dev_tool/serper_dev_tool.py | 16 +- tests/tools/serper_dev_tool_test.py | 151 ++++++++++++++++++ 3 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 tests/tools/serper_dev_tool_test.py diff --git a/src/crewai_tools/tools/serper_dev_tool/README.md b/src/crewai_tools/tools/serper_dev_tool/README.md index 0beb9f2ab..06f1abd56 100644 --- a/src/crewai_tools/tools/serper_dev_tool/README.md +++ b/src/crewai_tools/tools/serper_dev_tool/README.md @@ -26,7 +26,10 @@ from crewai_tools import SerperDevTool tool = SerperDevTool( n_results=10, # Optional: Number of results to return (default: 10) save_file=False, # Optional: Save results to file (default: False) - search_type="search" # Optional: Type of search - "search" or "news" (default: "search") + search_type="search", # Optional: Type of search - "search" or "news" (default: "search") + country="us", # Optional: Country for search (default: "") + location="New York", # Optional: Location for search (default: "") + locale="en-US" # Optional: Locale for search (default: "") ) # Execute a search diff --git a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py index 2db347190..629016189 100644 --- a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py +++ b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py @@ -2,7 +2,7 @@ import datetime import json import logging import os -from typing import Any, Type +from typing import Any, Type, Optional import requests from crewai.tools import BaseTool @@ -45,6 +45,9 @@ class SerperDevTool(BaseTool): n_results: int = 10 save_file: bool = False search_type: str = "search" + country: Optional[str] = "" + location: Optional[str] = "" + locale: Optional[str] = "" def _get_search_url(self, search_type: str) -> str: """Get the appropriate endpoint URL based on search type.""" @@ -146,11 +149,20 @@ class SerperDevTool(BaseTool): def _make_api_request(self, search_query: str, search_type: str) -> dict: """Make API request to Serper.""" search_url = self._get_search_url(search_type) - payload = json.dumps({"q": search_query, "num": self.n_results}) + payload = {"q": search_query, "num": self.n_results} + + if self.country != "": + payload["gl"] = self.country + if self.location != "": + payload["location"] = self.location + if self.locale != "": + payload["hl"] = self.locale + headers = { "X-API-KEY": os.environ["SERPER_API_KEY"], "content-type": "application/json", } + payload = json.dumps(payload) response = None try: diff --git a/tests/tools/serper_dev_tool_test.py b/tests/tools/serper_dev_tool_test.py new file mode 100644 index 000000000..d02f0606e --- /dev/null +++ b/tests/tools/serper_dev_tool_test.py @@ -0,0 +1,151 @@ +from unittest.mock import patch +import pytest +from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool +import os + + +@pytest.fixture(autouse=True) +def mock_serper_api_key(): + with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}): + yield + + +@pytest.fixture +def serper_tool(): + return SerperDevTool(n_results=2) + + +def test_serper_tool_initialization(): + tool = SerperDevTool() + assert tool.n_results == 10 + assert tool.save_file is False + assert tool.search_type == "search" + assert tool.country == "" + assert tool.location == "" + assert tool.locale == "" + + +def test_serper_tool_custom_initialization(): + tool = SerperDevTool( + n_results=5, + save_file=True, + search_type="news", + country="US", + location="New York", + locale="en" + ) + assert tool.n_results == 5 + assert tool.save_file is True + assert tool.search_type == "news" + assert tool.country == "US" + assert tool.location == "New York" + assert tool.locale == "en" + + +@patch("requests.post") +def test_serper_tool_search(mock_post): + tool = SerperDevTool(n_results=2) + mock_response = { + "searchParameters": { + "q": "test query", + "type": "search" + }, + "organic": [ + { + "title": "Test Title 1", + "link": "http://test1.com", + "snippet": "Test Description 1", + "position": 1 + }, + { + "title": "Test Title 2", + "link": "http://test2.com", + "snippet": "Test Description 2", + "position": 2 + } + ], + "peopleAlsoAsk": [ + { + "question": "Test Question", + "snippet": "Test Answer", + "title": "Test Source", + "link": "http://test.com" + } + ] + } + mock_post.return_value.json.return_value = mock_response + mock_post.return_value.status_code = 200 + + result = tool.run(search_query="test query") + + assert "searchParameters" in result + assert result["searchParameters"]["q"] == "test query" + assert len(result["organic"]) == 2 + assert result["organic"][0]["title"] == "Test Title 1" + + +@patch("requests.post") +def test_serper_tool_news_search(mock_post): + tool = SerperDevTool(n_results=2, search_type="news") + mock_response = { + "searchParameters": { + "q": "test news", + "type": "news" + }, + "news": [ + { + "title": "News Title 1", + "link": "http://news1.com", + "snippet": "News Description 1", + "date": "2024-01-01", + "source": "News Source 1", + "imageUrl": "http://image1.com" + } + ] + } + mock_post.return_value.json.return_value = mock_response + mock_post.return_value.status_code = 200 + + result = tool.run(search_query="test news") + + assert "news" in result + assert len(result["news"]) == 1 + assert result["news"][0]["title"] == "News Title 1" + + +@patch("requests.post") +def test_serper_tool_with_location_params(mock_post): + tool = SerperDevTool( + n_results=2, + country="US", + location="New York", + locale="en" + ) + + tool.run(search_query="test") + + called_payload = mock_post.call_args.kwargs["json"] + assert called_payload["gl"] == "US" + assert called_payload["location"] == "New York" + assert called_payload["hl"] == "en" + + +def test_invalid_search_type(): + tool = SerperDevTool() + with pytest.raises(ValueError) as exc_info: + tool.run(search_query="test", search_type="invalid") + assert "Invalid search type" in str(exc_info.value) + + +@patch("requests.post") +def test_api_error_handling(mock_post): + tool = SerperDevTool() + mock_post.side_effect = Exception("API Error") + + with pytest.raises(Exception) as exc_info: + tool.run(search_query="test") + assert "API Error" in str(exc_info.value) + + +if __name__ == "__main__": + pytest.main([__file__])