diff --git a/src/crewai_tools/tools/exa_tools/exa_base_tool.py b/src/crewai_tools/tools/exa_tools/exa_base_tool.py deleted file mode 100644 index 295b283ad..000000000 --- a/src/crewai_tools/tools/exa_tools/exa_base_tool.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Type - -from crewai.tools import BaseTool -from pydantic import BaseModel, Field - - -class EXABaseToolToolSchema(BaseModel): - """Input for EXABaseTool.""" - - search_query: str = Field( - ..., description="Mandatory search query you want to use to search the internet" - ) - - -class EXABaseTool(BaseTool): - name: str = "Search the internet" - description: str = ( - "A tool that can be used to search the internet from a search_query" - ) - args_schema: Type[BaseModel] = EXABaseToolToolSchema - search_url: str = "https://api.exa.ai/search" - n_results: int = None - headers: dict = { - "accept": "application/json", - "content-type": "application/json", - } - - def _parse_results(self, results): - string = [] - for result in results: - try: - string.append( - "\n".join( - [ - f"Title: {result['title']}", - f"Score: {result['score']}", - f"Url: {result['url']}", - f"ID: {result['id']}", - "---", - ] - ) - ) - except KeyError: - continue - - content = "\n".join(string) - return f"\nSearch results: {content}\n" diff --git a/src/crewai_tools/tools/exa_tools/exa_search_tool.py b/src/crewai_tools/tools/exa_tools/exa_search_tool.py index 6724c2417..6681e8d1b 100644 --- a/src/crewai_tools/tools/exa_tools/exa_search_tool.py +++ b/src/crewai_tools/tools/exa_tools/exa_search_tool.py @@ -1,30 +1,67 @@ -import os -from typing import Any +from typing import Any, Optional, Type +from pydantic import BaseModel, Field -import requests +try: + from exa_py import Exa -from .exa_base_tool import EXABaseTool + EXA_INSTALLED = True +except ImportError: + Exa = Any + EXA_INSTALLED = False -class EXASearchTool(EXABaseTool): +class EXABaseToolToolSchema(BaseModel): + search_query: str = Field( + ..., description="Mandatory search query you want to use to search the internet" + ) + + +class EXASearchTool: + args_schema: Type[BaseModel] = EXABaseToolToolSchema + client: Optional["Exa"] = Field(default=None, description="Exa search client") + + def __init__( + self, + api_key: str, + content: bool = False, + highlights: bool = False, + type: str = "keyword", + use_autoprompt: bool = True, + ): + if not EXA_INSTALLED: + raise ImportError("`exa-py` package not found, please run `uv add exa-py`") + self.client = Exa(api_key=api_key) + self.content = content + self.highlights = highlights + self.type = type + self.use_autoprompt = use_autoprompt + def _run( self, - **kwargs: Any, + search_query: str, + start_published_date: Optional[str] = None, + end_published_date: Optional[str] = None, + include_domains: Optional[list[str]] = None, ) -> Any: - search_query = kwargs.get("search_query") - if search_query is None: - search_query = kwargs.get("query") + if self.client is None: + raise ValueError("Client not initialized") - payload = { - "query": search_query, - "type": "magic", + search_params = { + "use_autoprompt": self.use_autoprompt, + "type": self.type, } - headers = self.headers.copy() - headers["x-api-key"] = os.environ["EXA_API_KEY"] + if start_published_date: + search_params["start_published_date"] = start_published_date + if end_published_date: + search_params["end_published_date"] = end_published_date + if include_domains: + search_params["include_domains"] = include_domains - response = requests.post(self.search_url, json=payload, headers=headers) - results = response.json() - if "results" in results: - results = super()._parse_results(results["results"]) + if self.content: + results = self.client.search_and_contents( + search_query, highlights=self.highlights, **search_params + ) + else: + results = self.client.search(search_query, **search_params) return results