Merge pull request #198 from crewAIInc/add/updated-exa

latest version of exa supported
This commit is contained in:
Lorenze Jay
2025-01-31 09:40:32 -08:00
committed by GitHub
3 changed files with 86 additions and 67 deletions

View File

@@ -6,7 +6,7 @@ This tool is designed to perform a semantic search for a specified query from a
## Installation
To incorporate this tool into your project, follow the installation instructions below:
```shell
pip install 'crewai[tools]'
uv add crewai[tools] exa_py
```
## Example
@@ -16,7 +16,7 @@ The following example demonstrates how to initialize the tool and execute a sear
from crewai_tools import EXASearchTool
# Initialize the tool for internet searching capabilities
tool = EXASearchTool()
tool = EXASearchTool(api_key="your_api_key")
```
## Steps to Get Started

View File

@@ -1,47 +0,0 @@
from typing import Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class EXABaseToolToolSchema(BaseModel):
"""Input for EXABaseTool."""
search_query: str = Field(
..., description="Mandatory search query you want to use to search the internet"
)
class EXABaseTool(BaseTool):
name: str = "Search the internet"
description: str = (
"A tool that can be used to search the internet from a search_query"
)
args_schema: Type[BaseModel] = EXABaseToolToolSchema
search_url: str = "https://api.exa.ai/search"
n_results: int = None
headers: dict = {
"accept": "application/json",
"content-type": "application/json",
}
def _parse_results(self, results):
string = []
for result in results:
try:
string.append(
"\n".join(
[
f"Title: {result['title']}",
f"Score: {result['score']}",
f"Url: {result['url']}",
f"ID: {result['id']}",
"---",
]
)
)
except KeyError:
continue
content = "\n".join(string)
return f"\nSearch results: {content}\n"

View File

@@ -1,30 +1,96 @@
import os
from typing import Any
from typing import Any, Optional, Type
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
import requests
try:
from exa_py import Exa
from .exa_base_tool import EXABaseTool
EXA_INSTALLED = True
except ImportError:
Exa = Any
EXA_INSTALLED = False
class EXASearchTool(EXABaseTool):
class EXABaseToolSchema(BaseModel):
search_query: str = Field(
..., description="Mandatory search query you want to use to search the internet"
)
start_published_date: Optional[str] = Field(
None, description="Start date for the search"
)
end_published_date: Optional[str] = Field(
None, description="End date for the search"
)
include_domains: Optional[list[str]] = Field(
None, description="List of domains to include in the search"
)
class EXASearchTool(BaseTool):
model_config = {"arbitrary_types_allowed": True}
name: str = "EXASearchTool"
description: str = "Search the internet using Exa"
args_schema: Type[BaseModel] = EXABaseToolSchema
client: Optional["Exa"] = None
content: Optional[bool] = False
summary: Optional[bool] = False
type: Optional[str] = "auto"
def __init__(
self,
api_key: str,
content: Optional[bool] = False,
summary: Optional[bool] = False,
type: Optional[str] = "auto",
**kwargs,
):
super().__init__(
**kwargs,
)
if not EXA_INSTALLED:
import click
if click.confirm(
"You are missing the 'exa_py' package. Would you like to install it?"
):
import subprocess
subprocess.run(["uv", "add", "exa_py"], check=True)
else:
raise ImportError(
"You are missing the 'exa_py' package. Would you like to install it?"
)
self.client = Exa(api_key=api_key)
self.content = content
self.summary = summary
self.type = type
def _run(
self,
**kwargs: Any,
search_query: str,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
include_domains: Optional[list[str]] = None,
) -> Any:
search_query = kwargs.get("search_query")
if search_query is None:
search_query = kwargs.get("query")
if self.client is None:
raise ValueError("Client not initialized")
payload = {
"query": search_query,
"type": "magic",
search_params = {
"type": self.type,
}
headers = self.headers.copy()
headers["x-api-key"] = os.environ["EXA_API_KEY"]
if start_published_date:
search_params["start_published_date"] = start_published_date
if end_published_date:
search_params["end_published_date"] = end_published_date
if include_domains:
search_params["include_domains"] = include_domains
response = requests.post(self.search_url, json=payload, headers=headers)
results = response.json()
if "results" in results:
results = super()._parse_results(results["results"])
if self.content:
results = self.client.search_and_contents(
search_query, summary=self.summary, **search_params
)
else:
results = self.client.search(search_query, **search_params)
return results