mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Adding Arxiv Paper tool (#310)
* arxiv_paper_tool.py * Updating as per the review * Update __init__.py * Update __init__.py * Update arxiv_paper_tool.py * added test cases * Create README.md * Create Examples.md * Update Examples.md * Updated logger * Updated with package_dependencies,env_vars
This commit is contained in:
@@ -10,6 +10,7 @@ from .aws import (
|
||||
from .tools import (
|
||||
AIMindTool,
|
||||
ApifyActorsTool,
|
||||
ArxivPaperTool,
|
||||
BraveSearchTool,
|
||||
BrowserbaseLoadTool,
|
||||
CodeDocsSearchTool,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from .apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from .arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from .brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool
|
||||
from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool
|
||||
|
||||
80
src/crewai_tools/tools/arxiv_paper_tool/Examples.md
Normal file
80
src/crewai_tools/tools/arxiv_paper_tool/Examples.md
Normal file
@@ -0,0 +1,80 @@
|
||||
### Example 1: Fetching Research Papers from arXiv with CrewAI
|
||||
|
||||
This example demonstrates how to build a simple CrewAI workflow that automatically searches for and downloads academic papers from [arXiv.org](https://arxiv.org). The setup uses:
|
||||
|
||||
* A custom `ArxivPaperTool` to fetch metadata and download PDFs
|
||||
* A single `Agent` tasked with locating relevant papers based on a given research topic
|
||||
* A `Task` to define the data retrieval and download process
|
||||
* A sequential `Crew` to orchestrate execution
|
||||
|
||||
The downloaded PDFs are saved to a local directory (`./DOWNLOADS`). Filenames are optionally based on sanitized paper titles, ensuring compatibility with your operating system.
|
||||
|
||||
> The saved PDFs can be further used in **downstream tasks**, such as:
|
||||
>
|
||||
> * **RAG (Retrieval-Augmented Generation)**
|
||||
> * **Summarization**
|
||||
> * **Citation extraction**
|
||||
> * **Embedding-based search or analysis**
|
||||
|
||||
---
|
||||
|
||||
|
||||
```
|
||||
from crewai import Agent, Task, Crew, Process, LLM
|
||||
from crewai_tools import ArxivPaperTool
|
||||
|
||||
|
||||
|
||||
llm = LLM(
|
||||
model="ollama/llama3.1",
|
||||
base_url="http://localhost:11434",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
topic = "Crew AI"
|
||||
max_results = 3
|
||||
save_dir = "./DOWNLOADS"
|
||||
use_title_as_filename = True
|
||||
|
||||
tool = ArxivPaperTool(
|
||||
download_pdfs=True,
|
||||
save_dir=save_dir,
|
||||
use_title_as_filename=True
|
||||
)
|
||||
tool.result_as_answer = True #Required,otherwise
|
||||
|
||||
|
||||
arxiv_paper_fetch = Agent(
|
||||
role="Arxiv Data Fetcher",
|
||||
goal=f"Retrieve relevant papers from arXiv based on a research topic {topic} and maximum number of papers to be downloaded is{max_results},try to use title as filename {use_title_as_filename} and download PDFs to {save_dir},",
|
||||
backstory="An expert in scientific data retrieval, skilled in extracting academic content from arXiv.",
|
||||
# tools=[ArxivPaperTool()],
|
||||
llm=llm,
|
||||
verbose=True,
|
||||
allow_delegation=False
|
||||
)
|
||||
fetch_task = Task(
|
||||
description=(
|
||||
f"Search arXiv for the topic '{topic}' and fetch up to {max_results} papers. "
|
||||
f"Download PDFs for analysis and store them at {save_dir}."
|
||||
),
|
||||
expected_output="PDFs saved to disk for downstream agents.",
|
||||
agent=arxiv_paper_fetch,
|
||||
tools=[tool], # Use the actual tool instance here
|
||||
|
||||
)
|
||||
|
||||
|
||||
pdf_qa_crew = Crew(
|
||||
agents=[arxiv_paper_fetch],
|
||||
tasks=[fetch_task],
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
result = pdf_qa_crew.kickoff()
|
||||
|
||||
print(f"\n🤖 Answer:\n\n{result.raw}\n")
|
||||
```
|
||||
142
src/crewai_tools/tools/arxiv_paper_tool/README.md
Normal file
142
src/crewai_tools/tools/arxiv_paper_tool/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# ArxivPaperTool
|
||||
|
||||
|
||||
# 📚 ArxivPaperTool
|
||||
|
||||
The **ArxivPaperTool** is a utility for fetching metadata and optionally downloading PDFs of academic papers from the [arXiv](https://arxiv.org) platform using its public API. It supports configurable queries, batch retrieval, PDF downloading, and clean formatting for summaries and metadata. This tool is particularly useful for researchers, students, academic agents, and AI tools performing automated literature reviews.
|
||||
|
||||
---
|
||||
|
||||
## Description
|
||||
|
||||
This tool:
|
||||
|
||||
* Accepts a **search query** and retrieves a list of papers from arXiv.
|
||||
* Allows configuration of the **maximum number of results** to fetch.
|
||||
* Optionally downloads the **PDFs** of the matched papers.
|
||||
* Lets you specify whether to name PDF files using the **arXiv ID** or **paper title**.
|
||||
* Saves downloaded files into a **custom or default directory**.
|
||||
* Returns structured summaries of all fetched papers including metadata.
|
||||
|
||||
---
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Type | Required | Description |
|
||||
| ----------------------- | ------ | -------- | --------------------------------------------------------------------------------- |
|
||||
| `search_query` | `str` | ✅ | Search query string (e.g., `"transformer neural network"`). |
|
||||
| `max_results` | `int` | ✅ | Number of results to fetch (between 1 and 100). |
|
||||
| `download_pdfs` | `bool` | ❌ | Whether to download the corresponding PDFs. Defaults to `False`. |
|
||||
| `save_dir` | `str` | ❌ | Directory to save PDFs (created if it doesn’t exist). Defaults to `./arxiv_pdfs`. |
|
||||
| `use_title_as_filename` | `bool` | ❌ | Use the paper title as the filename (sanitized). Defaults to `False`. |
|
||||
|
||||
---
|
||||
|
||||
## 📄 `ArxivPaperTool` Usage Examples
|
||||
|
||||
This document shows how to use the `ArxivPaperTool` to fetch research paper metadata from arXiv and optionally download PDFs.
|
||||
|
||||
### 🔧 Tool Initialization
|
||||
|
||||
```python
|
||||
from crewai_tools import ArxivPaperTool
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Example 1: Fetch Metadata Only (No Downloads)
|
||||
|
||||
```python
|
||||
tool = ArxivPaperTool()
|
||||
result = tool._run(
|
||||
search_query="deep learning",
|
||||
max_results=1
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Example 2: Fetch and Download PDFs (arXiv ID as Filename)
|
||||
|
||||
```python
|
||||
tool = ArxivPaperTool(download_pdfs=True)
|
||||
result = tool._run(
|
||||
search_query="transformer models",
|
||||
max_results=2
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Example 3: Download PDFs into a Custom Directory
|
||||
|
||||
```python
|
||||
tool = ArxivPaperTool(
|
||||
download_pdfs=True,
|
||||
save_dir="./my_papers"
|
||||
)
|
||||
result = tool._run(
|
||||
search_query="graph neural networks",
|
||||
max_results=2
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Example 4: Use Paper Titles as Filenames
|
||||
|
||||
```python
|
||||
tool = ArxivPaperTool(
|
||||
download_pdfs=True,
|
||||
use_title_as_filename=True
|
||||
)
|
||||
result = tool._run(
|
||||
search_query="vision transformers",
|
||||
max_results=1
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Example 5: All Options Combined
|
||||
|
||||
```python
|
||||
tool = ArxivPaperTool(
|
||||
download_pdfs=True,
|
||||
save_dir="./downloads",
|
||||
use_title_as_filename=True
|
||||
)
|
||||
result = tool._run(
|
||||
search_query="stable diffusion",
|
||||
max_results=3
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Run via `__main__`
|
||||
|
||||
Your file can also include:
|
||||
|
||||
```python
|
||||
if __name__ == "__main__":
|
||||
tool = ArxivPaperTool(
|
||||
download_pdfs=True,
|
||||
save_dir="./downloads2",
|
||||
use_title_as_filename=False
|
||||
)
|
||||
result = tool._run(
|
||||
search_query="deep learning",
|
||||
max_results=1
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
152
src/crewai_tools/tools/arxiv_paper_tool/arxiv_paper_tool.py
Normal file
152
src/crewai_tools/tools/arxiv_paper_tool/arxiv_paper_tool.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import re
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Type, List, Optional, ClassVar
|
||||
from pydantic import BaseModel, Field
|
||||
from crewai.tools import BaseTool,EnvVar
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
class ArxivToolInput(BaseModel):
|
||||
search_query: str = Field(..., description="Search query for Arxiv, e.g., 'transformer neural network'")
|
||||
max_results: int = Field(5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100")
|
||||
|
||||
class ArxivPaperTool(BaseTool):
|
||||
BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query"
|
||||
SLEEP_DURATION: ClassVar[int] = 1
|
||||
SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300
|
||||
ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}"
|
||||
REQUEST_TIMEOUT: ClassVar[int] = 10
|
||||
name: str = "Arxiv Paper Fetcher and Downloader"
|
||||
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
|
||||
args_schema: Type[BaseModel] = ArxivToolInput
|
||||
model_config = {"extra": "allow"}
|
||||
package_dependencies: List[str] = ["pydantic"]
|
||||
env_vars: List[EnvVar] = []
|
||||
|
||||
def __init__(self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False):
|
||||
super().__init__()
|
||||
self.download_pdfs = download_pdfs
|
||||
self.save_dir = save_dir
|
||||
self.use_title_as_filename = use_title_as_filename
|
||||
|
||||
def _run(self, search_query: str, max_results: int = 5) -> str:
|
||||
try:
|
||||
args = ArxivToolInput(search_query=search_query, max_results=max_results)
|
||||
logger.info(f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
|
||||
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
|
||||
f"use_title_as_filename={self.use_title_as_filename}")
|
||||
|
||||
papers = self.fetch_arxiv_data(args.search_query, args.max_results)
|
||||
|
||||
if self.download_pdfs:
|
||||
save_dir = self._validate_save_path(self.save_dir)
|
||||
for paper in papers:
|
||||
if paper['pdf_url']:
|
||||
if self.use_title_as_filename:
|
||||
safe_title = re.sub(r'[\\/*?:"<>|]', "_", paper['title']).strip()
|
||||
filename_base = safe_title or paper['arxiv_id']
|
||||
else:
|
||||
filename_base = paper['arxiv_id']
|
||||
filename = f"{filename_base[:500]}.pdf"
|
||||
save_path = Path(save_dir) / filename
|
||||
|
||||
self.download_pdf(paper['pdf_url'], save_path)
|
||||
time.sleep(self.SLEEP_DURATION)
|
||||
|
||||
results = [self._format_paper_result(p) for p in papers]
|
||||
return "\n\n" + "-" * 80 + "\n\n".join(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ArxivTool Error: {str(e)}")
|
||||
return f"Failed to fetch or download Arxiv papers: {str(e)}"
|
||||
|
||||
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]:
|
||||
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
|
||||
logger.info(f"Fetching data from Arxiv API: {api_url}")
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(api_url, timeout=self.REQUEST_TIMEOUT) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP {response.status}: {response.reason}")
|
||||
data = response.read().decode('utf-8')
|
||||
except urllib.error.URLError as e:
|
||||
logger.error(f"Error fetching data from Arxiv: {e}")
|
||||
raise
|
||||
|
||||
root = ET.fromstring(data)
|
||||
papers = []
|
||||
|
||||
for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
|
||||
raw_id = self._get_element_text(entry, "id")
|
||||
arxiv_id = raw_id.split('/')[-1].replace('.', '_') if raw_id else "unknown"
|
||||
|
||||
title = self._get_element_text(entry, "title") or "No Title"
|
||||
summary = self._get_element_text(entry, "summary") or "No Summary"
|
||||
published = self._get_element_text(entry, "published") or "No Publish Date"
|
||||
authors = [
|
||||
self._get_element_text(author, "name") or "Unknown"
|
||||
for author in entry.findall(self.ATOM_NAMESPACE + "author")
|
||||
]
|
||||
|
||||
pdf_url = self._extract_pdf_url(entry)
|
||||
|
||||
papers.append({
|
||||
"arxiv_id": arxiv_id,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"authors": authors,
|
||||
"published_date": published,
|
||||
"pdf_url": pdf_url
|
||||
})
|
||||
|
||||
return papers
|
||||
|
||||
@staticmethod
|
||||
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]:
|
||||
elem = entry.find(f'{ArxivPaperTool.ATOM_NAMESPACE}{element_name}')
|
||||
return elem.text.strip() if elem is not None and elem.text else None
|
||||
|
||||
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]:
|
||||
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||
if link.attrib.get('title', '').lower() == 'pdf':
|
||||
return link.attrib.get('href')
|
||||
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||
href = link.attrib.get('href')
|
||||
if href and 'pdf' in href:
|
||||
return href
|
||||
return None
|
||||
|
||||
def _format_paper_result(self, paper: dict) -> str:
|
||||
summary = (paper['summary'][:self.SUMMARY_TRUNCATE_LENGTH] + '...') \
|
||||
if len(paper['summary']) > self.SUMMARY_TRUNCATE_LENGTH else paper['summary']
|
||||
authors_str = ', '.join(paper['authors'])
|
||||
return (f"Title: {paper['title']}\n"
|
||||
f"Authors: {authors_str}\n"
|
||||
f"Published: {paper['published_date']}\n"
|
||||
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
|
||||
f"Summary: {summary}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_save_path(path: str) -> Path:
|
||||
save_path = Path(path).resolve()
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
return save_path
|
||||
|
||||
def download_pdf(self, pdf_url: str, save_path: str):
|
||||
try:
|
||||
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
|
||||
urllib.request.urlretrieve(pdf_url, str(save_path))
|
||||
logger.info(f"PDF saved: {save_path}")
|
||||
except urllib.error.URLError as e:
|
||||
logger.error(f"Network error occurred while downloading {pdf_url}: {e}")
|
||||
raise
|
||||
except OSError as e:
|
||||
logger.error(f"File save error for {save_path}: {e}")
|
||||
raise
|
||||
113
src/crewai_tools/tools/arxiv_paper_tool/arxiv_paper_tool_test.py
Normal file
113
src/crewai_tools/tools/arxiv_paper_tool/arxiv_paper_tool_test.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import urllib.error
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
from pathlib import Path
|
||||
import xml.etree.ElementTree as ET
|
||||
from crewai_tools.tools.arxiv_paper_tool import ArxivPaperTool
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return ArxivPaperTool(download_pdfs=False)
|
||||
|
||||
def mock_arxiv_response():
|
||||
return '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||
<entry>
|
||||
<id>http://arxiv.org/abs/1234.5678</id>
|
||||
<title>Sample Paper</title>
|
||||
<summary>This is a summary of the sample paper.</summary>
|
||||
<published>2022-01-01T00:00:00Z</published>
|
||||
<author><name>John Doe</name></author>
|
||||
<link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/>
|
||||
</entry>
|
||||
</feed>'''
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_fetch_arxiv_data(mock_urlopen, tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
results = tool.fetch_arxiv_data("transformer", 1)
|
||||
assert isinstance(results, list)
|
||||
assert results[0]['title'] == "Sample Paper"
|
||||
|
||||
@patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout"))
|
||||
def test_fetch_arxiv_data_network_error(mock_urlopen, tool):
|
||||
with pytest.raises(urllib.error.URLError):
|
||||
tool.fetch_arxiv_data("transformer", 1)
|
||||
|
||||
@patch("urllib.request.urlretrieve")
|
||||
def test_download_pdf_success(mock_urlretrieve):
|
||||
tool = ArxivPaperTool()
|
||||
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf"))
|
||||
mock_urlretrieve.assert_called_once()
|
||||
|
||||
@patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied"))
|
||||
def test_download_pdf_oserror(mock_urlretrieve):
|
||||
tool = ArxivPaperTool()
|
||||
with pytest.raises(OSError):
|
||||
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf"))
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
@patch("urllib.request.urlretrieve")
|
||||
def test_run_with_download(mock_urlretrieve, mock_urlopen):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
tool = ArxivPaperTool(download_pdfs=True)
|
||||
output = tool._run("transformer", 1)
|
||||
assert "Title: Sample Paper" in output
|
||||
mock_urlretrieve.assert_called_once()
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_run_no_download(mock_urlopen):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
tool = ArxivPaperTool(download_pdfs=False)
|
||||
result = tool._run("transformer", 1)
|
||||
assert "Title: Sample Paper" in result
|
||||
|
||||
@patch("pathlib.Path.mkdir")
|
||||
def test_validate_save_path_creates_directory(mock_mkdir):
|
||||
path = ArxivPaperTool._validate_save_path("new_folder")
|
||||
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||
assert isinstance(path, Path)
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_run_handles_exception(mock_urlopen):
|
||||
mock_urlopen.side_effect = Exception("API failure")
|
||||
tool = ArxivPaperTool()
|
||||
result = tool._run("transformer", 1)
|
||||
assert "Failed to fetch or download Arxiv papers" in result
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_invalid_xml_response(mock_urlopen, tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"<invalid><xml>"
|
||||
mock_response.status = 200
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
with pytest.raises(ET.ParseError):
|
||||
tool.fetch_arxiv_data("quantum", 1)
|
||||
|
||||
@patch.object(ArxivPaperTool, "fetch_arxiv_data")
|
||||
def test_run_with_max_results(mock_fetch, tool):
|
||||
mock_fetch.return_value = [{
|
||||
"arxiv_id": f"test_{i}",
|
||||
"title": f"Title {i}",
|
||||
"summary": "Summary",
|
||||
"authors": ["Author"],
|
||||
"published_date": "2023-01-01",
|
||||
"pdf_url": None
|
||||
} for i in range(100)]
|
||||
|
||||
result = tool._run(search_query="test", max_results=100)
|
||||
assert result.count("Title:") == 100
|
||||
Reference in New Issue
Block a user