mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fixing serper tool
This commit is contained in:
@@ -39,8 +39,8 @@ class ScrapeWebsiteTool(BaseTool):
|
|||||||
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
|
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
website_url = kwargs.get('website_url', self.website_url)
|
website_url = kwargs.get('website_url', self.website_url)
|
||||||
page = requests.get(
|
page = requests.get(
|
||||||
@@ -49,9 +49,11 @@ class ScrapeWebsiteTool(BaseTool):
|
|||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
cookies=self.cookies if self.cookies else {}
|
cookies=self.cookies if self.cookies else {}
|
||||||
)
|
)
|
||||||
parsed = BeautifulSoup(page.content, "html.parser")
|
|
||||||
|
page.encoding = page.apparent_encoding
|
||||||
|
parsed = BeautifulSoup(page.text, "html.parser")
|
||||||
|
|
||||||
text = parsed.get_text()
|
text = parsed.get_text()
|
||||||
text = '\n'.join([i for i in text.split('\n') if i.strip() != ''])
|
text = '\n'.join([i for i in text.split('\n') if i.strip() != ''])
|
||||||
text = ' '.join([i for i in text.split(' ') if i.strip() != ''])
|
text = ' '.join([i for i in text.split(' ') if i.strip() != ''])
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
@@ -7,11 +8,11 @@ from pydantic.v1 import BaseModel, Field
|
|||||||
from crewai_tools.tools.base_tool import BaseTool
|
from crewai_tools.tools.base_tool import BaseTool
|
||||||
|
|
||||||
def _save_results_to_file(content: str) -> None:
|
def _save_results_to_file(content: str) -> None:
|
||||||
"""Saves the search results to a file."""
|
"""Saves the search results to a file."""
|
||||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||||
with open(filename, 'w') as file:
|
with open(filename, 'w') as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
print(f"Results saved to {filename}")
|
print(f"Results saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
class SerperDevToolSchema(BaseModel):
|
class SerperDevToolSchema(BaseModel):
|
||||||
@@ -23,11 +24,11 @@ class SerperDevTool(BaseTool):
|
|||||||
description: str = "A tool that can be used to search the internet with a search_query."
|
description: str = "A tool that can be used to search the internet with a search_query."
|
||||||
args_schema: Type[BaseModel] = SerperDevToolSchema
|
args_schema: Type[BaseModel] = SerperDevToolSchema
|
||||||
search_url: str = "https://google.serper.dev/search"
|
search_url: str = "https://google.serper.dev/search"
|
||||||
country: Optional[str] = None
|
country: Optional[str] = ''
|
||||||
location: Optional[str] = None
|
location: Optional[str] = ''
|
||||||
locale: Optional[str] = None
|
locale: Optional[str] = ''
|
||||||
n_results: int = Field(default=10, description="Number of search results to return")
|
n_results: int = 10
|
||||||
save_file: bool = Field(default=False, description="Flag to determine whether to save the results to a file")
|
save_file: bool = False
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
@@ -39,18 +40,24 @@ class SerperDevTool(BaseTool):
|
|||||||
n_results = kwargs.get('n_results', self.n_results)
|
n_results = kwargs.get('n_results', self.n_results)
|
||||||
|
|
||||||
payload = { "q": search_query, "num": n_results }
|
payload = { "q": search_query, "num": n_results }
|
||||||
payload["gl"] = self.country if self.country
|
|
||||||
payload["location"] = self.country if self.location
|
if self.country != '':
|
||||||
payload["hl"] = self.country if self.locale
|
payload["gl"] = self.country
|
||||||
|
if self.location != '':
|
||||||
|
payload["location"] = self.location
|
||||||
|
if self.locale != '':
|
||||||
|
payload["hl"] = self.locale
|
||||||
|
|
||||||
payload = json.dumps(payload)
|
payload = json.dumps(payload)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'X-API-KEY': os.environ['SERPER_API_KEY'],
|
'X-API-KEY': os.environ['SERPER_API_KEY'],
|
||||||
'content-type': 'application/json'
|
'content-type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("POST", self.search_url, headers=headers, data=payload)
|
response = requests.request("POST", self.search_url, headers=headers, data=payload)
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
if 'organic' in results:
|
if 'organic' in results:
|
||||||
results = results['organic'][:self.n_results]
|
results = results['organic'][:self.n_results]
|
||||||
string = []
|
string = []
|
||||||
@@ -67,7 +74,7 @@ class SerperDevTool(BaseTool):
|
|||||||
|
|
||||||
content = '\n'.join(string)
|
content = '\n'.join(string)
|
||||||
if save_file:
|
if save_file:
|
||||||
_save_results_to_file(content)
|
_save_results_to_file(content)
|
||||||
return f"\nSearch results: {content}\n"
|
return f"\nSearch results: {content}\n"
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user