Files
crewAI/crewai_tools/tools/arxiv_paper_tool/arxiv_paper_tool.py
Greyson Lalonde e16606672a Squashed 'packages/tools/' content from commit 78317b9c
git-subtree-dir: packages/tools
git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
2025-09-12 21:58:02 -04:00

153 lines
6.5 KiB
Python

import re
import time
import urllib.request
import urllib.parse
import urllib.error
import xml.etree.ElementTree as ET
from typing import Type, List, Optional, ClassVar
from pydantic import BaseModel, Field
from crewai.tools import BaseTool,EnvVar
import logging
from pathlib import Path
logger = logging.getLogger(__file__)
class ArxivToolInput(BaseModel):
search_query: str = Field(..., description="Search query for Arxiv, e.g., 'transformer neural network'")
max_results: int = Field(5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100")
class ArxivPaperTool(BaseTool):
BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query"
SLEEP_DURATION: ClassVar[int] = 1
SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300
ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}"
REQUEST_TIMEOUT: ClassVar[int] = 10
name: str = "Arxiv Paper Fetcher and Downloader"
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
args_schema: Type[BaseModel] = ArxivToolInput
model_config = {"extra": "allow"}
package_dependencies: List[str] = ["pydantic"]
env_vars: List[EnvVar] = []
def __init__(self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False):
super().__init__()
self.download_pdfs = download_pdfs
self.save_dir = save_dir
self.use_title_as_filename = use_title_as_filename
def _run(self, search_query: str, max_results: int = 5) -> str:
try:
args = ArxivToolInput(search_query=search_query, max_results=max_results)
logger.info(f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
f"use_title_as_filename={self.use_title_as_filename}")
papers = self.fetch_arxiv_data(args.search_query, args.max_results)
if self.download_pdfs:
save_dir = self._validate_save_path(self.save_dir)
for paper in papers:
if paper['pdf_url']:
if self.use_title_as_filename:
safe_title = re.sub(r'[\\/*?:"<>|]', "_", paper['title']).strip()
filename_base = safe_title or paper['arxiv_id']
else:
filename_base = paper['arxiv_id']
filename = f"{filename_base[:500]}.pdf"
save_path = Path(save_dir) / filename
self.download_pdf(paper['pdf_url'], save_path)
time.sleep(self.SLEEP_DURATION)
results = [self._format_paper_result(p) for p in papers]
return "\n\n" + "-" * 80 + "\n\n".join(results)
except Exception as e:
logger.error(f"ArxivTool Error: {str(e)}")
return f"Failed to fetch or download Arxiv papers: {str(e)}"
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]:
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
logger.info(f"Fetching data from Arxiv API: {api_url}")
try:
with urllib.request.urlopen(api_url, timeout=self.REQUEST_TIMEOUT) as response:
if response.status != 200:
raise Exception(f"HTTP {response.status}: {response.reason}")
data = response.read().decode('utf-8')
except urllib.error.URLError as e:
logger.error(f"Error fetching data from Arxiv: {e}")
raise
root = ET.fromstring(data)
papers = []
for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
raw_id = self._get_element_text(entry, "id")
arxiv_id = raw_id.split('/')[-1].replace('.', '_') if raw_id else "unknown"
title = self._get_element_text(entry, "title") or "No Title"
summary = self._get_element_text(entry, "summary") or "No Summary"
published = self._get_element_text(entry, "published") or "No Publish Date"
authors = [
self._get_element_text(author, "name") or "Unknown"
for author in entry.findall(self.ATOM_NAMESPACE + "author")
]
pdf_url = self._extract_pdf_url(entry)
papers.append({
"arxiv_id": arxiv_id,
"title": title,
"summary": summary,
"authors": authors,
"published_date": published,
"pdf_url": pdf_url
})
return papers
@staticmethod
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]:
elem = entry.find(f'{ArxivPaperTool.ATOM_NAMESPACE}{element_name}')
return elem.text.strip() if elem is not None and elem.text else None
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]:
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
if link.attrib.get('title', '').lower() == 'pdf':
return link.attrib.get('href')
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
href = link.attrib.get('href')
if href and 'pdf' in href:
return href
return None
def _format_paper_result(self, paper: dict) -> str:
summary = (paper['summary'][:self.SUMMARY_TRUNCATE_LENGTH] + '...') \
if len(paper['summary']) > self.SUMMARY_TRUNCATE_LENGTH else paper['summary']
authors_str = ', '.join(paper['authors'])
return (f"Title: {paper['title']}\n"
f"Authors: {authors_str}\n"
f"Published: {paper['published_date']}\n"
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
f"Summary: {summary}")
@staticmethod
def _validate_save_path(path: str) -> Path:
save_path = Path(path).resolve()
save_path.mkdir(parents=True, exist_ok=True)
return save_path
def download_pdf(self, pdf_url: str, save_path: str):
try:
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
urllib.request.urlretrieve(pdf_url, str(save_path))
logger.info(f"PDF saved: {save_path}")
except urllib.error.URLError as e:
logger.error(f"Network error occurred while downloading {pdf_url}: {e}")
raise
except OSError as e:
logger.error(f"File save error for {save_path}: {e}")
raise