latest version of exa

This commit is contained in:
Lorenze Jay
2025-01-30 15:09:47 -08:00
parent 199044f866
commit 90cdb48db0
2 changed files with 55 additions and 65 deletions

View File

@@ -1,47 +0,0 @@
from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class EXABaseToolToolSchema(BaseModel):
"""Input for EXABaseTool."""
search_query: str = Field(
..., description="Mandatory search query you want to use to search the internet"
)
class EXABaseTool(BaseTool):
name: str = "Search the internet"
description: str = (
"A tool that can be used to search the internet from a search_query"
)
args_schema: Type[BaseModel] = EXABaseToolToolSchema
search_url: str = "https://api.exa.ai/search"
n_results: int = None
headers: dict = {
"accept": "application/json",
"content-type": "application/json",
}
def _parse_results(self, results):
string = []
for result in results:
try:
string.append(
"\n".join(
[
f"Title: {result['title']}",
f"Score: {result['score']}",
f"Url: {result['url']}",
f"ID: {result['id']}",
"---",
]
)
)
except KeyError:
continue
content = "\n".join(string)
return f"\nSearch results: {content}\n"

View File

@@ -1,30 +1,67 @@
import os
from typing import Any
from typing import Any, Optional, Type
from pydantic import BaseModel, Field
import requests
try:
from exa_py import Exa
from .exa_base_tool import EXABaseTool
EXA_INSTALLED = True
except ImportError:
Exa = Any
EXA_INSTALLED = False
class EXASearchTool(EXABaseTool):
class EXABaseToolToolSchema(BaseModel):
search_query: str = Field(
..., description="Mandatory search query you want to use to search the internet"
)
class EXASearchTool:
args_schema: Type[BaseModel] = EXABaseToolToolSchema
client: Optional["Exa"] = Field(default=None, description="Exa search client")
def __init__(
self,
api_key: str,
content: bool = False,
highlights: bool = False,
type: str = "keyword",
use_autoprompt: bool = True,
):
if not EXA_INSTALLED:
raise ImportError("`exa-py` package not found, please run `uv add exa-py`")
self.client = Exa(api_key=api_key)
self.content = content
self.highlights = highlights
self.type = type
self.use_autoprompt = use_autoprompt
def _run(
self,
**kwargs: Any,
search_query: str,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
include_domains: Optional[list[str]] = None,
) -> Any:
search_query = kwargs.get("search_query")
if search_query is None:
search_query = kwargs.get("query")
if self.client is None:
raise ValueError("Client not initialized")
payload = {
"query": search_query,
"type": "magic",
search_params = {
"use_autoprompt": self.use_autoprompt,
"type": self.type,
}
headers = self.headers.copy()
headers["x-api-key"] = os.environ["EXA_API_KEY"]
if start_published_date:
search_params["start_published_date"] = start_published_date
if end_published_date:
search_params["end_published_date"] = end_published_date
if include_domains:
search_params["include_domains"] = include_domains
response = requests.post(self.search_url, json=payload, headers=headers)
results = response.json()
if "results" in results:
results = super()._parse_results(results["results"])
if self.content:
results = self.client.search_and_contents(
search_query, highlights=self.highlights, **search_params
)
else:
results = self.client.search(search_query, **search_params)
return results