mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fe9352149 | ||
|
|
548170e989 | ||
|
|
417a4e3d91 | ||
|
|
68dce92003 | ||
|
|
289b90f00a | ||
|
|
c591c1ac87 | ||
|
|
86f0dfc2d7 | ||
|
|
74b5c88834 | ||
|
|
13e5ec711d |
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -15,11 +15,11 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- "src/crewai/cli/templates/**"
|
- "lib/crewai/src/crewai/cli/templates/**"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- "src/crewai/cli/templates/**"
|
- "lib/crewai/src/crewai/cli/templates/**"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
analyze:
|
analyze:
|
||||||
|
|||||||
8
.github/workflows/linter.yml
vendored
8
.github/workflows/linter.yml
vendored
@@ -52,10 +52,10 @@ jobs:
|
|||||||
- name: Run Ruff on Changed Files
|
- name: Run Ruff on Changed Files
|
||||||
if: ${{ steps.changed-files.outputs.files != '' }}
|
if: ${{ steps.changed-files.outputs.files != '' }}
|
||||||
run: |
|
run: |
|
||||||
echo "${{ steps.changed-files.outputs.files }}" \
|
echo "${{ steps.changed-files.outputs.files }}" \
|
||||||
| tr ' ' '\n' \
|
| tr ' ' '\n' \
|
||||||
| grep -v 'src/crewai/cli/templates/' \
|
| grep -v 'src/crewai/cli/templates/' \
|
||||||
| xargs -I{} uv run ruff check "{}"
|
| xargs -I{} uv run ruff check "{}"
|
||||||
|
|
||||||
- name: Save uv caches
|
- name: Save uv caches
|
||||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||||
|
|||||||
71
.github/workflows/publish.yml
vendored
Normal file
71
.github/workflows/publish.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
name: Publish to PyPI
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [ published ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
if: github.event.release.prerelease == true
|
||||||
|
name: Build packages
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
uv build --all-packages
|
||||||
|
rm dist/.gitignore
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
publish:
|
||||||
|
if: github.event.release.prerelease == true
|
||||||
|
name: Publish to PyPI
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/crewai
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
version: "0.8.4"
|
||||||
|
python-version: "3.12"
|
||||||
|
enable-cache: false
|
||||||
|
|
||||||
|
- name: Download artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist
|
||||||
|
path: dist
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
for package in dist/*; do
|
||||||
|
echo "Publishing $package"
|
||||||
|
uv publish "$package"
|
||||||
|
done
|
||||||
31
.github/workflows/tests.yml
vendored
31
.github/workflows/tests.yml
vendored
@@ -8,6 +8,14 @@ permissions:
|
|||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: fake-api-key
|
OPENAI_API_KEY: fake-api-key
|
||||||
PYTHONUNBUFFERED: 1
|
PYTHONUNBUFFERED: 1
|
||||||
|
BRAVE_API_KEY: fake-brave-key
|
||||||
|
SNOWFLAKE_USER: fake-snowflake-user
|
||||||
|
SNOWFLAKE_PASSWORD: fake-snowflake-password
|
||||||
|
SNOWFLAKE_ACCOUNT: fake-snowflake-account
|
||||||
|
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
|
||||||
|
SNOWFLAKE_DATABASE: fake-snowflake-database
|
||||||
|
SNOWFLAKE_SCHEMA: fake-snowflake-schema
|
||||||
|
EMBEDCHAIN_DB_URI: sqlite:///test.db
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
@@ -56,13 +64,13 @@ jobs:
|
|||||||
- name: Run tests (group ${{ matrix.group }} of 8)
|
- name: Run tests (group ${{ matrix.group }} of 8)
|
||||||
run: |
|
run: |
|
||||||
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
PYTHON_VERSION_SAFE=$(echo "${{ matrix.python-version }}" | tr '.' '_')
|
||||||
DURATION_FILE=".test_durations_py${PYTHON_VERSION_SAFE}"
|
DURATION_FILE="../../.test_durations_py${PYTHON_VERSION_SAFE}"
|
||||||
|
|
||||||
# Temporarily always skip cached durations to fix test splitting
|
# Temporarily always skip cached durations to fix test splitting
|
||||||
# When durations don't match, pytest-split runs duplicate tests instead of splitting
|
# When durations don't match, pytest-split runs duplicate tests instead of splitting
|
||||||
echo "Using even test splitting (duration cache disabled until fix merged)"
|
echo "Using even test splitting (duration cache disabled until fix merged)"
|
||||||
DURATIONS_ARG=""
|
DURATIONS_ARG=""
|
||||||
|
|
||||||
# Original logic (disabled temporarily):
|
# Original logic (disabled temporarily):
|
||||||
# if [ ! -f "$DURATION_FILE" ]; then
|
# if [ ! -f "$DURATION_FILE" ]; then
|
||||||
# echo "No cached durations found, tests will be split evenly"
|
# echo "No cached durations found, tests will be split evenly"
|
||||||
@@ -74,8 +82,8 @@ jobs:
|
|||||||
# echo "No test changes detected, using cached test durations for optimal splitting"
|
# echo "No test changes detected, using cached test durations for optimal splitting"
|
||||||
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
|
# DURATIONS_ARG="--durations-path=${DURATION_FILE}"
|
||||||
# fi
|
# fi
|
||||||
|
|
||||||
uv run pytest \
|
cd lib/crewai && uv run pytest \
|
||||||
--block-network \
|
--block-network \
|
||||||
--timeout=30 \
|
--timeout=30 \
|
||||||
-vv \
|
-vv \
|
||||||
@@ -86,6 +94,19 @@ jobs:
|
|||||||
-n auto \
|
-n auto \
|
||||||
--maxfail=3
|
--maxfail=3
|
||||||
|
|
||||||
|
- name: Run tool tests (group ${{ matrix.group }} of 8)
|
||||||
|
run: |
|
||||||
|
cd lib/crewai-tools && uv run pytest \
|
||||||
|
--block-network \
|
||||||
|
--timeout=30 \
|
||||||
|
-vv \
|
||||||
|
--splits 8 \
|
||||||
|
--group ${{ matrix.group }} \
|
||||||
|
--durations=10 \
|
||||||
|
-n auto \
|
||||||
|
--maxfail=3
|
||||||
|
|
||||||
|
|
||||||
- name: Save uv caches
|
- name: Save uv caches
|
||||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,7 +2,6 @@
|
|||||||
.pytest_cache
|
.pytest_cache
|
||||||
__pycache__
|
__pycache__
|
||||||
dist/
|
dist/
|
||||||
lib/
|
|
||||||
.env
|
.env
|
||||||
assets/*
|
assets/*
|
||||||
.idea
|
.idea
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ repos:
|
|||||||
entry: uv run ruff check
|
entry: uv run ruff check
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
|
exclude: ^lib/crewai/
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
name: ruff-format
|
name: ruff-format
|
||||||
entry: uv run ruff format
|
entry: uv run ruff format
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
|
exclude: ^lib/crewai/
|
||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
name: mypy
|
||||||
entry: uv run mypy
|
entry: uv run mypy
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
exclude: ^tests/
|
exclude: ^lib/crewai/
|
||||||
|
|||||||
335
lib/crewai-tools/BUILDING_TOOLS.md
Normal file
335
lib/crewai-tools/BUILDING_TOOLS.md
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
## Building CrewAI Tools
|
||||||
|
|
||||||
|
This guide shows you how to build high‑quality CrewAI tools that match the patterns in this repository and are ready to be merged. It focuses on: architecture, conventions, environment variables, dependencies, testing, documentation, and a complete example.
|
||||||
|
|
||||||
|
### Who this is for
|
||||||
|
- Contributors creating new tools under `crewai_tools/tools/*`
|
||||||
|
- Maintainers reviewing PRs for consistency and DX
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick‑start checklist
|
||||||
|
1. Create a new folder under `crewai_tools/tools/<your_tool_name>/` with a `README.md` and a `<your_tool_name>.py`.
|
||||||
|
2. Implement a class that ends with `Tool` and subclasses `BaseTool` (or `RagTool` when appropriate).
|
||||||
|
3. Define a Pydantic `args_schema` with explicit field descriptions and validation.
|
||||||
|
4. Declare `env_vars` and `package_dependencies` in the class when needed.
|
||||||
|
5. Lazily initialize clients in `__init__` or `_run` and handle missing credentials with clear errors.
|
||||||
|
6. Implement `_run(...) -> str | dict` and, if needed, `_arun(...)`.
|
||||||
|
7. Add tests under `tests/tools/` (unit, no real network calls; mock or record safely).
|
||||||
|
8. Add a concise tool `README.md` with usage and required env vars.
|
||||||
|
9. If you add optional dependencies, register them in `pyproject.toml` under `[project.optional-dependencies]` and reference that extra in your tool docs.
|
||||||
|
10. Run `uv run pytest` and `pre-commit run -a` locally; ensure green.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tool anatomy and conventions
|
||||||
|
|
||||||
|
### BaseTool pattern
|
||||||
|
All tools follow this structure:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Any, List, Optional, Type
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
|
||||||
|
|
||||||
|
class MyToolInput(BaseModel):
|
||||||
|
"""Input schema for MyTool."""
|
||||||
|
query: str = Field(..., description="Your input description here")
|
||||||
|
limit: int = Field(5, ge=1, le=50, description="Max items to return")
|
||||||
|
|
||||||
|
|
||||||
|
class MyTool(BaseTool):
|
||||||
|
name: str = "My Tool"
|
||||||
|
description: str = "Explain succinctly what this tool does and when to use it."
|
||||||
|
args_schema: Type[BaseModel] = MyToolInput
|
||||||
|
|
||||||
|
# Only include when applicable
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(name="MY_API_KEY", description="API key for My service", required=True),
|
||||||
|
]
|
||||||
|
package_dependencies: List[str] = ["my-sdk"]
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Lazy import to keep base install light
|
||||||
|
try:
|
||||||
|
import my_sdk # noqa: F401
|
||||||
|
except Exception as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Missing optional dependency 'my-sdk'. Install with: \n"
|
||||||
|
" uv add crewai-tools --extra my-sdk\n"
|
||||||
|
"or\n"
|
||||||
|
" pip install my-sdk\n"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if "MY_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Environment variable MY_API_KEY is required for MyTool")
|
||||||
|
|
||||||
|
def _run(self, query: str, limit: int = 5, **_: Any) -> str:
|
||||||
|
"""Synchronous execution. Return a concise string or JSON string."""
|
||||||
|
# Implement your logic here; do not print. Return the content.
|
||||||
|
# Handle errors gracefully, return clear messages.
|
||||||
|
return f"Processed {query} with limit={limit}"
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
"""Optional async counterpart if your client supports it."""
|
||||||
|
# Prefer delegating to _run when the client is thread-safe
|
||||||
|
return self._run(*args, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
Key points:
|
||||||
|
- Class name must end with `Tool` to be auto‑discovered by our tooling.
|
||||||
|
- Use `args_schema` for inputs; always include `description` and validation.
|
||||||
|
- Validate env vars early and fail with actionable errors.
|
||||||
|
- Keep outputs deterministic and compact; favor `str` (possibly JSON‑encoded) or small dicts converted to strings.
|
||||||
|
- Avoid printing; return the final string.
|
||||||
|
|
||||||
|
### Error handling
|
||||||
|
- Wrap network and I/O with try/except and return a helpful message. See `BraveSearchTool` and others for patterns.
|
||||||
|
- Validate required inputs and environment configuration with clear messages.
|
||||||
|
- Keep exceptions user‑friendly; do not leak stack traces.
|
||||||
|
|
||||||
|
### Rate limiting and retries
|
||||||
|
- If the upstream API enforces request pacing, implement minimal rate limiting (see `BraveSearchTool`).
|
||||||
|
- Consider idempotency and backoff for transient errors where appropriate.
|
||||||
|
|
||||||
|
### Async support
|
||||||
|
- Implement `_arun` only if your library has a true async client or your sync calls are thread‑safe.
|
||||||
|
- Otherwise, delegate `_arun` to `_run` as in multiple existing tools.
|
||||||
|
|
||||||
|
### Returning values
|
||||||
|
- Return a string (or JSON string) that’s ready to display in an agent transcript.
|
||||||
|
- If returning structured data, keep it small and human‑readable. Use stable keys and ordering.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## RAG tools and adapters
|
||||||
|
|
||||||
|
If your tool is a knowledge source, consider extending `RagTool` and/or creating an adapter.
|
||||||
|
|
||||||
|
- `RagTool` exposes `add(...)` and a `query(question: str) -> str` contract through an `Adapter`.
|
||||||
|
- See `crewai_tools/tools/rag/rag_tool.py` and adapters like `embedchain_adapter.py` and `lancedb_adapter.py`.
|
||||||
|
|
||||||
|
Minimal adapter example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter, RagTool
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAdapter(Adapter):
|
||||||
|
store: list[str] = []
|
||||||
|
|
||||||
|
def add(self, text: str, **_: Any) -> None:
|
||||||
|
self.store.append(text)
|
||||||
|
|
||||||
|
def query(self, question: str) -> str:
|
||||||
|
# naive demo: return all text containing any word from the question
|
||||||
|
tokens = set(question.lower().split())
|
||||||
|
hits = [t for t in self.store if tokens & set(t.lower().split())]
|
||||||
|
return "\n".join(hits) if hits else "No relevant content found."
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRagTool(RagTool):
|
||||||
|
name: str = "In‑memory RAG"
|
||||||
|
description: str = "Toy RAG that stores text in memory and returns matches."
|
||||||
|
adapter: Adapter = MemoryAdapter()
|
||||||
|
```
|
||||||
|
|
||||||
|
When using external vector DBs (MongoDB, Qdrant, Weaviate), study the existing tools to follow indexing, embedding, and query configuration patterns closely.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Toolkits (multiple related tools)
|
||||||
|
|
||||||
|
Some integrations expose a toolkit (a group of tools) rather than a single class. See Bedrock `browser_toolkit.py` and `code_interpreter_toolkit.py`.
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
- Provide small, focused `BaseTool` classes for each operation (e.g., `navigate`, `click`, `extract_text`).
|
||||||
|
- Offer a helper `create_<name>_toolkit(...) -> Tuple[ToolkitClass, List[BaseTool]]` to create tools and manage resources.
|
||||||
|
- If you open external resources (browsers, interpreters), support cleanup methods and optionally context manager usage.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment variables and dependencies
|
||||||
|
|
||||||
|
### env_vars
|
||||||
|
- Declare as `env_vars: List[EnvVar]` with `name`, `description`, `required`, and optional `default`.
|
||||||
|
- Validate presence in `__init__` or on first `_run` call.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
- List runtime packages in `package_dependencies` on the class.
|
||||||
|
- If they are genuinely optional, add an extra under `[project.optional-dependencies]` in `pyproject.toml` (e.g., `tavily-python`, `serpapi`, `scrapfly-sdk`).
|
||||||
|
- Use lazy imports to avoid hard deps for users who don’t need the tool.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Place tests under `tests/tools/` and follow these rules:
|
||||||
|
- Do not hit real external services in CI. Use mocks, fakes, or recorded fixtures where allowed.
|
||||||
|
- Validate input validation, env var handling, error messages, and happy path output formatting.
|
||||||
|
- Keep tests fast and deterministic.
|
||||||
|
|
||||||
|
Example skeleton (`tests/tools/my_tool_test.py`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from crewai_tools.tools.my_tool.my_tool import MyTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_requires_env_var(monkeypatch):
|
||||||
|
monkeypatch.delenv("MY_API_KEY", raising=False)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MyTool()
|
||||||
|
|
||||||
|
|
||||||
|
def test_happy_path(monkeypatch):
|
||||||
|
monkeypatch.setenv("MY_API_KEY", "test")
|
||||||
|
tool = MyTool()
|
||||||
|
result = tool.run(query="hello", limit=2)
|
||||||
|
assert "hello" in result
|
||||||
|
```
|
||||||
|
|
||||||
|
Run locally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
pre-commit run -a
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Each tool must include a `README.md` in its folder with:
|
||||||
|
- What it does and when to use it
|
||||||
|
- Required env vars and optional extras (with install snippet)
|
||||||
|
- Minimal usage example
|
||||||
|
|
||||||
|
Update the root `README.md` only if the tool introduces a new category or notable capability.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Discovery and specs
|
||||||
|
|
||||||
|
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `generate_tool_specs.py`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Full example: “Weather Search Tool”
|
||||||
|
|
||||||
|
This example demonstrates: `args_schema`, `env_vars`, `package_dependencies`, lazy imports, validation, and robust error handling.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# file: crewai_tools/tools/weather_tool/weather_tool.py
|
||||||
|
from typing import Any, List, Optional, Type
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherToolInput(BaseModel):
|
||||||
|
"""Input schema for WeatherTool."""
|
||||||
|
city: str = Field(..., description="City name, e.g., 'Berlin'")
|
||||||
|
country: Optional[str] = Field(None, description="ISO country code, e.g., 'DE'")
|
||||||
|
units: str = Field(
|
||||||
|
default="metric",
|
||||||
|
description="Units system: 'metric' or 'imperial'",
|
||||||
|
pattern=r"^(metric|imperial)$",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherTool(BaseTool):
|
||||||
|
name: str = "Weather Search"
|
||||||
|
description: str = (
|
||||||
|
"Look up current weather for a city using a public weather API."
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = WeatherToolInput
|
||||||
|
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="WEATHER_API_KEY",
|
||||||
|
description="API key for the weather service",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
package_dependencies: List[str] = ["requests"]
|
||||||
|
|
||||||
|
base_url: str = "https://api.openweathermap.org/data/2.5/weather"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if "WEATHER_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("WEATHER_API_KEY is required for WeatherTool")
|
||||||
|
|
||||||
|
def _run(self, city: str, country: Optional[str] = None, units: str = "metric") -> str:
|
||||||
|
try:
|
||||||
|
q = f"{city},{country}" if country else city
|
||||||
|
params = {
|
||||||
|
"q": q,
|
||||||
|
"units": units,
|
||||||
|
"appid": os.environ["WEATHER_API_KEY"],
|
||||||
|
}
|
||||||
|
resp = requests.get(self.base_url, params=params, timeout=10)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
main = data.get("weather", [{}])[0].get("main", "Unknown")
|
||||||
|
desc = data.get("weather", [{}])[0].get("description", "")
|
||||||
|
temp = data.get("main", {}).get("temp")
|
||||||
|
feels = data.get("main", {}).get("feels_like")
|
||||||
|
city_name = data.get("name", city)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Weather in {city_name}: {main} ({desc}). "
|
||||||
|
f"Temperature: {temp}°, feels like {feels}°."
|
||||||
|
)
|
||||||
|
except requests.Timeout:
|
||||||
|
return "Weather service timed out. Please try again later."
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
return f"Weather service error: {e.response.status_code} {e.response.text[:120]}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Unexpected error fetching weather: {e}"
|
||||||
|
```
|
||||||
|
|
||||||
|
Folder layout:
|
||||||
|
|
||||||
|
```
|
||||||
|
crewai_tools/tools/weather_tool/
|
||||||
|
├─ weather_tool.py
|
||||||
|
└─ README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
And `README.md` should document env vars and usage.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## PR checklist
|
||||||
|
- [ ] Tool lives under `crewai_tools/tools/<name>/`
|
||||||
|
- [ ] Class ends with `Tool` and subclasses `BaseTool` (or `RagTool`)
|
||||||
|
- [ ] Precise `args_schema` with descriptions and validation
|
||||||
|
- [ ] `env_vars` declared (if any) and validated
|
||||||
|
- [ ] `package_dependencies` and optional extras added in `pyproject.toml` (if any)
|
||||||
|
- [ ] Clear error handling; no prints
|
||||||
|
- [ ] Unit tests added (`tests/tools/`), fast and deterministic
|
||||||
|
- [ ] Tool `README.md` with usage and env vars
|
||||||
|
- [ ] `pre-commit` and `pytest` pass locally
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tips for great DX
|
||||||
|
- Keep responses short and useful—agents quote your tool output directly.
|
||||||
|
- Validate early; fail fast with actionable guidance.
|
||||||
|
- Prefer lazy imports; minimize default install surface.
|
||||||
|
- Mirror patterns from similar tools in this repo for a consistent developer experience.
|
||||||
|
|
||||||
|
Happy building!
|
||||||
|
|
||||||
|
|
||||||
229
lib/crewai-tools/README.md
Normal file
229
lib/crewai-tools/README.md
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
<div align="center">
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
<div align="left">
|
||||||
|
|
||||||
|
# CrewAI Tools
|
||||||
|
|
||||||
|
Empower your CrewAI agents with powerful, customizable tools to elevate their capabilities and tackle sophisticated, real-world tasks.
|
||||||
|
|
||||||
|
CrewAI Tools provide the essential functionality to extend your agents, helping you rapidly enhance your automations with reliable, ready-to-use tools or custom-built solutions tailored precisely to your needs.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Links
|
||||||
|
|
||||||
|
[Homepage](https://www.crewai.com/) | [Documentation](https://docs.crewai.com/) | [Examples](https://github.com/crewAIInc/crewAI-examples) | [Community](https://community.crewai.com/)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
|
||||||
|
CrewAI provides an extensive collection of powerful tools ready to enhance your agents:
|
||||||
|
|
||||||
|
- **File Management**: `FileReadTool`, `FileWriteTool`
|
||||||
|
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
|
||||||
|
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
|
||||||
|
- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool`
|
||||||
|
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
|
||||||
|
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
|
||||||
|
|
||||||
|
And many more robust tools to simplify your agent integrations.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Creating Custom Tools
|
||||||
|
|
||||||
|
CrewAI offers two straightforward approaches to creating custom tools:
|
||||||
|
|
||||||
|
### Subclassing `BaseTool`
|
||||||
|
|
||||||
|
Define your tool by subclassing:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
class MyCustomTool(BaseTool):
|
||||||
|
name: str = "Tool Name"
|
||||||
|
description: str = "Detailed description here."
|
||||||
|
|
||||||
|
def _run(self, *args, **kwargs):
|
||||||
|
# Your tool logic here
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the `tool` Decorator
|
||||||
|
|
||||||
|
Quickly create lightweight tools using decorators:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import tool
|
||||||
|
|
||||||
|
@tool("Tool Name")
|
||||||
|
def my_custom_function(input):
|
||||||
|
# Tool logic here
|
||||||
|
return output
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CrewAI Tools and MCP
|
||||||
|
|
||||||
|
CrewAI Tools supports the Model Context Protocol (MCP). It gives you access to thousands of tools from the hundreds of MCP servers out there built by the community.
|
||||||
|
|
||||||
|
Before you start using MCP with CrewAI tools, you need to install the `mcp` extra dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install crewai-tools[mcp]
|
||||||
|
# or
|
||||||
|
uv add crewai-tools --extra mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
To quickly get started with MCP in CrewAI you have 2 options:
|
||||||
|
|
||||||
|
### Option 1: Fully managed connection
|
||||||
|
|
||||||
|
In this scenario we use a contextmanager (`with` statement) to start and stop the the connection with the MCP server.
|
||||||
|
This is done in the background and you only get to interact with the CrewAI tools corresponding to the MCP server's tools.
|
||||||
|
|
||||||
|
For an STDIO based MCP server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from crewai_tools import MCPServerAdapter
|
||||||
|
|
||||||
|
serverparams = StdioServerParameters(
|
||||||
|
command="uvx",
|
||||||
|
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
env={"UV_PYTHON": "3.12", **os.environ},
|
||||||
|
)
|
||||||
|
|
||||||
|
with MCPServerAdapter(serverparams) as tools:
|
||||||
|
# tools is now a list of CrewAI Tools matching 1:1 with the MCP server's tools
|
||||||
|
agent = Agent(..., tools=tools)
|
||||||
|
task = Task(...)
|
||||||
|
crew = Crew(..., agents=[agent], tasks=[task])
|
||||||
|
crew.kickoff(...)
|
||||||
|
```
|
||||||
|
For an SSE based MCP server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
serverparams = {"url": "http://localhost:8000/sse"}
|
||||||
|
with MCPServerAdapter(serverparams) as tools:
|
||||||
|
# tools is now a list of CrewAI Tools matching 1:1 with the MCP server's tools
|
||||||
|
agent = Agent(..., tools=tools)
|
||||||
|
task = Task(...)
|
||||||
|
crew = Crew(..., agents=[agent], tasks=[task])
|
||||||
|
crew.kickoff(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: More control over the MCP connection
|
||||||
|
|
||||||
|
If you need more control over the MCP connection, you can instanciate the MCPServerAdapter into an `mcp_server_adapter` object which can be used to manage the connection with the MCP server and access the available tools.
|
||||||
|
|
||||||
|
**important**: in this case you need to call `mcp_server_adapter.stop()` to make sure the connection is correctly stopped. We recommend that you use a `try ... finally` block run to make sure the `.stop()` is called even in case of errors.
|
||||||
|
|
||||||
|
Here is the same example for an STDIO MCP Server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from crewai_tools import MCPServerAdapter
|
||||||
|
|
||||||
|
serverparams = StdioServerParameters(
|
||||||
|
command="uvx",
|
||||||
|
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
env={"UV_PYTHON": "3.12", **os.environ},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
mcp_server_adapter = MCPServerAdapter(serverparams)
|
||||||
|
tools = mcp_server_adapter.tools
|
||||||
|
# tools is now a list of CrewAI Tools matching 1:1 with the MCP server's tools
|
||||||
|
agent = Agent(..., tools=tools)
|
||||||
|
task = Task(...)
|
||||||
|
crew = Crew(..., agents=[agent], tasks=[task])
|
||||||
|
crew.kickoff(...)
|
||||||
|
|
||||||
|
# ** important ** don't forget to stop the connection
|
||||||
|
finally:
|
||||||
|
mcp_server_adapter.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
And finally the same thing but for an SSE MCP Server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from crewai_tools import MCPServerAdapter
|
||||||
|
|
||||||
|
serverparams = {"url": "http://localhost:8000/sse"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
mcp_server_adapter = MCPServerAdapter(serverparams)
|
||||||
|
tools = mcp_server_adapter.tools
|
||||||
|
# tools is now a list of CrewAI Tools matching 1:1 with the MCP server's tools
|
||||||
|
agent = Agent(..., tools=tools)
|
||||||
|
task = Task(...)
|
||||||
|
crew = Crew(..., agents=[agent], tasks=[task])
|
||||||
|
crew.kickoff(...)
|
||||||
|
|
||||||
|
# ** important ** don't forget to stop the connection
|
||||||
|
finally:
|
||||||
|
mcp_server_adapter.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Considerations & Limitations
|
||||||
|
|
||||||
|
#### Staying Safe with MCP
|
||||||
|
|
||||||
|
Always make sure that you trust the MCP Server before using it. Using an STDIO server will execute code on your machine. Using SSE is still not a silver bullet with many injection possible into your application from a malicious MCP server.
|
||||||
|
|
||||||
|
#### Limitations
|
||||||
|
|
||||||
|
* At this time we only support tools from MCP Server not other type of primitives like prompts, resources...
|
||||||
|
* We only return the first text output returned by the MCP Server tool using `.content[0].text`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why Use CrewAI Tools?
|
||||||
|
|
||||||
|
- **Simplicity & Flexibility**: Easy-to-use yet powerful enough for complex workflows.
|
||||||
|
- **Rapid Integration**: Seamlessly incorporate external services, APIs, and databases.
|
||||||
|
- **Enterprise Ready**: Built for stability, performance, and consistent results.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Contribution Guidelines
|
||||||
|
|
||||||
|
We welcome contributions from the community!
|
||||||
|
|
||||||
|
1. Fork and clone the repository.
|
||||||
|
2. Create a new branch (`git checkout -b feature/my-feature`).
|
||||||
|
3. Commit your changes (`git commit -m 'Add my feature'`).
|
||||||
|
4. Push your branch (`git push origin feature/my-feature`).
|
||||||
|
5. Open a pull request.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Developer Quickstart
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install crewai[tools]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Development Setup
|
||||||
|
|
||||||
|
- Install dependencies: `uv sync`
|
||||||
|
- Run tests: `uv run pytest`
|
||||||
|
- Run static type checking: `uv run pyright`
|
||||||
|
- Set up pre-commit hooks: `pre-commit install`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Support and Community
|
||||||
|
|
||||||
|
Join our rapidly growing community and receive real-time support:
|
||||||
|
|
||||||
|
- [Discourse](https://community.crewai.com/)
|
||||||
|
- [Open an Issue](https://github.com/crewAIInc/crewAI/issues)
|
||||||
|
|
||||||
|
Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools.
|
||||||
155
lib/crewai-tools/generate_tool_specs.py
Normal file
155
lib/crewai-tools/generate_tool_specs.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||||
|
from crewai_tools import tools
|
||||||
|
from pydantic.json_schema import GenerateJsonSchema
|
||||||
|
from pydantic_core import PydanticOmit
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaGenerator(GenerateJsonSchema):
|
||||||
|
def handle_invalid_for_json_schema(self, schema, error_info):
|
||||||
|
raise PydanticOmit
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSpecExtractor:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.tools_spec: List[Dict[str, Any]] = []
|
||||||
|
self.processed_tools: set[str] = set()
|
||||||
|
|
||||||
|
def extract_all_tools(self) -> List[Dict[str, Any]]:
|
||||||
|
for name in dir(tools):
|
||||||
|
if name.endswith("Tool") and name not in self.processed_tools:
|
||||||
|
obj = getattr(tools, name, None)
|
||||||
|
if inspect.isclass(obj):
|
||||||
|
self.extract_tool_info(obj)
|
||||||
|
self.processed_tools.add(name)
|
||||||
|
return self.tools_spec
|
||||||
|
|
||||||
|
def extract_tool_info(self, tool_class: BaseTool) -> None:
|
||||||
|
try:
|
||||||
|
core_schema = tool_class.__pydantic_core_schema__
|
||||||
|
if not core_schema:
|
||||||
|
return
|
||||||
|
|
||||||
|
schema = self._unwrap_schema(core_schema)
|
||||||
|
fields = schema.get("schema", {}).get("fields", {})
|
||||||
|
|
||||||
|
tool_info = {
|
||||||
|
"name": tool_class.__name__,
|
||||||
|
"humanized_name": self._extract_field_default(
|
||||||
|
fields.get("name"), fallback=tool_class.__name__
|
||||||
|
),
|
||||||
|
"description": self._extract_field_default(
|
||||||
|
fields.get("description")
|
||||||
|
).strip(),
|
||||||
|
"run_params_schema": self._extract_params(fields.get("args_schema")),
|
||||||
|
"init_params_schema": self._extract_init_params(tool_class),
|
||||||
|
"env_vars": self._extract_env_vars(fields.get("env_vars")),
|
||||||
|
"package_dependencies": self._extract_field_default(
|
||||||
|
fields.get("package_dependencies"), fallback=[]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.tools_spec.append(tool_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting {tool_class.__name__}: {e}")
|
||||||
|
|
||||||
|
def _unwrap_schema(self, schema: Dict) -> Dict:
|
||||||
|
while (
|
||||||
|
schema.get("type") in {"function-after", "default"} and "schema" in schema
|
||||||
|
):
|
||||||
|
schema = schema["schema"]
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def _extract_field_default(self, field: Optional[Dict], fallback: str = "") -> str:
|
||||||
|
if not field:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
schema = field.get("schema", {})
|
||||||
|
default = schema.get("default")
|
||||||
|
return default if isinstance(default, (list, str, int)) else fallback
|
||||||
|
|
||||||
|
def _extract_params(
|
||||||
|
self, args_schema_field: Optional[Dict]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
if not args_schema_field:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
args_schema_class = args_schema_field.get("schema", {}).get("default")
|
||||||
|
if not (
|
||||||
|
inspect.isclass(args_schema_class)
|
||||||
|
and hasattr(args_schema_class, "__pydantic_core_schema__")
|
||||||
|
):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return args_schema_class.model_json_schema(
|
||||||
|
schema_generator=SchemaGenerator, mode="validation"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting params from {args_schema_class}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _extract_env_vars(self, env_vars_field: Optional[Dict]) -> List[Dict[str, str]]:
|
||||||
|
if not env_vars_field:
|
||||||
|
return []
|
||||||
|
|
||||||
|
env_vars = []
|
||||||
|
for env_var in env_vars_field.get("schema", {}).get("default", []):
|
||||||
|
if isinstance(env_var, EnvVar):
|
||||||
|
env_vars.append(
|
||||||
|
{
|
||||||
|
"name": env_var.name,
|
||||||
|
"description": env_var.description,
|
||||||
|
"required": env_var.required,
|
||||||
|
"default": env_var.default,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return env_vars
|
||||||
|
|
||||||
|
def _extract_init_params(self, tool_class: BaseTool) -> dict:
|
||||||
|
ignored_init_params = [
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"env_vars",
|
||||||
|
"args_schema",
|
||||||
|
"description_updated",
|
||||||
|
"cache_function",
|
||||||
|
"result_as_answer",
|
||||||
|
"max_usage_count",
|
||||||
|
"current_usage_count",
|
||||||
|
"package_dependencies",
|
||||||
|
]
|
||||||
|
|
||||||
|
json_schema = tool_class.model_json_schema(
|
||||||
|
schema_generator=SchemaGenerator, mode="serialization"
|
||||||
|
)
|
||||||
|
|
||||||
|
properties = {}
|
||||||
|
for key, value in json_schema["properties"].items():
|
||||||
|
if key not in ignored_init_params:
|
||||||
|
properties[key] = value
|
||||||
|
|
||||||
|
json_schema["properties"] = properties
|
||||||
|
return json_schema
|
||||||
|
|
||||||
|
def save_to_json(self, output_path: str) -> None:
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"tools": self.tools_spec}, f, indent=2, sort_keys=True)
|
||||||
|
print(f"Saved tool specs to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
output_file = Path(__file__).parent / "tool.specs.json"
|
||||||
|
extractor = ToolSpecExtractor()
|
||||||
|
|
||||||
|
specs = extractor.extract_all_tools()
|
||||||
|
extractor.save_to_json(str(output_file))
|
||||||
|
|
||||||
|
print(f"Extracted {len(specs)} tool classes.")
|
||||||
153
lib/crewai-tools/pyproject.toml
Normal file
153
lib/crewai-tools/pyproject.toml
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
[project]
|
||||||
|
name = "crewai-tools"
|
||||||
|
dynamic = ["version"]
|
||||||
|
description = "Set of tools for the crewAI framework"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{ name = "João Moura", email = "joaomdmoura@gmail.com" },
|
||||||
|
]
|
||||||
|
requires-python = ">=3.10, <3.14"
|
||||||
|
dependencies = [
|
||||||
|
"lancedb>=0.5.4",
|
||||||
|
"pytube>=15.0.0",
|
||||||
|
"requests>=2.32.0",
|
||||||
|
"docker>=7.1.0",
|
||||||
|
"crewai==1.0.0a1",
|
||||||
|
"lancedb>=0.5.4",
|
||||||
|
"tiktoken>=0.8.0",
|
||||||
|
"stagehand>=0.4.1",
|
||||||
|
"beautifulsoup4>=4.13.4",
|
||||||
|
"pypdf>=5.9.0",
|
||||||
|
"python-docx>=1.2.0",
|
||||||
|
"youtube-transcript-api>=1.2.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://crewai.com"
|
||||||
|
Repository = "https://github.com/crewAIInc/crewAI"
|
||||||
|
Documentation = "https://docs.crewai.com"
|
||||||
|
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
scrapfly-sdk = [
|
||||||
|
"scrapfly-sdk>=0.8.19",
|
||||||
|
]
|
||||||
|
sqlalchemy = [
|
||||||
|
"sqlalchemy>=2.0.35",
|
||||||
|
]
|
||||||
|
multion = [
|
||||||
|
"multion>=1.1.0",
|
||||||
|
]
|
||||||
|
firecrawl-py = [
|
||||||
|
"firecrawl-py>=1.8.0",
|
||||||
|
]
|
||||||
|
composio-core = [
|
||||||
|
"composio-core>=0.6.11.post1",
|
||||||
|
]
|
||||||
|
browserbase = [
|
||||||
|
"browserbase>=1.0.5",
|
||||||
|
]
|
||||||
|
weaviate-client = [
|
||||||
|
"weaviate-client>=4.10.2",
|
||||||
|
]
|
||||||
|
patronus = [
|
||||||
|
"patronus>=0.0.16",
|
||||||
|
]
|
||||||
|
serpapi = [
|
||||||
|
"serpapi>=0.1.5",
|
||||||
|
]
|
||||||
|
beautifulsoup4 = [
|
||||||
|
"beautifulsoup4>=4.12.3",
|
||||||
|
]
|
||||||
|
selenium = [
|
||||||
|
"selenium>=4.27.1",
|
||||||
|
]
|
||||||
|
spider-client = [
|
||||||
|
"spider-client>=0.1.25",
|
||||||
|
]
|
||||||
|
scrapegraph-py = [
|
||||||
|
"scrapegraph-py>=1.9.0",
|
||||||
|
]
|
||||||
|
linkup-sdk = [
|
||||||
|
"linkup-sdk>=0.2.2",
|
||||||
|
]
|
||||||
|
tavily-python = [
|
||||||
|
"tavily-python>=0.5.4",
|
||||||
|
]
|
||||||
|
hyperbrowser = [
|
||||||
|
"hyperbrowser>=0.18.0",
|
||||||
|
]
|
||||||
|
snowflake = [
|
||||||
|
"cryptography>=43.0.3",
|
||||||
|
"snowflake-connector-python>=3.12.4",
|
||||||
|
"snowflake-sqlalchemy>=1.7.3",
|
||||||
|
]
|
||||||
|
singlestore = [
|
||||||
|
"singlestoredb>=1.12.4",
|
||||||
|
"SQLAlchemy>=2.0.40",
|
||||||
|
]
|
||||||
|
exa-py = [
|
||||||
|
"exa-py>=1.8.7",
|
||||||
|
]
|
||||||
|
qdrant-client = [
|
||||||
|
"qdrant-client>=1.12.1",
|
||||||
|
]
|
||||||
|
apify = [
|
||||||
|
"langchain-apify>=0.1.2,<1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
databricks-sdk = [
|
||||||
|
"databricks-sdk>=0.46.0",
|
||||||
|
]
|
||||||
|
couchbase = [
|
||||||
|
"couchbase>=4.3.5",
|
||||||
|
]
|
||||||
|
mcp = [
|
||||||
|
"mcp>=1.6.0",
|
||||||
|
"mcpadapt>=0.1.9",
|
||||||
|
]
|
||||||
|
stagehand = [
|
||||||
|
"stagehand>=0.4.1",
|
||||||
|
]
|
||||||
|
github = [
|
||||||
|
"gitpython==3.1.38",
|
||||||
|
"PyGithub==1.59.1",
|
||||||
|
]
|
||||||
|
rag = [
|
||||||
|
"python-docx>=1.1.0",
|
||||||
|
"lxml>=5.3.0,<5.4.0", # Pin to avoid etree import issues in 5.4.0
|
||||||
|
]
|
||||||
|
xml = [
|
||||||
|
"unstructured[local-inference, all-docs]>=0.17.2"
|
||||||
|
]
|
||||||
|
oxylabs = [
|
||||||
|
"oxylabs==2.0.0"
|
||||||
|
]
|
||||||
|
mongodb = [
|
||||||
|
"pymongo>=4.13"
|
||||||
|
]
|
||||||
|
mysql = [
|
||||||
|
"pymysql>=1.1.1"
|
||||||
|
]
|
||||||
|
postgresql = [
|
||||||
|
"psycopg2-binary>=2.9.10"
|
||||||
|
]
|
||||||
|
bedrock = [
|
||||||
|
"beautifulsoup4>=4.13.4",
|
||||||
|
"bedrock-agentcore>=0.1.0",
|
||||||
|
"playwright>=1.52.0",
|
||||||
|
"nest-asyncio>=1.6.0",
|
||||||
|
]
|
||||||
|
contextual = [
|
||||||
|
"contextual-client>=0.1.0",
|
||||||
|
"nest-asyncio>=1.6.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.version]
|
||||||
|
path = "src/crewai_tools/__init__.py"
|
||||||
102
lib/crewai-tools/src/crewai_tools/__init__.py
Normal file
102
lib/crewai-tools/src/crewai_tools/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# ruff: noqa: F401
|
||||||
|
from .adapters.enterprise_adapter import EnterpriseActionTool
|
||||||
|
from .adapters.mcp_adapter import MCPServerAdapter
|
||||||
|
from .adapters.zapier_adapter import ZapierActionTool
|
||||||
|
from .aws import (
|
||||||
|
BedrockInvokeAgentTool,
|
||||||
|
BedrockKBRetrieverTool,
|
||||||
|
S3ReaderTool,
|
||||||
|
S3WriterTool,
|
||||||
|
)
|
||||||
|
from .tools import (
|
||||||
|
AIMindTool,
|
||||||
|
ApifyActorsTool,
|
||||||
|
ArxivPaperTool,
|
||||||
|
BraveSearchTool,
|
||||||
|
BrightDataDatasetTool,
|
||||||
|
BrightDataSearchTool,
|
||||||
|
BrightDataWebUnlockerTool,
|
||||||
|
BrowserbaseLoadTool,
|
||||||
|
CSVSearchTool,
|
||||||
|
CodeDocsSearchTool,
|
||||||
|
CodeInterpreterTool,
|
||||||
|
ComposioTool,
|
||||||
|
ContextualAICreateAgentTool,
|
||||||
|
ContextualAIParseTool,
|
||||||
|
ContextualAIQueryTool,
|
||||||
|
ContextualAIRerankTool,
|
||||||
|
CouchbaseFTSVectorSearchTool,
|
||||||
|
CrewaiEnterpriseTools,
|
||||||
|
CrewaiPlatformTools,
|
||||||
|
DOCXSearchTool,
|
||||||
|
DallETool,
|
||||||
|
DatabricksQueryTool,
|
||||||
|
DirectoryReadTool,
|
||||||
|
DirectorySearchTool,
|
||||||
|
EXASearchTool,
|
||||||
|
FileCompressorTool,
|
||||||
|
FileReadTool,
|
||||||
|
FileWriterTool,
|
||||||
|
FirecrawlCrawlWebsiteTool,
|
||||||
|
FirecrawlScrapeWebsiteTool,
|
||||||
|
FirecrawlSearchTool,
|
||||||
|
GenerateCrewaiAutomationTool,
|
||||||
|
GithubSearchTool,
|
||||||
|
HyperbrowserLoadTool,
|
||||||
|
InvokeCrewAIAutomationTool,
|
||||||
|
JSONSearchTool,
|
||||||
|
LinkupSearchTool,
|
||||||
|
LlamaIndexTool,
|
||||||
|
MDXSearchTool,
|
||||||
|
MongoDBVectorSearchConfig,
|
||||||
|
MongoDBVectorSearchTool,
|
||||||
|
MultiOnTool,
|
||||||
|
MySQLSearchTool,
|
||||||
|
NL2SQLTool,
|
||||||
|
OCRTool,
|
||||||
|
OxylabsAmazonProductScraperTool,
|
||||||
|
OxylabsAmazonSearchScraperTool,
|
||||||
|
OxylabsGoogleSearchScraperTool,
|
||||||
|
OxylabsUniversalScraperTool,
|
||||||
|
PDFSearchTool,
|
||||||
|
PGSearchTool,
|
||||||
|
ParallelSearchTool,
|
||||||
|
PatronusEvalTool,
|
||||||
|
PatronusLocalEvaluatorTool,
|
||||||
|
PatronusPredefinedCriteriaEvalTool,
|
||||||
|
QdrantVectorSearchTool,
|
||||||
|
RagTool,
|
||||||
|
ScrapeElementFromWebsiteTool,
|
||||||
|
ScrapeWebsiteTool,
|
||||||
|
ScrapegraphScrapeTool,
|
||||||
|
ScrapegraphScrapeToolSchema,
|
||||||
|
ScrapflyScrapeWebsiteTool,
|
||||||
|
SeleniumScrapingTool,
|
||||||
|
SerpApiGoogleSearchTool,
|
||||||
|
SerpApiGoogleShoppingTool,
|
||||||
|
SerperDevTool,
|
||||||
|
SerperScrapeWebsiteTool,
|
||||||
|
SerplyJobSearchTool,
|
||||||
|
SerplyNewsSearchTool,
|
||||||
|
SerplyScholarSearchTool,
|
||||||
|
SerplyWebSearchTool,
|
||||||
|
SerplyWebpageToMarkdownTool,
|
||||||
|
SingleStoreSearchTool,
|
||||||
|
SnowflakeConfig,
|
||||||
|
SnowflakeSearchTool,
|
||||||
|
SpiderTool,
|
||||||
|
StagehandTool,
|
||||||
|
TXTSearchTool,
|
||||||
|
TavilyExtractorTool,
|
||||||
|
TavilySearchTool,
|
||||||
|
VisionTool,
|
||||||
|
WeaviateVectorSearchTool,
|
||||||
|
WebsiteSearchTool,
|
||||||
|
XMLSearchTool,
|
||||||
|
YoutubeChannelSearchTool,
|
||||||
|
YoutubeVideoSearchTool,
|
||||||
|
ZapierActionTools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = "1.0.0a1"
|
||||||
268
lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
268
lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Adapter for CrewAI's native RAG system."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypeAlias, TypedDict
|
||||||
|
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.config.utils import get_rag_client
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
|
from crewai.rag.factory import create_client
|
||||||
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
|
||||||
|
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class AddDocumentParams(TypedDict, total=False):
|
||||||
|
"""Parameters for adding documents to the RAG system."""
|
||||||
|
|
||||||
|
data_type: DataType
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
website: str
|
||||||
|
url: str
|
||||||
|
file_path: str | Path
|
||||||
|
github_url: str
|
||||||
|
youtube_url: str
|
||||||
|
directory_path: str | Path
|
||||||
|
|
||||||
|
|
||||||
|
class CrewAIRagAdapter(Adapter):
|
||||||
|
"""Adapter that uses CrewAI's native RAG system.
|
||||||
|
|
||||||
|
Supports custom vector database configuration through the config parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
collection_name: str = "default"
|
||||||
|
summarize: bool = False
|
||||||
|
similarity_threshold: float = 0.6
|
||||||
|
limit: int = 5
|
||||||
|
config: RagConfigType | None = None
|
||||||
|
_client: BaseClient | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
"""Initialize the CrewAI RAG client after model initialization."""
|
||||||
|
if self.config is not None:
|
||||||
|
self._client = create_client(self.config)
|
||||||
|
else:
|
||||||
|
self._client = get_rag_client()
|
||||||
|
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Query the knowledge base with a question.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The question to ask
|
||||||
|
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
||||||
|
limit: Maximum number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevant content from the knowledge base
|
||||||
|
"""
|
||||||
|
search_limit = limit if limit is not None else self.limit
|
||||||
|
search_threshold = (
|
||||||
|
similarity_threshold
|
||||||
|
if similarity_threshold is not None
|
||||||
|
else self.similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[SearchResult] = self._client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query=question,
|
||||||
|
limit=search_limit,
|
||||||
|
score_threshold=search_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "No relevant content found."
|
||||||
|
|
||||||
|
contents: list[str] = []
|
||||||
|
for result in results:
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if content:
|
||||||
|
contents.append(content)
|
||||||
|
|
||||||
|
return "\n\n".join(contents)
|
||||||
|
|
||||||
|
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||||
|
"""Add content to the knowledge base.
|
||||||
|
|
||||||
|
This method handles various input types and converts them to documents
|
||||||
|
for the vector database. It supports the data_type parameter for
|
||||||
|
compatibility with existing tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Content items to add (strings, paths, or document dicts)
|
||||||
|
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import LoaderResult
|
||||||
|
from crewai_tools.rag.data_types import DataType, DataTypes
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = []
|
||||||
|
data_type: DataType | None = kwargs.get("data_type")
|
||||||
|
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
source_ref: str
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||||
|
else:
|
||||||
|
source_ref = str(arg)
|
||||||
|
|
||||||
|
if not data_type:
|
||||||
|
data_type = DataTypes.from_content(source_ref)
|
||||||
|
|
||||||
|
if data_type == DataType.DIRECTORY:
|
||||||
|
if not os.path.isdir(source_ref):
|
||||||
|
raise ValueError(f"Directory does not exist: {source_ref}")
|
||||||
|
|
||||||
|
# Define binary and non-text file extensions to skip
|
||||||
|
binary_extensions = {
|
||||||
|
".pyc",
|
||||||
|
".pyo",
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".gif",
|
||||||
|
".bmp",
|
||||||
|
".ico",
|
||||||
|
".svg",
|
||||||
|
".webp",
|
||||||
|
".pdf",
|
||||||
|
".zip",
|
||||||
|
".tar",
|
||||||
|
".gz",
|
||||||
|
".bz2",
|
||||||
|
".7z",
|
||||||
|
".rar",
|
||||||
|
".exe",
|
||||||
|
".dll",
|
||||||
|
".so",
|
||||||
|
".dylib",
|
||||||
|
".bin",
|
||||||
|
".dat",
|
||||||
|
".db",
|
||||||
|
".sqlite",
|
||||||
|
".class",
|
||||||
|
".jar",
|
||||||
|
".war",
|
||||||
|
".ear",
|
||||||
|
}
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(source_ref):
|
||||||
|
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in files:
|
||||||
|
if filename.startswith("."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip binary files based on extension
|
||||||
|
file_ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if file_ext in binary_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip __pycache__ directories
|
||||||
|
if "__pycache__" in root:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path: str = os.path.join(root, filename)
|
||||||
|
try:
|
||||||
|
file_data_type: DataType = DataTypes.from_content(file_path)
|
||||||
|
file_loader = file_data_type.get_loader()
|
||||||
|
file_chunker = file_data_type.get_chunker()
|
||||||
|
|
||||||
|
file_source = SourceContent(file_path)
|
||||||
|
file_result: LoaderResult = file_loader.load(file_source)
|
||||||
|
|
||||||
|
file_chunks = file_chunker.chunk(file_result.content)
|
||||||
|
|
||||||
|
for chunk_idx, file_chunk in enumerate(file_chunks):
|
||||||
|
file_metadata: dict[str, Any] = base_metadata.copy()
|
||||||
|
file_metadata.update(file_result.metadata)
|
||||||
|
file_metadata["data_type"] = str(file_data_type)
|
||||||
|
file_metadata["file_path"] = file_path
|
||||||
|
file_metadata["chunk_index"] = chunk_idx
|
||||||
|
file_metadata["total_chunks"] = len(file_chunks)
|
||||||
|
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
file_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
|
chunk_id = hashlib.sha256(
|
||||||
|
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
documents.append(
|
||||||
|
{
|
||||||
|
"doc_id": chunk_id,
|
||||||
|
"content": file_chunk,
|
||||||
|
"metadata": sanitize_metadata_for_chromadb(
|
||||||
|
file_metadata
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently skip files that can't be processed
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
metadata: dict[str, Any] = base_metadata.copy()
|
||||||
|
|
||||||
|
if data_type in [
|
||||||
|
DataType.PDF_FILE,
|
||||||
|
DataType.TEXT_FILE,
|
||||||
|
DataType.DOCX,
|
||||||
|
DataType.CSV,
|
||||||
|
DataType.JSON,
|
||||||
|
DataType.XML,
|
||||||
|
DataType.MDX,
|
||||||
|
]:
|
||||||
|
if not os.path.isfile(source_ref):
|
||||||
|
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||||
|
|
||||||
|
loader = data_type.get_loader()
|
||||||
|
chunker = data_type.get_chunker()
|
||||||
|
|
||||||
|
source_content = SourceContent(source_ref)
|
||||||
|
loader_result: LoaderResult = loader.load(source_content)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(loader_result.content)
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_metadata: dict[str, Any] = metadata.copy()
|
||||||
|
chunk_metadata.update(loader_result.metadata)
|
||||||
|
chunk_metadata["data_type"] = str(data_type)
|
||||||
|
chunk_metadata["chunk_index"] = i
|
||||||
|
chunk_metadata["total_chunks"] = len(chunks)
|
||||||
|
chunk_metadata["source"] = source_ref
|
||||||
|
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
chunk_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
|
chunk_id = hashlib.sha256(
|
||||||
|
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
documents.append(
|
||||||
|
{
|
||||||
|
"doc_id": chunk_id,
|
||||||
|
"content": chunk,
|
||||||
|
"metadata": sanitize_metadata_for_chromadb(chunk_metadata),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
self._client.add_documents(
|
||||||
|
collection_name=self.collection_name, documents=documents
|
||||||
|
)
|
||||||
434
lib/crewai-tools/src/crewai_tools/adapters/enterprise_adapter.py
Normal file
434
lib/crewai-tools/src/crewai_tools/adapters/enterprise_adapter.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast, get_origin
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import Field, create_model
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def get_enterprise_api_base_url() -> str:
|
||||||
|
"""Get the enterprise API base URL from environment or use default."""
|
||||||
|
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
|
||||||
|
return f"{base_url}/crewai_plus/api/v1/integrations"
|
||||||
|
|
||||||
|
|
||||||
|
ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url()
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseActionTool(BaseTool):
|
||||||
|
"""A tool that executes a specific enterprise action."""
|
||||||
|
|
||||||
|
enterprise_action_token: str = Field(
|
||||||
|
default="", description="The enterprise action token"
|
||||||
|
)
|
||||||
|
action_name: str = Field(default="", description="The name of the action")
|
||||||
|
action_schema: Dict[str, Any] = Field(
|
||||||
|
default={}, description="The schema of the action"
|
||||||
|
)
|
||||||
|
enterprise_api_base_url: str = Field(
|
||||||
|
default=ENTERPRISE_API_BASE_URL, description="The base API URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
enterprise_action_token: str,
|
||||||
|
action_name: str,
|
||||||
|
action_schema: Dict[str, Any],
|
||||||
|
enterprise_api_base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._model_registry = {}
|
||||||
|
self._base_name = self._sanitize_name(name)
|
||||||
|
|
||||||
|
schema_props, required = self._extract_schema_info(action_schema)
|
||||||
|
|
||||||
|
# Define field definitions for the model
|
||||||
|
field_definitions = {}
|
||||||
|
for param_name, param_details in schema_props.items():
|
||||||
|
param_desc = param_details.get("description", "")
|
||||||
|
is_required = param_name in required
|
||||||
|
|
||||||
|
try:
|
||||||
|
field_type = self._process_schema_type(
|
||||||
|
param_details, self._sanitize_name(param_name).title()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process schema for {param_name}: {e}")
|
||||||
|
field_type = str
|
||||||
|
|
||||||
|
# Create field definition based on requirement
|
||||||
|
field_definitions[param_name] = self._create_field_definition(
|
||||||
|
field_type, is_required, param_desc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the model
|
||||||
|
if field_definitions:
|
||||||
|
try:
|
||||||
|
args_schema = create_model(
|
||||||
|
f"{self._base_name}Schema", **field_definitions
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not create main schema model: {e}")
|
||||||
|
args_schema = create_model(
|
||||||
|
f"{self._base_name}Schema",
|
||||||
|
input_text=(str, Field(description="Input for the action")),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback for empty schema
|
||||||
|
args_schema = create_model(
|
||||||
|
f"{self._base_name}Schema",
|
||||||
|
input_text=(str, Field(description="Input for the action")),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(name=name, description=description, args_schema=args_schema)
|
||||||
|
self.enterprise_action_token = enterprise_action_token
|
||||||
|
self.action_name = action_name
|
||||||
|
self.action_schema = action_schema
|
||||||
|
self.enterprise_api_base_url = (
|
||||||
|
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names to create proper Python class names."""
|
||||||
|
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name)
|
||||||
|
parts = sanitized.split("_")
|
||||||
|
return "".join(word.capitalize() for word in parts if word)
|
||||||
|
|
||||||
|
def _extract_schema_info(
|
||||||
|
self, action_schema: Dict[str, Any]
|
||||||
|
) -> tuple[Dict[str, Any], List[str]]:
|
||||||
|
"""Extract schema properties and required fields from action schema."""
|
||||||
|
schema_props = (
|
||||||
|
action_schema.get("function", {})
|
||||||
|
.get("parameters", {})
|
||||||
|
.get("properties", {})
|
||||||
|
)
|
||||||
|
required = (
|
||||||
|
action_schema.get("function", {}).get("parameters", {}).get("required", [])
|
||||||
|
)
|
||||||
|
return schema_props, required
|
||||||
|
|
||||||
|
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||||
|
"""Process a JSON schema and return appropriate Python type."""
|
||||||
|
if "anyOf" in schema:
|
||||||
|
any_of_types = schema["anyOf"]
|
||||||
|
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||||
|
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
|
||||||
|
|
||||||
|
if non_null_types:
|
||||||
|
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||||
|
return Optional[base_type] if is_nullable else base_type
|
||||||
|
return cast(Type[Any], Optional[str])
|
||||||
|
|
||||||
|
if "oneOf" in schema:
|
||||||
|
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||||
|
|
||||||
|
if "allOf" in schema:
|
||||||
|
return self._process_schema_type(schema["allOf"][0], type_name)
|
||||||
|
|
||||||
|
json_type = schema.get("type", "string")
|
||||||
|
|
||||||
|
if "enum" in schema:
|
||||||
|
enum_values = schema["enum"]
|
||||||
|
if not enum_values:
|
||||||
|
return self._map_json_type_to_python(json_type)
|
||||||
|
return Literal[tuple(enum_values)] # type: ignore[return-value]
|
||||||
|
|
||||||
|
if json_type == "array":
|
||||||
|
items_schema = schema.get("items", {"type": "string"})
|
||||||
|
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||||
|
return List[item_type]
|
||||||
|
|
||||||
|
if json_type == "object":
|
||||||
|
return self._create_nested_model(schema, type_name)
|
||||||
|
|
||||||
|
return self._map_json_type_to_python(json_type)
|
||||||
|
|
||||||
|
def _create_nested_model(
|
||||||
|
self, schema: Dict[str, Any], model_name: str
|
||||||
|
) -> Type[Any]:
|
||||||
|
"""Create a nested Pydantic model for complex objects."""
|
||||||
|
full_model_name = f"{self._base_name}{model_name}"
|
||||||
|
|
||||||
|
if full_model_name in self._model_registry:
|
||||||
|
return self._model_registry[full_model_name]
|
||||||
|
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
|
if not properties:
|
||||||
|
return dict
|
||||||
|
|
||||||
|
field_definitions = {}
|
||||||
|
for prop_name, prop_schema in properties.items():
|
||||||
|
prop_desc = prop_schema.get("description", "")
|
||||||
|
is_required = prop_name in required_fields
|
||||||
|
|
||||||
|
try:
|
||||||
|
prop_type = self._process_schema_type(
|
||||||
|
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process schema for {prop_name}: {e}")
|
||||||
|
prop_type = str
|
||||||
|
|
||||||
|
field_definitions[prop_name] = self._create_field_definition(
|
||||||
|
prop_type, is_required, prop_desc
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nested_model = create_model(full_model_name, **field_definitions)
|
||||||
|
self._model_registry[full_model_name] = nested_model
|
||||||
|
return nested_model
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not create nested model {full_model_name}: {e}")
|
||||||
|
return dict
|
||||||
|
|
||||||
|
def _create_field_definition(
|
||||||
|
self, field_type: Type[Any], is_required: bool, description: str
|
||||||
|
) -> tuple:
|
||||||
|
"""Create Pydantic field definition based on type and requirement."""
|
||||||
|
if is_required:
|
||||||
|
return (field_type, Field(description=description))
|
||||||
|
if get_origin(field_type) is Union:
|
||||||
|
return (field_type, Field(default=None, description=description))
|
||||||
|
return (
|
||||||
|
Optional[field_type],
|
||||||
|
Field(default=None, description=description),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||||
|
"""Map basic JSON schema types to Python types."""
|
||||||
|
type_mapping = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": float,
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
"null": type(None),
|
||||||
|
}
|
||||||
|
return type_mapping.get(json_type, str)
|
||||||
|
|
||||||
|
def _get_required_nullable_fields(self) -> List[str]:
|
||||||
|
"""Get a list of required nullable fields from the action schema."""
|
||||||
|
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||||
|
|
||||||
|
required_nullable_fields = []
|
||||||
|
for param_name in required:
|
||||||
|
param_details = schema_props.get(param_name, {})
|
||||||
|
if self._is_nullable_type(param_details):
|
||||||
|
required_nullable_fields.append(param_name)
|
||||||
|
|
||||||
|
return required_nullable_fields
|
||||||
|
|
||||||
|
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||||
|
"""Check if a schema represents a nullable type."""
|
||||||
|
if "anyOf" in schema:
|
||||||
|
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||||
|
return schema.get("type") == "null"
|
||||||
|
|
||||||
|
def _run(self, **kwargs) -> str:
|
||||||
|
"""Execute the specific enterprise action with validated parameters."""
|
||||||
|
try:
|
||||||
|
cleaned_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is not None:
|
||||||
|
cleaned_kwargs[key] = value
|
||||||
|
|
||||||
|
required_nullable_fields = self._get_required_nullable_fields()
|
||||||
|
|
||||||
|
for field_name in required_nullable_fields:
|
||||||
|
if field_name not in cleaned_kwargs:
|
||||||
|
cleaned_kwargs[field_name] = None
|
||||||
|
|
||||||
|
api_url = (
|
||||||
|
f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.enterprise_action_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = cleaned_kwargs
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=api_url, headers=headers, json=payload, timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if not response.ok:
|
||||||
|
error_message = data.get("error", {}).get("message", json.dumps(data))
|
||||||
|
return f"API request failed: {error_message}"
|
||||||
|
|
||||||
|
return json.dumps(data, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing action {self.action_name}: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseActionKitToolAdapter:
|
||||||
|
"""Adapter that creates BaseTool instances for enterprise actions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enterprise_action_token: str,
|
||||||
|
enterprise_api_base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the adapter with an enterprise action token."""
|
||||||
|
self._set_enterprise_action_token(enterprise_action_token)
|
||||||
|
self._actions_schema = {}
|
||||||
|
self._tools = None
|
||||||
|
self.enterprise_api_base_url = (
|
||||||
|
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||||
|
)
|
||||||
|
|
||||||
|
def tools(self) -> List[BaseTool]:
|
||||||
|
"""Get the list of tools created from enterprise actions."""
|
||||||
|
if self._tools is None:
|
||||||
|
self._fetch_actions()
|
||||||
|
self._create_tools()
|
||||||
|
return self._tools or []
|
||||||
|
|
||||||
|
def _fetch_actions(self):
|
||||||
|
"""Fetch available actions from the API."""
|
||||||
|
try:
|
||||||
|
actions_url = f"{self.enterprise_api_base_url}/actions"
|
||||||
|
headers = {"Authorization": f"Bearer {self.enterprise_action_token}"}
|
||||||
|
|
||||||
|
response = requests.get(actions_url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
raw_data = response.json()
|
||||||
|
if "actions" not in raw_data:
|
||||||
|
print(f"Unexpected API response structure: {raw_data}")
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed_schema = {}
|
||||||
|
action_categories = raw_data["actions"]
|
||||||
|
|
||||||
|
for integration_type, action_list in action_categories.items():
|
||||||
|
if isinstance(action_list, list):
|
||||||
|
for action in action_list:
|
||||||
|
action_name = action.get("name")
|
||||||
|
if action_name:
|
||||||
|
action_schema = {
|
||||||
|
"function": {
|
||||||
|
"name": action_name,
|
||||||
|
"description": action.get(
|
||||||
|
"description", f"Execute {action_name}"
|
||||||
|
),
|
||||||
|
"parameters": action.get("parameters", {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed_schema[action_name] = action_schema
|
||||||
|
|
||||||
|
self._actions_schema = parsed_schema
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error fetching actions: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def _generate_detailed_description(
|
||||||
|
self, schema: Dict[str, Any], indent: int = 0
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate detailed description for nested schema structures."""
|
||||||
|
descriptions = []
|
||||||
|
indent_str = " " * indent
|
||||||
|
|
||||||
|
schema_type = schema.get("type", "string")
|
||||||
|
|
||||||
|
if schema_type == "object":
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
|
if properties:
|
||||||
|
descriptions.append(f"{indent_str}Object with properties:")
|
||||||
|
for prop_name, prop_schema in properties.items():
|
||||||
|
prop_desc = prop_schema.get("description", "")
|
||||||
|
is_required = prop_name in required_fields
|
||||||
|
req_str = " (required)" if is_required else " (optional)"
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} - {prop_name}: {prop_desc}{req_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prop_schema.get("type") == "object":
|
||||||
|
descriptions.extend(
|
||||||
|
self._generate_detailed_description(prop_schema, indent + 2)
|
||||||
|
)
|
||||||
|
elif prop_schema.get("type") == "array":
|
||||||
|
items_schema = prop_schema.get("items", {})
|
||||||
|
if items_schema.get("type") == "object":
|
||||||
|
descriptions.append(f"{indent_str} Array of objects:")
|
||||||
|
descriptions.extend(
|
||||||
|
self._generate_detailed_description(
|
||||||
|
items_schema, indent + 3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "enum" in items_schema:
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} Array of enum values: {items_schema['enum']}"
|
||||||
|
)
|
||||||
|
elif "enum" in prop_schema:
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} Enum values: {prop_schema['enum']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return descriptions
|
||||||
|
|
||||||
|
def _create_tools(self):
|
||||||
|
"""Create BaseTool instances for each action."""
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
for action_name, action_schema in self._actions_schema.items():
|
||||||
|
function_details = action_schema.get("function", {})
|
||||||
|
description = function_details.get("description", f"Execute {action_name}")
|
||||||
|
|
||||||
|
parameters = function_details.get("parameters", {})
|
||||||
|
param_descriptions = []
|
||||||
|
|
||||||
|
if parameters.get("properties"):
|
||||||
|
param_descriptions.append("\nDetailed Parameter Structure:")
|
||||||
|
param_descriptions.extend(
|
||||||
|
self._generate_detailed_description(parameters)
|
||||||
|
)
|
||||||
|
|
||||||
|
full_description = description + "\n".join(param_descriptions)
|
||||||
|
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name=action_name.lower().replace(" ", "_"),
|
||||||
|
description=full_description,
|
||||||
|
action_name=action_name,
|
||||||
|
action_schema=action_schema,
|
||||||
|
enterprise_action_token=self.enterprise_action_token,
|
||||||
|
enterprise_api_base_url=self.enterprise_api_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
self._tools = tools
|
||||||
|
|
||||||
|
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]):
|
||||||
|
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
|
||||||
|
warnings.warn(
|
||||||
|
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
token = enterprise_action_token or os.environ.get(
|
||||||
|
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enterprise_action_token = token
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.tools()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
pass
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
from lancedb import DBConnection as LanceDBConnection, connect as lancedb_connect
|
||||||
|
from lancedb.table import Table as LanceDBTable
|
||||||
|
from openai import Client as OpenAIClient
|
||||||
|
from pydantic import Field, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
|
def _default_embedding_function():
|
||||||
|
client = OpenAIClient()
|
||||||
|
|
||||||
|
def _embedding_function(input):
|
||||||
|
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
|
||||||
|
return [record.embedding for record in rs.data]
|
||||||
|
|
||||||
|
return _embedding_function
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBAdapter(Adapter):
|
||||||
|
uri: str | Path
|
||||||
|
table_name: str
|
||||||
|
embedding_function: Callable = Field(default_factory=_default_embedding_function)
|
||||||
|
top_k: int = 3
|
||||||
|
vector_column_name: str = "vector"
|
||||||
|
text_column_name: str = "text"
|
||||||
|
|
||||||
|
_db: LanceDBConnection = PrivateAttr()
|
||||||
|
_table: LanceDBTable = PrivateAttr()
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
self._db = lancedb_connect(self.uri)
|
||||||
|
self._table = self._db.open_table(self.table_name)
|
||||||
|
|
||||||
|
super().model_post_init(__context)
|
||||||
|
|
||||||
|
def query(self, question: str) -> str:
|
||||||
|
query = self.embedding_function([question])[0]
|
||||||
|
results = (
|
||||||
|
self._table.search(query, vector_column_name=self.vector_column_name)
|
||||||
|
.limit(self.top_k)
|
||||||
|
.select([self.text_column_name])
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
values = [result[self.text_column_name] for result in results]
|
||||||
|
return "\n".join(values)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._table.add(*args, **kwargs)
|
||||||
167
lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py
Normal file
167
lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
MCPServer for CrewAI.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from mcpadapt.core import MCPAdapt
|
||||||
|
from mcpadapt.crewai_adapter import CrewAIAdapter
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from mcpadapt.core import MCPAdapt
|
||||||
|
from mcpadapt.crewai_adapter import CrewAIAdapter
|
||||||
|
|
||||||
|
MCP_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MCP_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerAdapter:
|
||||||
|
"""Manages the lifecycle of an MCP server and make its tools available to CrewAI.
|
||||||
|
|
||||||
|
Note: tools can only be accessed after the server has been started with the
|
||||||
|
`start()` method.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tools: The CrewAI tools available from the MCP server.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# context manager + stdio
|
||||||
|
with MCPServerAdapter(...) as tools:
|
||||||
|
# tools is now available
|
||||||
|
|
||||||
|
# context manager + sse
|
||||||
|
with MCPServerAdapter({"url": "http://localhost:8000/sse"}) as tools:
|
||||||
|
# tools is now available
|
||||||
|
|
||||||
|
# context manager with filtered tools
|
||||||
|
with MCPServerAdapter(..., "tool1", "tool2") as filtered_tools:
|
||||||
|
# only tool1 and tool2 are available
|
||||||
|
|
||||||
|
# context manager with custom connect timeout (60 seconds)
|
||||||
|
with MCPServerAdapter(..., connect_timeout=60) as tools:
|
||||||
|
# tools is now available with longer timeout
|
||||||
|
|
||||||
|
# manually stop mcp server
|
||||||
|
try:
|
||||||
|
mcp_server = MCPServerAdapter(...)
|
||||||
|
tools = mcp_server.tools # all tools
|
||||||
|
|
||||||
|
# or with filtered tools and custom timeout
|
||||||
|
mcp_server = MCPServerAdapter(..., "tool1", "tool2", connect_timeout=45)
|
||||||
|
filtered_tools = mcp_server.tools # only tool1 and tool2
|
||||||
|
...
|
||||||
|
finally:
|
||||||
|
mcp_server.stop()
|
||||||
|
|
||||||
|
# Best practice is ensure cleanup is done after use.
|
||||||
|
mcp_server.stop() # run after crew().kickoff()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
serverparams: StdioServerParameters | dict[str, Any],
|
||||||
|
*tool_names: str,
|
||||||
|
connect_timeout: int = 30,
|
||||||
|
):
|
||||||
|
"""Initialize the MCP Server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
serverparams: The parameters for the MCP server it supports either a
|
||||||
|
`StdioServerParameters` or a `dict` respectively for STDIO and SSE.
|
||||||
|
*tool_names: Optional names of tools to filter. If provided, only tools with
|
||||||
|
matching names will be available.
|
||||||
|
connect_timeout: Connection timeout in seconds to the MCP server (default is 30s).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self._adapter = None
|
||||||
|
self._tools = None
|
||||||
|
self._tool_names = list(tool_names) if tool_names else None
|
||||||
|
|
||||||
|
if not MCP_AVAILABLE:
|
||||||
|
import click
|
||||||
|
|
||||||
|
if click.confirm(
|
||||||
|
"You are missing the 'mcp' package. Would you like to install it?"
|
||||||
|
):
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True)
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
raise ImportError("Failed to install mcp package")
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"`mcp` package not found, please run `uv add crewai-tools[mcp]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._serverparams = serverparams
|
||||||
|
self._adapter = MCPAdapt(
|
||||||
|
self._serverparams, CrewAIAdapter(), connect_timeout
|
||||||
|
)
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self._adapter is not None:
|
||||||
|
try:
|
||||||
|
self.stop()
|
||||||
|
except Exception as stop_e:
|
||||||
|
logger.error(f"Error during stop cleanup: {stop_e}")
|
||||||
|
raise RuntimeError(f"Failed to initialize MCP Adapter: {e}") from e
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the MCP server and initialize the tools."""
|
||||||
|
self._tools = self._adapter.__enter__()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the MCP server"""
|
||||||
|
self._adapter.__exit__(None, None, None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools(self) -> ToolCollection[BaseTool]:
|
||||||
|
"""The CrewAI tools available from the MCP server.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the MCP server is not started.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The CrewAI tools available from the MCP server.
|
||||||
|
"""
|
||||||
|
if self._tools is None:
|
||||||
|
raise ValueError(
|
||||||
|
"MCP server not started, run `mcp_server.start()` first before accessing `tools`"
|
||||||
|
)
|
||||||
|
|
||||||
|
tools_collection = ToolCollection(self._tools)
|
||||||
|
if self._tool_names:
|
||||||
|
return tools_collection.filter_by_names(self._tool_names)
|
||||||
|
return tools_collection
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""
|
||||||
|
Enter the context manager. Note that `__init__()` already starts the MCP server.
|
||||||
|
So tools should already be available.
|
||||||
|
"""
|
||||||
|
return self.tools
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Exit the context manager."""
|
||||||
|
return self._adapter.__exit__(exc_type, exc_value, traceback)
|
||||||
38
lib/crewai-tools/src/crewai_tools/adapters/rag_adapter.py
Normal file
38
lib/crewai-tools/src/crewai_tools/adapters/rag_adapter.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from crewai_tools.rag.core import RAG
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
|
||||||
|
|
||||||
|
class RAGAdapter(Adapter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str = "crewai_knowledge_base",
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
embedding_model: str = "text-embedding-3-small",
|
||||||
|
top_k: int = 5,
|
||||||
|
embedding_api_key: Optional[str] = None,
|
||||||
|
**embedding_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Prepare embedding configuration
|
||||||
|
embedding_config = {"api_key": embedding_api_key, **embedding_kwargs}
|
||||||
|
|
||||||
|
self._adapter = RAG(
|
||||||
|
collection_name=collection_name,
|
||||||
|
persist_directory=persist_directory,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
top_k=top_k,
|
||||||
|
embedding_config=embedding_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(self, question: str) -> str:
|
||||||
|
return self._adapter.query(question)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._adapter.add(*args, **kwargs)
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
from typing import Callable, Dict, Generic, List, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseTool)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCollection(list, Generic[T]):
|
||||||
|
"""
|
||||||
|
A collection of tools that can be accessed by index or name
|
||||||
|
|
||||||
|
This class extends the built-in list to provide dictionary-like
|
||||||
|
access to tools based on their name property.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
tools = ToolCollection(list_of_tools)
|
||||||
|
# Access by index (regular list behavior)
|
||||||
|
first_tool = tools[0]
|
||||||
|
# Access by name (new functionality)
|
||||||
|
search_tool = tools["search"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tools: Optional[List[T]] = None):
|
||||||
|
super().__init__(tools or [])
|
||||||
|
self._name_cache: Dict[str, T] = {}
|
||||||
|
self._build_name_cache()
|
||||||
|
|
||||||
|
def _build_name_cache(self) -> None:
|
||||||
|
self._name_cache = {tool.name.lower(): tool for tool in self}
|
||||||
|
|
||||||
|
def __getitem__(self, key: Union[int, str]) -> T:
|
||||||
|
if isinstance(key, str):
|
||||||
|
return self._name_cache[key.lower()]
|
||||||
|
return super().__getitem__(key)
|
||||||
|
|
||||||
|
def append(self, tool: T) -> None:
|
||||||
|
super().append(tool)
|
||||||
|
self._name_cache[tool.name.lower()] = tool
|
||||||
|
|
||||||
|
def extend(self, tools: List[T]) -> None:
|
||||||
|
super().extend(tools)
|
||||||
|
self._build_name_cache()
|
||||||
|
|
||||||
|
def insert(self, index: int, tool: T) -> None:
|
||||||
|
super().insert(index, tool)
|
||||||
|
self._name_cache[tool.name.lower()] = tool
|
||||||
|
|
||||||
|
def remove(self, tool: T) -> None:
|
||||||
|
super().remove(tool)
|
||||||
|
if tool.name.lower() in self._name_cache:
|
||||||
|
del self._name_cache[tool.name.lower()]
|
||||||
|
|
||||||
|
def pop(self, index: int = -1) -> T:
|
||||||
|
tool = super().pop(index)
|
||||||
|
if tool.name.lower() in self._name_cache:
|
||||||
|
del self._name_cache[tool.name.lower()]
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]":
|
||||||
|
if names is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
return ToolCollection(
|
||||||
|
[
|
||||||
|
tool
|
||||||
|
for name in names
|
||||||
|
if (tool := self._name_cache.get(name.lower())) is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_where(self, func: Callable[[T], bool]) -> "ToolCollection[T]":
|
||||||
|
return ToolCollection([tool for tool in self if func(tool)])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self._name_cache.clear()
|
||||||
123
lib/crewai-tools/src/crewai_tools/adapters/zapier_adapter.py
Normal file
123
lib/crewai-tools/src/crewai_tools/adapters/zapier_adapter.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import Field, create_model
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
ACTIONS_URL = "https://actions.zapier.com/api/v2/ai-actions"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZapierActionTool(BaseTool):
|
||||||
|
"""
|
||||||
|
A tool that wraps a Zapier action
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(description="Tool name")
|
||||||
|
description: str = Field(description="Tool description")
|
||||||
|
action_id: str = Field(description="Zapier action ID")
|
||||||
|
api_key: str = Field(description="Zapier API key")
|
||||||
|
|
||||||
|
def _run(self, **kwargs) -> str:
|
||||||
|
"""Execute the Zapier action"""
|
||||||
|
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
instructions = kwargs.pop(
|
||||||
|
"instructions", "Execute this action with the provided parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not kwargs:
|
||||||
|
action_params = {"instructions": instructions, "params": {}}
|
||||||
|
else:
|
||||||
|
formatted_params = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
formatted_params[key] = {
|
||||||
|
"value": value,
|
||||||
|
"mode": "guess",
|
||||||
|
}
|
||||||
|
action_params = {"instructions": instructions, "params": formatted_params}
|
||||||
|
|
||||||
|
execute_url = f"{ACTIONS_URL}/{self.action_id}/execute/"
|
||||||
|
response = requests.request(
|
||||||
|
"POST", execute_url, headers=headers, json=action_params
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
class ZapierActionsAdapter:
|
||||||
|
"""
|
||||||
|
Adapter for Zapier Actions
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
self.api_key = api_key or os.getenv("ZAPIER_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
logger.error("Zapier Actions API key is required")
|
||||||
|
raise ValueError("Zapier Actions API key is required")
|
||||||
|
|
||||||
|
def get_zapier_actions(self):
|
||||||
|
headers = {
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
}
|
||||||
|
response = requests.request("GET", ACTIONS_URL, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
return response_json
|
||||||
|
|
||||||
|
def tools(self) -> List[BaseTool]:
|
||||||
|
"""Convert Zapier actions to BaseTool instances"""
|
||||||
|
actions_response = self.get_zapier_actions()
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
for action in actions_response.get("results", []):
|
||||||
|
tool_name = (
|
||||||
|
action["meta"]["action_label"]
|
||||||
|
.replace(" ", "_")
|
||||||
|
.replace(":", "")
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
params = action.get("params", {})
|
||||||
|
args_fields = {}
|
||||||
|
|
||||||
|
args_fields["instructions"] = (
|
||||||
|
str,
|
||||||
|
Field(description="Instructions for how to execute this action"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for param_name, param_info in params.items():
|
||||||
|
field_type = (
|
||||||
|
str # Default to string, could be enhanced based on param_info
|
||||||
|
)
|
||||||
|
field_description = (
|
||||||
|
param_info.get("description", "")
|
||||||
|
if isinstance(param_info, dict)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
args_fields[param_name] = (
|
||||||
|
field_type,
|
||||||
|
Field(description=field_description),
|
||||||
|
)
|
||||||
|
|
||||||
|
args_schema = create_model(f"{tool_name.title()}Schema", **args_fields)
|
||||||
|
|
||||||
|
tool = ZapierActionTool(
|
||||||
|
name=tool_name,
|
||||||
|
description=action["description"],
|
||||||
|
action_id=action["id"],
|
||||||
|
api_key=self.api_key,
|
||||||
|
args_schema=args_schema,
|
||||||
|
)
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
17
lib/crewai-tools/src/crewai_tools/aws/__init__.py
Normal file
17
lib/crewai-tools/src/crewai_tools/aws/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .bedrock import (
|
||||||
|
BedrockInvokeAgentTool,
|
||||||
|
BedrockKBRetrieverTool,
|
||||||
|
create_browser_toolkit,
|
||||||
|
create_code_interpreter_toolkit,
|
||||||
|
)
|
||||||
|
from .s3 import S3ReaderTool, S3WriterTool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BedrockInvokeAgentTool",
|
||||||
|
"BedrockKBRetrieverTool",
|
||||||
|
"S3ReaderTool",
|
||||||
|
"S3WriterTool",
|
||||||
|
"create_browser_toolkit",
|
||||||
|
"create_code_interpreter_toolkit",
|
||||||
|
]
|
||||||
12
lib/crewai-tools/src/crewai_tools/aws/bedrock/__init__.py
Normal file
12
lib/crewai-tools/src/crewai_tools/aws/bedrock/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from .agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||||
|
from .browser import create_browser_toolkit
|
||||||
|
from .code_interpreter import create_code_interpreter_toolkit
|
||||||
|
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BedrockInvokeAgentTool",
|
||||||
|
"BedrockKBRetrieverTool",
|
||||||
|
"create_browser_toolkit",
|
||||||
|
"create_code_interpreter_toolkit",
|
||||||
|
]
|
||||||
181
lib/crewai-tools/src/crewai_tools/aws/bedrock/agents/README.md
Normal file
181
lib/crewai-tools/src/crewai_tools/aws/bedrock/agents/README.md
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# BedrockInvokeAgentTool
|
||||||
|
|
||||||
|
The `BedrockInvokeAgentTool` enables CrewAI agents to invoke Amazon Bedrock Agents and leverage their capabilities within your workflows.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- AWS credentials configured (either through environment variables or AWS CLI)
|
||||||
|
- `boto3` and `python-dotenv` packages
|
||||||
|
- Access to Amazon Bedrock Agents
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Here's how to use the tool with a CrewAI agent:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||||
|
|
||||||
|
# Initialize the tool
|
||||||
|
agent_tool = BedrockInvokeAgentTool(
|
||||||
|
agent_id="your-agent-id",
|
||||||
|
agent_alias_id="your-agent-alias-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the tool
|
||||||
|
aws_expert = Agent(
|
||||||
|
role='AWS Service Expert',
|
||||||
|
goal='Help users understand AWS services and quotas',
|
||||||
|
backstory='I am an expert in AWS services and can provide detailed information about them.',
|
||||||
|
tools=[agent_tool],
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
quota_task = Task(
|
||||||
|
description="Find out the current service quotas for EC2 in us-west-2 and explain any recent changes.",
|
||||||
|
agent=aws_expert
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a crew with the agent
|
||||||
|
crew = Crew(
|
||||||
|
agents=[aws_expert],
|
||||||
|
tasks=[quota_task],
|
||||||
|
verbose=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tool Arguments
|
||||||
|
|
||||||
|
| Argument | Type | Required | Default | Description |
|
||||||
|
|----------|------|----------|---------|-------------|
|
||||||
|
| agent_id | str | Yes | None | The unique identifier of the Bedrock agent |
|
||||||
|
| agent_alias_id | str | Yes | None | The unique identifier of the agent alias |
|
||||||
|
| session_id | str | No | timestamp | The unique identifier of the session |
|
||||||
|
| enable_trace | bool | No | False | Whether to enable trace for debugging |
|
||||||
|
| end_session | bool | No | False | Whether to end the session after invocation |
|
||||||
|
| description | str | No | None | Custom description for the tool |
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
BEDROCK_AGENT_ID=your-agent-id # Alternative to passing agent_id
|
||||||
|
BEDROCK_AGENT_ALIAS_ID=your-agent-alias-id # Alternative to passing agent_alias_id
|
||||||
|
AWS_REGION=your-aws-region # Defaults to us-west-2
|
||||||
|
AWS_ACCESS_KEY_ID=your-access-key # Required for AWS authentication
|
||||||
|
AWS_SECRET_ACCESS_KEY=your-secret-key # Required for AWS authentication
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Multi-Agent Workflow with Session Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, Process
|
||||||
|
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||||
|
|
||||||
|
# Initialize tools with session management
|
||||||
|
initial_tool = BedrockInvokeAgentTool(
|
||||||
|
agent_id="your-agent-id",
|
||||||
|
agent_alias_id="your-agent-alias-id",
|
||||||
|
session_id="custom-session-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
followup_tool = BedrockInvokeAgentTool(
|
||||||
|
agent_id="your-agent-id",
|
||||||
|
agent_alias_id="your-agent-alias-id",
|
||||||
|
session_id="custom-session-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_tool = BedrockInvokeAgentTool(
|
||||||
|
agent_id="your-agent-id",
|
||||||
|
agent_alias_id="your-agent-alias-id",
|
||||||
|
session_id="custom-session-id",
|
||||||
|
end_session=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agents for different stages
|
||||||
|
researcher = Agent(
|
||||||
|
role='AWS Service Researcher',
|
||||||
|
goal='Gather information about AWS services',
|
||||||
|
backstory='I am specialized in finding detailed AWS service information.',
|
||||||
|
tools=[initial_tool]
|
||||||
|
)
|
||||||
|
|
||||||
|
analyst = Agent(
|
||||||
|
role='Service Compatibility Analyst',
|
||||||
|
goal='Analyze service compatibility and requirements',
|
||||||
|
backstory='I analyze AWS services for compatibility and integration possibilities.',
|
||||||
|
tools=[followup_tool]
|
||||||
|
)
|
||||||
|
|
||||||
|
summarizer = Agent(
|
||||||
|
role='Technical Documentation Writer',
|
||||||
|
goal='Create clear technical summaries',
|
||||||
|
backstory='I specialize in creating clear, concise technical documentation.',
|
||||||
|
tools=[final_tool]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks
|
||||||
|
research_task = Task(
|
||||||
|
description="Find all available AWS services in us-west-2 region.",
|
||||||
|
agent=researcher
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis_task = Task(
|
||||||
|
description="Analyze which services support IPv6 and their implementation requirements.",
|
||||||
|
agent=analyst
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_task = Task(
|
||||||
|
description="Create a summary of IPv6-compatible services and their key features.",
|
||||||
|
agent=summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a crew with the agents and tasks
|
||||||
|
crew = Crew(
|
||||||
|
agents=[researcher, analyst, summarizer],
|
||||||
|
tasks=[research_task, analysis_task, summary_task],
|
||||||
|
process=Process.sequential,
|
||||||
|
verbose=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### Hybrid Multi-Agent Collaborations
|
||||||
|
- Create workflows where CrewAI agents collaborate with managed Bedrock agents running as services in AWS
|
||||||
|
- Enable scenarios where sensitive data processing happens within your AWS environment while other agents operate externally
|
||||||
|
- Bridge on-premises CrewAI agents with cloud-based Bedrock agents for distributed intelligence workflows
|
||||||
|
|
||||||
|
### Data Sovereignty and Compliance
|
||||||
|
- Keep data-sensitive agentic workflows within your AWS environment while allowing external CrewAI agents to orchestrate tasks
|
||||||
|
- Maintain compliance with data residency requirements by processing sensitive information only within your AWS account
|
||||||
|
- Enable secure multi-agent collaborations where some agents cannot access your organization's private data
|
||||||
|
|
||||||
|
### Seamless AWS Service Integration
|
||||||
|
- Access any AWS service through Amazon Bedrock Actions without writing complex integration code
|
||||||
|
- Enable CrewAI agents to interact with AWS services through natural language requests
|
||||||
|
- Leverage pre-built Bedrock agent capabilities to interact with AWS services like Bedrock Knowledge Bases, Lambda, and more
|
||||||
|
|
||||||
|
### Scalable Hybrid Agent Architectures
|
||||||
|
- Offload computationally intensive tasks to managed Bedrock agents while lightweight tasks run in CrewAI
|
||||||
|
- Scale agent processing by distributing workloads between local CrewAI agents and cloud-based Bedrock agents
|
||||||
|
|
||||||
|
### Cross-Organizational Agent Collaboration
|
||||||
|
- Enable secure collaboration between your organization's CrewAI agents and partner organizations' Bedrock agents
|
||||||
|
- Create workflows where external expertise from Bedrock agents can be incorporated without exposing sensitive data
|
||||||
|
- Build agent ecosystems that span organizational boundaries while maintaining security and data control
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .invoke_agent_tool import BedrockInvokeAgentTool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BedrockInvokeAgentTool"]
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||||
|
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockInvokeAgentToolInput(BaseModel):
|
||||||
|
"""Input schema for BedrockInvokeAgentTool."""
|
||||||
|
|
||||||
|
query: str = Field(..., description="The query to send to the agent")
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockInvokeAgentTool(BaseTool):
|
||||||
|
name: str = "Bedrock Agent Invoke Tool"
|
||||||
|
description: str = "An agent responsible for policy analysis."
|
||||||
|
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput
|
||||||
|
agent_id: str = None
|
||||||
|
agent_alias_id: str = None
|
||||||
|
session_id: str = None
|
||||||
|
enable_trace: bool = False
|
||||||
|
end_session: bool = False
|
||||||
|
package_dependencies: List[str] = ["boto3"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_id: str = None,
|
||||||
|
agent_alias_id: str = None,
|
||||||
|
session_id: str = None,
|
||||||
|
enable_trace: bool = False,
|
||||||
|
end_session: bool = False,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id (str): The unique identifier of the Bedrock agent
|
||||||
|
agent_alias_id (str): The unique identifier of the agent alias
|
||||||
|
session_id (str): The unique identifier of the session
|
||||||
|
enable_trace (bool): Whether to enable trace for the agent invocation
|
||||||
|
end_session (bool): Whether to end the session with the agent
|
||||||
|
description (Optional[str]): Custom description for the tool
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Get values from environment variables if not provided
|
||||||
|
self.agent_id = agent_id or os.getenv("BEDROCK_AGENT_ID")
|
||||||
|
self.agent_alias_id = agent_alias_id or os.getenv("BEDROCK_AGENT_ALIAS_ID")
|
||||||
|
self.session_id = session_id or str(
|
||||||
|
int(time.time())
|
||||||
|
) # Use timestamp as session ID if not provided
|
||||||
|
self.enable_trace = enable_trace
|
||||||
|
self.end_session = end_session
|
||||||
|
|
||||||
|
# Update the description if provided
|
||||||
|
if description:
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
self._validate_parameters()
|
||||||
|
|
||||||
|
def _validate_parameters(self):
|
||||||
|
"""Validate the parameters according to AWS API requirements."""
|
||||||
|
try:
|
||||||
|
# Validate agent_id
|
||||||
|
if not self.agent_id:
|
||||||
|
raise BedrockValidationError("agent_id cannot be empty")
|
||||||
|
if not isinstance(self.agent_id, str):
|
||||||
|
raise BedrockValidationError("agent_id must be a string")
|
||||||
|
|
||||||
|
# Validate agent_alias_id
|
||||||
|
if not self.agent_alias_id:
|
||||||
|
raise BedrockValidationError("agent_alias_id cannot be empty")
|
||||||
|
if not isinstance(self.agent_alias_id, str):
|
||||||
|
raise BedrockValidationError("agent_alias_id must be a string")
|
||||||
|
|
||||||
|
# Validate session_id if provided
|
||||||
|
if self.session_id and not isinstance(self.session_id, str):
|
||||||
|
raise BedrockValidationError("session_id must be a string")
|
||||||
|
|
||||||
|
except BedrockValidationError as e:
|
||||||
|
raise BedrockValidationError(f"Parameter validation failed: {e!s}")
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the Bedrock Agent Runtime client
|
||||||
|
bedrock_agent = boto3.client(
|
||||||
|
"bedrock-agent-runtime",
|
||||||
|
region_name=os.getenv(
|
||||||
|
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format the prompt with current time
|
||||||
|
current_utc = datetime.now(timezone.utc)
|
||||||
|
prompt = f"""
|
||||||
|
The current time is: {current_utc}
|
||||||
|
|
||||||
|
Below is the users query or task. Complete it and answer it consicely and to the point:
|
||||||
|
{query}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Invoke the agent
|
||||||
|
response = bedrock_agent.invoke_agent(
|
||||||
|
agentId=self.agent_id,
|
||||||
|
agentAliasId=self.agent_alias_id,
|
||||||
|
sessionId=self.session_id,
|
||||||
|
inputText=prompt,
|
||||||
|
enableTrace=self.enable_trace,
|
||||||
|
endSession=self.end_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the response
|
||||||
|
completion = ""
|
||||||
|
|
||||||
|
# Check if response contains a completion field
|
||||||
|
if "completion" in response:
|
||||||
|
# Process streaming response format
|
||||||
|
for event in response.get("completion", []):
|
||||||
|
if "chunk" in event and "bytes" in event["chunk"]:
|
||||||
|
chunk_bytes = event["chunk"]["bytes"]
|
||||||
|
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||||
|
completion += chunk_bytes.decode("utf-8")
|
||||||
|
else:
|
||||||
|
completion += str(chunk_bytes)
|
||||||
|
|
||||||
|
# If no completion found in streaming format, try direct format
|
||||||
|
if not completion and "chunk" in response and "bytes" in response["chunk"]:
|
||||||
|
chunk_bytes = response["chunk"]["bytes"]
|
||||||
|
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||||
|
completion = chunk_bytes.decode("utf-8")
|
||||||
|
else:
|
||||||
|
completion = str(chunk_bytes)
|
||||||
|
|
||||||
|
# If still no completion, return debug info
|
||||||
|
if not completion:
|
||||||
|
debug_info = {
|
||||||
|
"error": "Could not extract completion from response",
|
||||||
|
"response_keys": list(response.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add more debug info
|
||||||
|
if "chunk" in response:
|
||||||
|
debug_info["chunk_keys"] = list(response["chunk"].keys())
|
||||||
|
|
||||||
|
raise BedrockAgentError(
|
||||||
|
f"Failed to extract completion: {json.dumps(debug_info, indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = "Unknown"
|
||||||
|
error_message = str(e)
|
||||||
|
|
||||||
|
# Try to extract error code if available
|
||||||
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
|
error_code = e.response["Error"].get("Code", "Unknown")
|
||||||
|
error_message = e.response["Error"].get("Message", str(e))
|
||||||
|
|
||||||
|
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||||
|
except BedrockAgentError:
|
||||||
|
# Re-raise BedrockAgentError exceptions
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockAgentError(f"Unexpected error: {e!s}")
|
||||||
158
lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/README.md
Normal file
158
lib/crewai-tools/src/crewai_tools/aws/bedrock/browser/README.md
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# AWS Bedrock Browser Tools
|
||||||
|
|
||||||
|
This toolkit provides a set of tools for interacting with web browsers through AWS Bedrock Browser. It enables your CrewAI agents to navigate websites, extract content, click elements, and more.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Navigate to URLs and browse the web
|
||||||
|
- Extract text and hyperlinks from pages
|
||||||
|
- Click on elements using CSS selectors
|
||||||
|
- Navigate back through browser history
|
||||||
|
- Get information about the current webpage
|
||||||
|
- Multiple browser sessions with thread-based isolation
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Ensure you have the necessary dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add crewai-tools bedrock-agentcore beautifulsoup4 playwright nest-asyncio
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai_tools.aws.bedrock.browser import create_browser_toolkit
|
||||||
|
|
||||||
|
# Create the browser toolkit
|
||||||
|
toolkit, browser_tools = create_browser_toolkit(region="us-west-2")
|
||||||
|
|
||||||
|
# Create the Bedrock LLM
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the browser tools
|
||||||
|
research_agent = Agent(
|
||||||
|
role="Web Researcher",
|
||||||
|
goal="Research and summarize web content",
|
||||||
|
backstory="You're an expert at finding information online.",
|
||||||
|
tools=browser_tools,
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
research_task = Task(
|
||||||
|
description="Navigate to https://example.com and extract all text content. Summarize the main points.",
|
||||||
|
expected_output="A list of bullet points containing the most important information on https://example.com. Plus, a description of the tool calls used, and actions performed to get to the page.",
|
||||||
|
agent=research_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[research_agent],
|
||||||
|
tasks=[research_task]
|
||||||
|
)
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print(f"\n***Final result:***\n\n{result}")
|
||||||
|
|
||||||
|
# Clean up browser resources when done
|
||||||
|
toolkit.sync_cleanup()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Tools
|
||||||
|
|
||||||
|
The toolkit provides the following tools:
|
||||||
|
|
||||||
|
1. `navigate_browser` - Navigate to a URL
|
||||||
|
2. `click_element` - Click on an element using CSS selectors
|
||||||
|
3. `extract_text` - Extract all text from the current webpage
|
||||||
|
4. `extract_hyperlinks` - Extract all hyperlinks from the current webpage
|
||||||
|
5. `get_elements` - Get elements matching a CSS selector
|
||||||
|
6. `navigate_back` - Navigate to the previous page
|
||||||
|
7. `current_webpage` - Get information about the current webpage
|
||||||
|
|
||||||
|
### Advanced Usage (with async)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai_tools.aws.bedrock.browser import create_browser_toolkit
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
# Create the browser toolkit with specific AWS region
|
||||||
|
toolkit, browser_tools = create_browser_toolkit(region="us-west-2")
|
||||||
|
tools_by_name = toolkit.get_tools_by_name()
|
||||||
|
|
||||||
|
# Create the Bedrock LLM
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agents with specific tools
|
||||||
|
navigator_agent = Agent(
|
||||||
|
role="Navigator",
|
||||||
|
goal="Find specific information across websites",
|
||||||
|
backstory="You navigate through websites to locate information.",
|
||||||
|
tools=[
|
||||||
|
tools_by_name["navigate_browser"],
|
||||||
|
tools_by_name["click_element"],
|
||||||
|
tools_by_name["navigate_back"]
|
||||||
|
],
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
content_agent = Agent(
|
||||||
|
role="Content Extractor",
|
||||||
|
goal="Extract and analyze webpage content",
|
||||||
|
backstory="You extract and analyze content from webpages.",
|
||||||
|
tools=[
|
||||||
|
tools_by_name["extract_text"],
|
||||||
|
tools_by_name["extract_hyperlinks"],
|
||||||
|
tools_by_name["get_elements"]
|
||||||
|
],
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks for the agents
|
||||||
|
navigation_task = Task(
|
||||||
|
description="Navigate to https://example.com, then click on the the 'More information...' link.",
|
||||||
|
expected_output="The status of the tool calls for this task.",
|
||||||
|
agent=navigator_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
extraction_task = Task(
|
||||||
|
description="Extract all text from the current page and summarize it.",
|
||||||
|
expected_output="The summary of the page, and a description of the tool calls used, and actions performed to get to the page.",
|
||||||
|
agent=content_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[navigator_agent, content_agent],
|
||||||
|
tasks=[navigation_task, extraction_task]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await crew.kickoff_async()
|
||||||
|
|
||||||
|
# Clean up browser resources when done
|
||||||
|
toolkit.sync_cleanup()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = asyncio.run(main())
|
||||||
|
print(f"\n***Final result:***\n\n{result}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- AWS account with access to Bedrock AgentCore API
|
||||||
|
- Properly configured AWS credentials
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .browser_toolkit import BrowserToolkit, create_browser_toolkit
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BrowserToolkit", "create_browser_toolkit"]
|
||||||
@@ -0,0 +1,263 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||||
|
from playwright.async_api import Browser as AsyncBrowser
|
||||||
|
from playwright.sync_api import Browser as SyncBrowser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserSessionManager:
|
||||||
|
"""
|
||||||
|
Manages browser sessions for different threads.
|
||||||
|
|
||||||
|
This class maintains separate browser sessions for different threads,
|
||||||
|
enabling concurrent usage of browsers in multi-threaded environments.
|
||||||
|
Browsers are created lazily only when needed by tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, region: str = "us-west-2"):
|
||||||
|
"""
|
||||||
|
Initialize the browser session manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
region: AWS region for browser client
|
||||||
|
"""
|
||||||
|
self.region = region
|
||||||
|
self._async_sessions: Dict[str, Tuple[BrowserClient, AsyncBrowser]] = {}
|
||||||
|
self._sync_sessions: Dict[str, Tuple[BrowserClient, SyncBrowser]] = {}
|
||||||
|
|
||||||
|
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||||
|
"""
|
||||||
|
Get or create an async browser for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread requesting the browser
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An async browser instance specific to the thread
|
||||||
|
"""
|
||||||
|
if thread_id in self._async_sessions:
|
||||||
|
return self._async_sessions[thread_id][1]
|
||||||
|
|
||||||
|
return await self._create_async_browser_session(thread_id)
|
||||||
|
|
||||||
|
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
|
||||||
|
"""
|
||||||
|
Get or create a sync browser for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread requesting the browser
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sync browser instance specific to the thread
|
||||||
|
"""
|
||||||
|
if thread_id in self._sync_sessions:
|
||||||
|
return self._sync_sessions[thread_id][1]
|
||||||
|
|
||||||
|
return self._create_sync_browser_session(thread_id)
|
||||||
|
|
||||||
|
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
|
||||||
|
"""
|
||||||
|
Create a new async browser session for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly created async browser instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If browser session creation fails
|
||||||
|
"""
|
||||||
|
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||||
|
|
||||||
|
browser_client = BrowserClient(region=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start browser session
|
||||||
|
browser_client.start()
|
||||||
|
|
||||||
|
# Get WebSocket connection info
|
||||||
|
ws_url, headers = browser_client.generate_ws_headers()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Connecting to async WebSocket endpoint for thread {thread_id}: {ws_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from playwright.async_api import async_playwright
|
||||||
|
|
||||||
|
# Connect to browser using Playwright
|
||||||
|
playwright = await async_playwright().start()
|
||||||
|
browser = await playwright.chromium.connect_over_cdp(
|
||||||
|
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully connected to async browser for thread {thread_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session resources
|
||||||
|
self._async_sessions[thread_id] = (browser_client, browser)
|
||||||
|
|
||||||
|
return browser
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create async browser session for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up resources if session creation fails
|
||||||
|
if browser_client:
|
||||||
|
try:
|
||||||
|
browser_client.stop()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(f"Error cleaning up browser client: {cleanup_error}")
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_sync_browser_session(self, thread_id: str) -> SyncBrowser:
|
||||||
|
"""
|
||||||
|
Create a new sync browser session for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly created sync browser instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If browser session creation fails
|
||||||
|
"""
|
||||||
|
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||||
|
|
||||||
|
browser_client = BrowserClient(region=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start browser session
|
||||||
|
browser_client.start()
|
||||||
|
|
||||||
|
# Get WebSocket connection info
|
||||||
|
ws_url, headers = browser_client.generate_ws_headers()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Connecting to sync WebSocket endpoint for thread {thread_id}: {ws_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from playwright.sync_api import sync_playwright
|
||||||
|
|
||||||
|
# Connect to browser using Playwright
|
||||||
|
playwright = sync_playwright().start()
|
||||||
|
browser = playwright.chromium.connect_over_cdp(
|
||||||
|
endpoint_url=ws_url, headers=headers, timeout=30000
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully connected to sync browser for thread {thread_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session resources
|
||||||
|
self._sync_sessions[thread_id] = (browser_client, browser)
|
||||||
|
|
||||||
|
return browser
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create sync browser session for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up resources if session creation fails
|
||||||
|
if browser_client:
|
||||||
|
try:
|
||||||
|
browser_client.stop()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(f"Error cleaning up browser client: {cleanup_error}")
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close_async_browser(self, thread_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Close the async browser session for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread
|
||||||
|
"""
|
||||||
|
if thread_id not in self._async_sessions:
|
||||||
|
logger.warning(f"No async browser session found for thread {thread_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
browser_client, browser = self._async_sessions[thread_id]
|
||||||
|
|
||||||
|
# Close browser
|
||||||
|
if browser:
|
||||||
|
try:
|
||||||
|
await browser.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error closing async browser for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop browser client
|
||||||
|
if browser_client:
|
||||||
|
try:
|
||||||
|
browser_client.stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove session from dictionary
|
||||||
|
del self._async_sessions[thread_id]
|
||||||
|
logger.info(f"Async browser session cleaned up for thread {thread_id}")
|
||||||
|
|
||||||
|
def close_sync_browser(self, thread_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Close the sync browser session for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Unique identifier for the thread
|
||||||
|
"""
|
||||||
|
if thread_id not in self._sync_sessions:
|
||||||
|
logger.warning(f"No sync browser session found for thread {thread_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
browser_client, browser = self._sync_sessions[thread_id]
|
||||||
|
|
||||||
|
# Close browser
|
||||||
|
if browser:
|
||||||
|
try:
|
||||||
|
browser.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error closing sync browser for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop browser client
|
||||||
|
if browser_client:
|
||||||
|
try:
|
||||||
|
browser_client.stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error stopping browser client for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove session from dictionary
|
||||||
|
del self._sync_sessions[thread_id]
|
||||||
|
logger.info(f"Sync browser session cleaned up for thread {thread_id}")
|
||||||
|
|
||||||
|
async def close_all_browsers(self) -> None:
|
||||||
|
"""Close all browser sessions."""
|
||||||
|
# Close all async browsers
|
||||||
|
async_thread_ids = list(self._async_sessions.keys())
|
||||||
|
for thread_id in async_thread_ids:
|
||||||
|
await self.close_async_browser(thread_id)
|
||||||
|
|
||||||
|
# Close all sync browsers
|
||||||
|
sync_thread_ids = list(self._sync_sessions.keys())
|
||||||
|
for thread_id in sync_thread_ids:
|
||||||
|
self.close_sync_browser(thread_id)
|
||||||
|
|
||||||
|
logger.info("All browser sessions closed")
|
||||||
@@ -0,0 +1,616 @@
|
|||||||
|
"""Toolkit for navigating web with AWS browser."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Tuple, Type
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .browser_session_manager import BrowserSessionManager
|
||||||
|
from .utils import aget_current_page, get_current_page
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Input schemas
|
||||||
|
class NavigateToolInput(BaseModel):
|
||||||
|
"""Input for NavigateTool."""
|
||||||
|
|
||||||
|
url: str = Field(description="URL to navigate to")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClickToolInput(BaseModel):
|
||||||
|
"""Input for ClickTool."""
|
||||||
|
|
||||||
|
selector: str = Field(description="CSS selector for the element to click on")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GetElementsToolInput(BaseModel):
|
||||||
|
"""Input for GetElementsTool."""
|
||||||
|
|
||||||
|
selector: str = Field(description="CSS selector for elements to get")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractTextToolInput(BaseModel):
|
||||||
|
"""Input for ExtractTextTool."""
|
||||||
|
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractHyperlinksToolInput(BaseModel):
|
||||||
|
"""Input for ExtractHyperlinksTool."""
|
||||||
|
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NavigateBackToolInput(BaseModel):
|
||||||
|
"""Input for NavigateBackTool."""
|
||||||
|
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentWebPageToolInput(BaseModel):
|
||||||
|
"""Input for CurrentWebPageTool."""
|
||||||
|
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the browser session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Base tool class
|
||||||
|
class BrowserBaseTool(BaseTool):
|
||||||
|
"""Base class for browser tools."""
|
||||||
|
|
||||||
|
def __init__(self, session_manager: BrowserSessionManager):
|
||||||
|
"""Initialize with a session manager."""
|
||||||
|
super().__init__()
|
||||||
|
self._session_manager = session_manager
|
||||||
|
|
||||||
|
if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
|
||||||
|
self._original_run = self._run
|
||||||
|
|
||||||
|
# Override _run to use _arun when in an asyncio loop
|
||||||
|
def patched_run(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
nest_asyncio.apply(loop)
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._arun(*args, **kwargs)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error in patched _run: {e!s}"
|
||||||
|
|
||||||
|
self._run = patched_run
|
||||||
|
|
||||||
|
async def get_async_page(self, thread_id: str) -> Any:
|
||||||
|
"""Get or create a page for the specified thread."""
|
||||||
|
browser = await self._session_manager.get_async_browser(thread_id)
|
||||||
|
page = await aget_current_page(browser)
|
||||||
|
return page
|
||||||
|
|
||||||
|
def get_sync_page(self, thread_id: str) -> Any:
|
||||||
|
"""Get or create a page for the specified thread."""
|
||||||
|
browser = self._session_manager.get_sync_browser(thread_id)
|
||||||
|
page = get_current_page(browser)
|
||||||
|
return page
|
||||||
|
|
||||||
|
def _is_in_asyncio_loop(self) -> bool:
|
||||||
|
"""Check if we're currently in an asyncio event loop."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.is_running()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Tool classes
|
||||||
|
class NavigateTool(BrowserBaseTool):
|
||||||
|
"""Tool for navigating a browser to a URL."""
|
||||||
|
|
||||||
|
name: str = "navigate_browser"
|
||||||
|
description: str = "Navigate a browser to the specified URL"
|
||||||
|
args_schema: Type[BaseModel] = NavigateToolInput
|
||||||
|
|
||||||
|
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Get page for this thread
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Validate URL scheme
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if parsed_url.scheme not in ("http", "https"):
|
||||||
|
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||||
|
|
||||||
|
# Navigate to URL
|
||||||
|
response = page.goto(url)
|
||||||
|
status = response.status if response else "unknown"
|
||||||
|
return f"Navigating to {url} returned status code {status}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error navigating to {url}: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Get page for this thread
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Validate URL scheme
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if parsed_url.scheme not in ("http", "https"):
|
||||||
|
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||||
|
|
||||||
|
# Navigate to URL
|
||||||
|
response = await page.goto(url)
|
||||||
|
status = response.status if response else "unknown"
|
||||||
|
return f"Navigating to {url} returned status code {status}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error navigating to {url}: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class ClickTool(BrowserBaseTool):
|
||||||
|
"""Tool for clicking on an element with the given CSS selector."""
|
||||||
|
|
||||||
|
name: str = "click_element"
|
||||||
|
description: str = "Click on an element with the given CSS selector"
|
||||||
|
args_schema: Type[BaseModel] = ClickToolInput
|
||||||
|
|
||||||
|
visible_only: bool = True
|
||||||
|
"""Whether to consider only visible elements."""
|
||||||
|
playwright_strict: bool = False
|
||||||
|
"""Whether to employ Playwright's strict mode when clicking on elements."""
|
||||||
|
playwright_timeout: float = 1_000
|
||||||
|
"""Timeout (in ms) for Playwright to wait for element to be ready."""
|
||||||
|
|
||||||
|
def _selector_effective(self, selector: str) -> str:
|
||||||
|
if not self.visible_only:
|
||||||
|
return selector
|
||||||
|
return f"{selector} >> visible=1"
|
||||||
|
|
||||||
|
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Click on the element
|
||||||
|
selector_effective = self._selector_effective(selector=selector)
|
||||||
|
from playwright.sync_api import TimeoutError as PlaywrightTimeoutError
|
||||||
|
|
||||||
|
try:
|
||||||
|
page.click(
|
||||||
|
selector_effective,
|
||||||
|
strict=self.playwright_strict,
|
||||||
|
timeout=self.playwright_timeout,
|
||||||
|
)
|
||||||
|
except PlaywrightTimeoutError:
|
||||||
|
return f"Unable to click on element '{selector}'"
|
||||||
|
except Exception as click_error:
|
||||||
|
return f"Unable to click on element '{selector}': {click_error!s}"
|
||||||
|
|
||||||
|
return f"Clicked element '{selector}'"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error clicking on element: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Click on the element
|
||||||
|
selector_effective = self._selector_effective(selector=selector)
|
||||||
|
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||||
|
|
||||||
|
try:
|
||||||
|
await page.click(
|
||||||
|
selector_effective,
|
||||||
|
strict=self.playwright_strict,
|
||||||
|
timeout=self.playwright_timeout,
|
||||||
|
)
|
||||||
|
except PlaywrightTimeoutError:
|
||||||
|
return f"Unable to click on element '{selector}'"
|
||||||
|
except Exception as click_error:
|
||||||
|
return f"Unable to click on element '{selector}': {click_error!s}"
|
||||||
|
|
||||||
|
return f"Clicked element '{selector}'"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error clicking on element: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class NavigateBackTool(BrowserBaseTool):
|
||||||
|
"""Tool for navigating back in browser history."""
|
||||||
|
|
||||||
|
name: str = "navigate_back"
|
||||||
|
description: str = "Navigate back to the previous page"
|
||||||
|
args_schema: Type[BaseModel] = NavigateBackToolInput
|
||||||
|
|
||||||
|
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Navigate back
|
||||||
|
try:
|
||||||
|
page.go_back()
|
||||||
|
return "Navigated back to the previous page"
|
||||||
|
except Exception as nav_error:
|
||||||
|
return f"Unable to navigate back: {nav_error!s}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error navigating back: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Navigate back
|
||||||
|
try:
|
||||||
|
await page.go_back()
|
||||||
|
return "Navigated back to the previous page"
|
||||||
|
except Exception as nav_error:
|
||||||
|
return f"Unable to navigate back: {nav_error!s}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error navigating back: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractTextTool(BrowserBaseTool):
|
||||||
|
"""Tool for extracting text from a webpage."""
|
||||||
|
|
||||||
|
name: str = "extract_text"
|
||||||
|
description: str = "Extract all the text on the current webpage"
|
||||||
|
args_schema: Type[BaseModel] = ExtractTextToolInput
|
||||||
|
|
||||||
|
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Import BeautifulSoup
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"The 'beautifulsoup4' package is required to use this tool."
|
||||||
|
" Please install it with 'pip install beautifulsoup4'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Extract text
|
||||||
|
content = page.content()
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
return soup.get_text(separator="\n").strip()
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error extracting text: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Import BeautifulSoup
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"The 'beautifulsoup4' package is required to use this tool."
|
||||||
|
" Please install it with 'pip install beautifulsoup4'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Extract text
|
||||||
|
content = await page.content()
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
return soup.get_text(separator="\n").strip()
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error extracting text: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractHyperlinksTool(BrowserBaseTool):
|
||||||
|
"""Tool for extracting hyperlinks from a webpage."""
|
||||||
|
|
||||||
|
name: str = "extract_hyperlinks"
|
||||||
|
description: str = "Extract all hyperlinks on the current webpage"
|
||||||
|
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
||||||
|
|
||||||
|
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Import BeautifulSoup
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"The 'beautifulsoup4' package is required to use this tool."
|
||||||
|
" Please install it with 'pip install beautifulsoup4'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Extract hyperlinks
|
||||||
|
content = page.content()
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
links = []
|
||||||
|
for link in soup.find_all("a", href=True):
|
||||||
|
text = link.get_text().strip()
|
||||||
|
href = link["href"]
|
||||||
|
if href.startswith("http") or href.startswith("https"):
|
||||||
|
links.append({"text": text, "url": href})
|
||||||
|
|
||||||
|
if not links:
|
||||||
|
return "No hyperlinks found on the current page."
|
||||||
|
|
||||||
|
return json.dumps(links, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error extracting hyperlinks: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Import BeautifulSoup
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"The 'beautifulsoup4' package is required to use this tool."
|
||||||
|
" Please install it with 'pip install beautifulsoup4'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Extract hyperlinks
|
||||||
|
content = await page.content()
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
links = []
|
||||||
|
for link in soup.find_all("a", href=True):
|
||||||
|
text = link.get_text().strip()
|
||||||
|
href = link["href"]
|
||||||
|
if href.startswith("http") or href.startswith("https"):
|
||||||
|
links.append({"text": text, "url": href})
|
||||||
|
|
||||||
|
if not links:
|
||||||
|
return "No hyperlinks found on the current page."
|
||||||
|
|
||||||
|
return json.dumps(links, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error extracting hyperlinks: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class GetElementsTool(BrowserBaseTool):
|
||||||
|
"""Tool for getting elements from a webpage."""
|
||||||
|
|
||||||
|
name: str = "get_elements"
|
||||||
|
description: str = "Get elements from the webpage using a CSS selector"
|
||||||
|
args_schema: Type[BaseModel] = GetElementsToolInput
|
||||||
|
|
||||||
|
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Get elements
|
||||||
|
elements = page.query_selector_all(selector)
|
||||||
|
if not elements:
|
||||||
|
return f"No elements found with selector '{selector}'"
|
||||||
|
|
||||||
|
elements_text = []
|
||||||
|
for i, element in enumerate(elements):
|
||||||
|
text = element.text_content()
|
||||||
|
elements_text.append(f"Element {i + 1}: {text.strip()}")
|
||||||
|
|
||||||
|
return "\n".join(elements_text)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error getting elements: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Get elements
|
||||||
|
elements = await page.query_selector_all(selector)
|
||||||
|
if not elements:
|
||||||
|
return f"No elements found with selector '{selector}'"
|
||||||
|
|
||||||
|
elements_text = []
|
||||||
|
for i, element in enumerate(elements):
|
||||||
|
text = await element.text_content()
|
||||||
|
elements_text.append(f"Element {i + 1}: {text.strip()}")
|
||||||
|
|
||||||
|
return "\n".join(elements_text)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error getting elements: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentWebPageTool(BrowserBaseTool):
|
||||||
|
"""Tool for getting information about the current webpage."""
|
||||||
|
|
||||||
|
name: str = "current_webpage"
|
||||||
|
description: str = "Get information about the current webpage"
|
||||||
|
args_schema: Type[BaseModel] = CurrentWebPageToolInput
|
||||||
|
|
||||||
|
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the sync tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = self.get_sync_page(thread_id)
|
||||||
|
|
||||||
|
# Get information
|
||||||
|
url = page.url
|
||||||
|
title = page.title()
|
||||||
|
return f"URL: {url}\nTitle: {title}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error getting current webpage info: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||||
|
"""Use the async tool."""
|
||||||
|
try:
|
||||||
|
# Get the current page
|
||||||
|
page = await self.get_async_page(thread_id)
|
||||||
|
|
||||||
|
# Get information
|
||||||
|
url = page.url
|
||||||
|
title = await page.title()
|
||||||
|
return f"URL: {url}\nTitle: {title}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error getting current webpage info: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserToolkit:
|
||||||
|
"""Toolkit for navigating web with AWS Bedrock browser.
|
||||||
|
|
||||||
|
This toolkit provides a set of tools for working with a remote browser
|
||||||
|
and supports multiple threads by maintaining separate browser sessions
|
||||||
|
for each thread ID. Browsers are created lazily only when needed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai_tools.aws.bedrock.browser import create_browser_toolkit
|
||||||
|
|
||||||
|
# Create the browser toolkit
|
||||||
|
toolkit, browser_tools = create_browser_toolkit(region="us-west-2")
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the browser tools
|
||||||
|
research_agent = Agent(
|
||||||
|
role="Web Researcher",
|
||||||
|
goal="Research and summarize web content",
|
||||||
|
backstory="You're an expert at finding information online.",
|
||||||
|
tools=browser_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
research_task = Task(
|
||||||
|
description="Navigate to https://example.com and extract all text content. Summarize the main points.",
|
||||||
|
agent=research_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Clean up browser resources when done
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, region: str = "us-west-2"):
|
||||||
|
"""
|
||||||
|
Initialize the toolkit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
region: AWS region for the browser client
|
||||||
|
"""
|
||||||
|
self.region = region
|
||||||
|
self.session_manager = BrowserSessionManager(region=region)
|
||||||
|
self.tools: List[BaseTool] = []
|
||||||
|
self._nest_current_loop()
|
||||||
|
self._setup_tools()
|
||||||
|
|
||||||
|
def _nest_current_loop(self):
|
||||||
|
"""Apply nest_asyncio if we're in an asyncio loop."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply(loop)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to apply nest_asyncio: {e!s}")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _setup_tools(self) -> None:
|
||||||
|
"""Initialize tools without creating any browsers."""
|
||||||
|
self.tools = [
|
||||||
|
NavigateTool(session_manager=self.session_manager),
|
||||||
|
ClickTool(session_manager=self.session_manager),
|
||||||
|
NavigateBackTool(session_manager=self.session_manager),
|
||||||
|
ExtractTextTool(session_manager=self.session_manager),
|
||||||
|
ExtractHyperlinksTool(session_manager=self.session_manager),
|
||||||
|
GetElementsTool(session_manager=self.session_manager),
|
||||||
|
CurrentWebPageTool(session_manager=self.session_manager),
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_tools(self) -> List[BaseTool]:
|
||||||
|
"""
|
||||||
|
Get the list of browser tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CrewAI tools
|
||||||
|
"""
|
||||||
|
return self.tools
|
||||||
|
|
||||||
|
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||||
|
"""
|
||||||
|
Get a dictionary of tools mapped by their names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of {tool_name: tool}
|
||||||
|
"""
|
||||||
|
return {tool.name: tool for tool in self.tools}
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Clean up all browser sessions asynchronously"""
|
||||||
|
await self.session_manager.close_all_browsers()
|
||||||
|
logger.info("All browser sessions cleaned up")
|
||||||
|
|
||||||
|
def sync_cleanup(self) -> None:
|
||||||
|
"""Clean up all browser sessions from synchronous code"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
asyncio.create_task(self.cleanup())
|
||||||
|
else:
|
||||||
|
loop.run_until_complete(self.cleanup())
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(self.cleanup())
|
||||||
|
|
||||||
|
|
||||||
|
def create_browser_toolkit(
|
||||||
|
region: str = "us-west-2",
|
||||||
|
) -> Tuple[BrowserToolkit, List[BaseTool]]:
|
||||||
|
"""
|
||||||
|
Create a BrowserToolkit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
region: AWS region for browser client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (toolkit, tools)
|
||||||
|
"""
|
||||||
|
toolkit = BrowserToolkit(region=region)
|
||||||
|
tools = toolkit.get_tools()
|
||||||
|
return toolkit, tools
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from playwright.async_api import Browser as AsyncBrowser, Page as AsyncPage
|
||||||
|
from playwright.sync_api import Browser as SyncBrowser, Page as SyncPage
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
|
||||||
|
"""
|
||||||
|
Asynchronously get the current page of the browser.
|
||||||
|
Args:
|
||||||
|
browser: The browser (AsyncBrowser) to get the current page from.
|
||||||
|
Returns:
|
||||||
|
AsyncPage: The current page.
|
||||||
|
"""
|
||||||
|
if not browser.contexts:
|
||||||
|
context = await browser.new_context()
|
||||||
|
return await context.new_page()
|
||||||
|
context = browser.contexts[0]
|
||||||
|
if not context.pages:
|
||||||
|
return await context.new_page()
|
||||||
|
return context.pages[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage:
|
||||||
|
"""
|
||||||
|
Get the current page of the browser.
|
||||||
|
Args:
|
||||||
|
browser: The browser to get the current page from.
|
||||||
|
Returns:
|
||||||
|
SyncPage: The current page.
|
||||||
|
"""
|
||||||
|
if not browser.contexts:
|
||||||
|
context = browser.new_context()
|
||||||
|
return context.new_page()
|
||||||
|
context = browser.contexts[0]
|
||||||
|
if not context.pages:
|
||||||
|
return context.new_page()
|
||||||
|
return context.pages[-1]
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
# AWS Bedrock Code Interpreter Tools
|
||||||
|
|
||||||
|
This toolkit provides a set of tools for interacting with the AWS Bedrock Code Interpreter environment. It enables your CrewAI agents to execute code, run shell commands, manage files, and perform computational tasks in a secure, isolated environment.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Execute code in various languages (primarily Python)
|
||||||
|
- Run shell commands in the environment
|
||||||
|
- Read, write, list, and delete files
|
||||||
|
- Manage long-running tasks asynchronously
|
||||||
|
- Multiple code interpreter sessions with thread-based isolation
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Ensure you have the necessary dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add crewai-tools bedrock-agentcore
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai_tools.aws import create_code_interpreter_toolkit
|
||||||
|
|
||||||
|
# Create the code interpreter toolkit
|
||||||
|
toolkit, code_tools = create_code_interpreter_toolkit(region="us-west-2")
|
||||||
|
|
||||||
|
# Create the Bedrock LLM
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the code interpreter tools
|
||||||
|
developer_agent = Agent(
|
||||||
|
role="Python Developer",
|
||||||
|
goal="Create and execute Python code to solve problems.",
|
||||||
|
backstory="You're a skilled Python developer with expertise in data analysis.",
|
||||||
|
tools=code_tools,
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
coding_task = Task(
|
||||||
|
description="Write a Python function that calculates the factorial of a number and test it. Do not use any imports from outside the Python standard library.",
|
||||||
|
expected_output="The Python function created, and the test results.",
|
||||||
|
agent=developer_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[developer_agent],
|
||||||
|
tasks=[coding_task]
|
||||||
|
)
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print(f"\n***Final result:***\n\n{result}")
|
||||||
|
|
||||||
|
# Clean up resources when done
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Tools
|
||||||
|
|
||||||
|
The toolkit provides the following tools:
|
||||||
|
|
||||||
|
1. `execute_code` - Run code in various languages (primarily Python)
|
||||||
|
2. `execute_command` - Run shell commands in the environment
|
||||||
|
3. `read_files` - Read content of files in the environment
|
||||||
|
4. `list_files` - List files in directories
|
||||||
|
5. `delete_files` - Remove files from the environment
|
||||||
|
6. `write_files` - Create or update files
|
||||||
|
7. `start_command_execution` - Start long-running commands asynchronously
|
||||||
|
8. `get_task` - Check status of async tasks
|
||||||
|
9. `stop_task` - Stop running tasks
|
||||||
|
|
||||||
|
### Advanced Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai_tools.aws import create_code_interpreter_toolkit
|
||||||
|
|
||||||
|
# Create the code interpreter toolkit
|
||||||
|
toolkit, code_tools = create_code_interpreter_toolkit(region="us-west-2")
|
||||||
|
tools_by_name = toolkit.get_tools_by_name()
|
||||||
|
|
||||||
|
# Create the Bedrock LLM
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agents with specific tools
|
||||||
|
code_agent = Agent(
|
||||||
|
role="Code Developer",
|
||||||
|
goal="Write and execute code",
|
||||||
|
backstory="You write and test code to solve complex problems.",
|
||||||
|
tools=[
|
||||||
|
# Use specific tools by name
|
||||||
|
tools_by_name["execute_code"],
|
||||||
|
tools_by_name["execute_command"],
|
||||||
|
tools_by_name["read_files"],
|
||||||
|
tools_by_name["write_files"]
|
||||||
|
],
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
file_agent = Agent(
|
||||||
|
role="File Manager",
|
||||||
|
goal="Manage files in the environment",
|
||||||
|
backstory="You help organize and manage files in the code environment.",
|
||||||
|
tools=[
|
||||||
|
# Use specific tools by name
|
||||||
|
tools_by_name["list_files"],
|
||||||
|
tools_by_name["read_files"],
|
||||||
|
tools_by_name["write_files"],
|
||||||
|
tools_by_name["delete_files"]
|
||||||
|
],
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks for the agents
|
||||||
|
coding_task = Task(
|
||||||
|
description="Write a Python script to analyze data from a CSV file. Do not use any imports from outside the Python standard library.",
|
||||||
|
expected_output="The Python function created.",
|
||||||
|
agent=code_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
file_task = Task(
|
||||||
|
description="Organize the created files into separate directories.",
|
||||||
|
agent=file_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[code_agent, file_agent],
|
||||||
|
tasks=[coding_task, file_task]
|
||||||
|
)
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print(f"\n***Final result:***\n\n{result}")
|
||||||
|
|
||||||
|
# Clean up code interpreter resources when done
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Data Analysis with Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai_tools.aws import create_code_interpreter_toolkit
|
||||||
|
|
||||||
|
# Create toolkit and tools
|
||||||
|
toolkit, code_tools = create_code_interpreter_toolkit(region="us-west-2")
|
||||||
|
|
||||||
|
# Create the Bedrock LLM
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a data analyst agent
|
||||||
|
analyst_agent = Agent(
|
||||||
|
role="Data Analyst",
|
||||||
|
goal="Analyze data using Python",
|
||||||
|
backstory="You're an expert data analyst who uses Python for data processing.",
|
||||||
|
tools=code_tools,
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
analysis_task = Task(
|
||||||
|
description="""
|
||||||
|
For all of the below, do not use any imports from outside the Python standard library.
|
||||||
|
1. Create a sample dataset with random data
|
||||||
|
2. Perform statistical analysis on the dataset
|
||||||
|
3. Generate visualizations of the results
|
||||||
|
4. Save the results and visualizations to files
|
||||||
|
""",
|
||||||
|
agent=analyst_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[analyst_agent],
|
||||||
|
tasks=[analysis_task]
|
||||||
|
)
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print(f"\n***Final result:***\n\n{result}")
|
||||||
|
|
||||||
|
# Clean up resources
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Resource Cleanup
|
||||||
|
|
||||||
|
Always clean up code interpreter resources when done to prevent resource leaks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Clean up all code interpreter sessions
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- AWS account with access to Bedrock AgentCore API
|
||||||
|
- Properly configured AWS credentials
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from .code_interpreter_toolkit import (
|
||||||
|
CodeInterpreterToolkit,
|
||||||
|
create_code_interpreter_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["CodeInterpreterToolkit", "create_code_interpreter_toolkit"]
|
||||||
@@ -0,0 +1,630 @@
|
|||||||
|
"""Toolkit for working with AWS Bedrock Code Interpreter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_output_from_stream(response):
|
||||||
|
"""
|
||||||
|
Extract output from code interpreter response stream
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Response from code interpreter execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted output as string
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for event in response["stream"]:
|
||||||
|
if "result" in event:
|
||||||
|
result = event["result"]
|
||||||
|
for content_item in result["content"]:
|
||||||
|
if content_item["type"] == "text":
|
||||||
|
output.append(content_item["text"])
|
||||||
|
if content_item["type"] == "resource":
|
||||||
|
resource = content_item["resource"]
|
||||||
|
if "text" in resource:
|
||||||
|
file_path = resource["uri"].replace("file://", "")
|
||||||
|
file_content = resource["text"]
|
||||||
|
output.append(f"==== File: {file_path} ====\n{file_content}\n")
|
||||||
|
else:
|
||||||
|
output.append(json.dumps(resource))
|
||||||
|
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
# Input schemas
|
||||||
|
class ExecuteCodeInput(BaseModel):
|
||||||
|
"""Input for ExecuteCode."""
|
||||||
|
|
||||||
|
code: str = Field(description="The code to execute")
|
||||||
|
language: str = Field(
|
||||||
|
default="python", description="The programming language of the code"
|
||||||
|
)
|
||||||
|
clear_context: bool = Field(
|
||||||
|
default=False, description="Whether to clear execution context"
|
||||||
|
)
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteCommandInput(BaseModel):
|
||||||
|
"""Input for ExecuteCommand."""
|
||||||
|
|
||||||
|
command: str = Field(description="The command to execute")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFilesInput(BaseModel):
|
||||||
|
"""Input for ReadFiles."""
|
||||||
|
|
||||||
|
paths: List[str] = Field(description="List of file paths to read")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListFilesInput(BaseModel):
|
||||||
|
"""Input for ListFiles."""
|
||||||
|
|
||||||
|
directory_path: str = Field(default="", description="Path to the directory to list")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFilesInput(BaseModel):
|
||||||
|
"""Input for DeleteFiles."""
|
||||||
|
|
||||||
|
paths: List[str] = Field(description="List of file paths to delete")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFilesInput(BaseModel):
|
||||||
|
"""Input for WriteFiles."""
|
||||||
|
|
||||||
|
files: List[Dict[str, str]] = Field(
|
||||||
|
description="List of dictionaries with path and text fields"
|
||||||
|
)
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StartCommandInput(BaseModel):
|
||||||
|
"""Input for StartCommand."""
|
||||||
|
|
||||||
|
command: str = Field(description="The command to execute asynchronously")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskInput(BaseModel):
|
||||||
|
"""Input for GetTask."""
|
||||||
|
|
||||||
|
task_id: str = Field(description="The ID of the task to check")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StopTaskInput(BaseModel):
|
||||||
|
"""Input for StopTask."""
|
||||||
|
|
||||||
|
task_id: str = Field(description="The ID of the task to stop")
|
||||||
|
thread_id: str = Field(
|
||||||
|
default="default", description="Thread ID for the code interpreter session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Tool classes
|
||||||
|
class ExecuteCodeTool(BaseTool):
|
||||||
|
"""Tool for executing code in various languages."""
|
||||||
|
|
||||||
|
name: str = "execute_code"
|
||||||
|
description: str = "Execute code in various languages (primarily Python)"
|
||||||
|
args_schema: Type[BaseModel] = ExecuteCodeInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
clear_context: bool = False,
|
||||||
|
thread_id: str = "default",
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute code
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="executeCode",
|
||||||
|
params={
|
||||||
|
"code": code,
|
||||||
|
"language": language,
|
||||||
|
"clearContext": clear_context,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing code: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
clear_context: bool = False,
|
||||||
|
thread_id: str = "default",
|
||||||
|
) -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(
|
||||||
|
code=code,
|
||||||
|
language=language,
|
||||||
|
clear_context=clear_context,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteCommandTool(BaseTool):
|
||||||
|
"""Tool for running shell commands in the code interpreter environment."""
|
||||||
|
|
||||||
|
name: str = "execute_command"
|
||||||
|
description: str = "Run shell commands in the code interpreter environment"
|
||||||
|
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute command
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="executeCommand", params={"command": command}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing command: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, command: str, thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(command=command, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFilesTool(BaseTool):
|
||||||
|
"""Tool for reading content of files in the environment."""
|
||||||
|
|
||||||
|
name: str = "read_files"
|
||||||
|
description: str = "Read content of files in the environment"
|
||||||
|
args_schema: Type[BaseModel] = ReadFilesInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read files
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="readFiles", params={"paths": paths}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error reading files: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(paths=paths, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class ListFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in directories in the environment."""
|
||||||
|
|
||||||
|
name: str = "list_files"
|
||||||
|
description: str = "List files in directories in the environment"
|
||||||
|
args_schema: Type[BaseModel] = ListFilesInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# List files
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="listFiles", params={"directoryPath": directory_path}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error listing files: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(directory_path=directory_path, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFilesTool(BaseTool):
|
||||||
|
"""Tool for removing files from the environment."""
|
||||||
|
|
||||||
|
name: str = "delete_files"
|
||||||
|
description: str = "Remove files from the environment"
|
||||||
|
args_schema: Type[BaseModel] = DeleteFilesInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove files
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="removeFiles", params={"paths": paths}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error deleting files: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(paths=paths, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFilesTool(BaseTool):
|
||||||
|
"""Tool for creating or updating files in the environment."""
|
||||||
|
|
||||||
|
name: str = "write_files"
|
||||||
|
description: str = "Create or update files in the environment"
|
||||||
|
args_schema: Type[BaseModel] = WriteFilesInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, files: List[Dict[str, str]], thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write files
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="writeFiles", params={"content": files}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error writing files: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self, files: List[Dict[str, str]], thread_id: str = "default"
|
||||||
|
) -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(files=files, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class StartCommandTool(BaseTool):
|
||||||
|
"""Tool for starting long-running commands asynchronously."""
|
||||||
|
|
||||||
|
name: str = "start_command_execution"
|
||||||
|
description: str = "Start long-running commands asynchronously"
|
||||||
|
args_schema: Type[BaseModel] = StartCommandInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start command execution
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="startCommandExecution", params={"command": command}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error starting command: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, command: str, thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(command=command, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class GetTaskTool(BaseTool):
|
||||||
|
"""Tool for checking status of async tasks."""
|
||||||
|
|
||||||
|
name: str = "get_task"
|
||||||
|
description: str = "Check status of async tasks"
|
||||||
|
args_schema: Type[BaseModel] = GetTaskInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get task status
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="getTask", params={"taskId": task_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error getting task status: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, task_id: str, thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(task_id=task_id, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class StopTaskTool(BaseTool):
|
||||||
|
"""Tool for stopping running tasks."""
|
||||||
|
|
||||||
|
name: str = "stop_task"
|
||||||
|
description: str = "Stop running tasks"
|
||||||
|
args_schema: Type[BaseModel] = StopTaskInput
|
||||||
|
toolkit: Any = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, toolkit):
|
||||||
|
super().__init__()
|
||||||
|
self.toolkit = toolkit
|
||||||
|
|
||||||
|
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||||
|
try:
|
||||||
|
# Get or create code interpreter
|
||||||
|
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||||
|
thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop task
|
||||||
|
response = code_interpreter.invoke(
|
||||||
|
method="stopTask", params={"taskId": task_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return extract_output_from_stream(response)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error stopping task: {e!s}"
|
||||||
|
|
||||||
|
async def _arun(self, task_id: str, thread_id: str = "default") -> str:
|
||||||
|
# Use _run as we're working with a synchronous API that's thread-safe
|
||||||
|
return self._run(task_id=task_id, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolkit:
|
||||||
|
"""Toolkit for working with AWS Bedrock code interpreter environment.
|
||||||
|
|
||||||
|
This toolkit provides a set of tools for working with a remote code interpreter environment:
|
||||||
|
|
||||||
|
* execute_code - Run code in various languages (primarily Python)
|
||||||
|
* execute_command - Run shell commands
|
||||||
|
* read_files - Read content of files in the environment
|
||||||
|
* list_files - List files in directories
|
||||||
|
* delete_files - Remove files from the environment
|
||||||
|
* write_files - Create or update files
|
||||||
|
* start_command_execution - Start long-running commands asynchronously
|
||||||
|
* get_task - Check status of async tasks
|
||||||
|
* stop_task - Stop running tasks
|
||||||
|
|
||||||
|
The toolkit lazily initializes the code interpreter session on first use.
|
||||||
|
It supports multiple threads by maintaining separate code interpreter sessions for each thread ID.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai_tools.aws.bedrock.code_interpreter import (
|
||||||
|
create_code_interpreter_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the code interpreter toolkit
|
||||||
|
toolkit, code_tools = create_code_interpreter_toolkit(region="us-west-2")
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the code interpreter tools
|
||||||
|
developer_agent = Agent(
|
||||||
|
role="Python Developer",
|
||||||
|
goal="Create and execute Python code to solve problems",
|
||||||
|
backstory="You're a skilled Python developer with expertise in data analysis.",
|
||||||
|
tools=code_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
coding_task = Task(
|
||||||
|
description="Write a Python function that calculates the factorial of a number and test it.",
|
||||||
|
agent=developer_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run the crew
|
||||||
|
crew = Crew(agents=[developer_agent], tasks=[coding_task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Clean up resources when done
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(toolkit.cleanup())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, region: str = "us-west-2"):
|
||||||
|
"""
|
||||||
|
Initialize the toolkit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
region: AWS region for the code interpreter
|
||||||
|
"""
|
||||||
|
self.region = region
|
||||||
|
self._code_interpreters: Dict[str, CodeInterpreter] = {}
|
||||||
|
self.tools: List[BaseTool] = []
|
||||||
|
self._setup_tools()
|
||||||
|
|
||||||
|
def _setup_tools(self) -> None:
|
||||||
|
"""Initialize tools without creating any code interpreter sessions."""
|
||||||
|
self.tools = [
|
||||||
|
ExecuteCodeTool(self),
|
||||||
|
ExecuteCommandTool(self),
|
||||||
|
ReadFilesTool(self),
|
||||||
|
ListFilesTool(self),
|
||||||
|
DeleteFilesTool(self),
|
||||||
|
WriteFilesTool(self),
|
||||||
|
StartCommandTool(self),
|
||||||
|
GetTaskTool(self),
|
||||||
|
StopTaskTool(self),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_or_create_interpreter(self, thread_id: str = "default") -> CodeInterpreter:
|
||||||
|
"""Get or create a code interpreter for the specified thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Thread ID for the code interpreter session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeInterpreter instance
|
||||||
|
"""
|
||||||
|
if thread_id in self._code_interpreters:
|
||||||
|
return self._code_interpreters[thread_id]
|
||||||
|
|
||||||
|
# Create a new code interpreter for this thread
|
||||||
|
from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter
|
||||||
|
|
||||||
|
code_interpreter = CodeInterpreter(region=self.region)
|
||||||
|
code_interpreter.start()
|
||||||
|
logger.info(
|
||||||
|
f"Started code interpreter with session_id:{code_interpreter.session_id} for thread:{thread_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the interpreter
|
||||||
|
self._code_interpreters[thread_id] = code_interpreter
|
||||||
|
return code_interpreter
|
||||||
|
|
||||||
|
def get_tools(self) -> List[BaseTool]:
|
||||||
|
"""
|
||||||
|
Get the list of code interpreter tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CrewAI tools
|
||||||
|
"""
|
||||||
|
return self.tools
|
||||||
|
|
||||||
|
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||||
|
"""
|
||||||
|
Get a dictionary of tools mapped by their names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of {tool_name: tool}
|
||||||
|
"""
|
||||||
|
return {tool.name: tool for tool in self.tools}
|
||||||
|
|
||||||
|
async def cleanup(self, thread_id: Optional[str] = None) -> None:
|
||||||
|
"""Clean up resources
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
|
||||||
|
"""
|
||||||
|
if thread_id:
|
||||||
|
# Clean up a specific thread's session
|
||||||
|
if thread_id in self._code_interpreters:
|
||||||
|
try:
|
||||||
|
self._code_interpreters[thread_id].stop()
|
||||||
|
del self._code_interpreters[thread_id]
|
||||||
|
logger.info(
|
||||||
|
f"Code interpreter session for thread {thread_id} cleaned up"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error stopping code interpreter for thread {thread_id}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Clean up all sessions
|
||||||
|
thread_ids = list(self._code_interpreters.keys())
|
||||||
|
for tid in thread_ids:
|
||||||
|
try:
|
||||||
|
self._code_interpreters[tid].stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error stopping code interpreter for thread {tid}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._code_interpreters = {}
|
||||||
|
logger.info("All code interpreter sessions cleaned up")
|
||||||
|
|
||||||
|
|
||||||
|
def create_code_interpreter_toolkit(
|
||||||
|
region: str = "us-west-2",
|
||||||
|
) -> Tuple[CodeInterpreterToolkit, List[BaseTool]]:
|
||||||
|
"""
|
||||||
|
Create a CodeInterpreterToolkit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
region: AWS region for code interpreter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (toolkit, tools)
|
||||||
|
"""
|
||||||
|
toolkit = CodeInterpreterToolkit(region=region)
|
||||||
|
tools = toolkit.get_tools()
|
||||||
|
return toolkit, tools
|
||||||
17
lib/crewai-tools/src/crewai_tools/aws/bedrock/exceptions.py
Normal file
17
lib/crewai-tools/src/crewai_tools/aws/bedrock/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Custom exceptions for AWS Bedrock integration."""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockError(Exception):
|
||||||
|
"""Base exception for Bedrock-related errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockAgentError(BedrockError):
|
||||||
|
"""Exception raised for errors in the Bedrock Agent operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockKnowledgeBaseError(BedrockError):
|
||||||
|
"""Exception raised for errors in the Bedrock Knowledge Base operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockValidationError(BedrockError):
|
||||||
|
"""Exception raised for validation errors in Bedrock operations."""
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
# BedrockKBRetrieverTool
|
||||||
|
|
||||||
|
The `BedrockKBRetrieverTool` enables CrewAI agents to retrieve information from Amazon Bedrock Knowledge Bases using natural language queries.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- AWS credentials configured (either through environment variables or AWS CLI)
|
||||||
|
- `boto3` and `python-dotenv` packages
|
||||||
|
- Access to Amazon Bedrock Knowledge Base
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Here's how to use the tool with a CrewAI agent:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai_tools.aws.bedrock.knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||||
|
|
||||||
|
# Initialize the tool
|
||||||
|
kb_tool = BedrockKBRetrieverTool(
|
||||||
|
knowledge_base_id="your-kb-id",
|
||||||
|
number_of_results=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a CrewAI agent that uses the tool
|
||||||
|
researcher = Agent(
|
||||||
|
role='Knowledge Base Researcher',
|
||||||
|
goal='Find information about company policies',
|
||||||
|
backstory='I am a researcher specialized in retrieving and analyzing company documentation.',
|
||||||
|
tools=[kb_tool],
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
research_task = Task(
|
||||||
|
description="Find our company's remote work policy and summarize the key points.",
|
||||||
|
agent=researcher
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a crew with the agent
|
||||||
|
crew = Crew(
|
||||||
|
agents=[researcher],
|
||||||
|
tasks=[research_task],
|
||||||
|
verbose=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tool Arguments
|
||||||
|
|
||||||
|
| Argument | Type | Required | Default | Description |
|
||||||
|
|----------|------|----------|---------|-------------|
|
||||||
|
| knowledge_base_id | str | Yes | None | The unique identifier of the knowledge base (0-10 alphanumeric characters) |
|
||||||
|
| number_of_results | int | No | 5 | Maximum number of results to return |
|
||||||
|
| retrieval_configuration | dict | No | None | Custom configurations for the knowledge base query |
|
||||||
|
| guardrail_configuration | dict | No | None | Content filtering settings |
|
||||||
|
| next_token | str | No | None | Token for pagination |
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
BEDROCK_KB_ID=your-knowledge-base-id # Alternative to passing knowledge_base_id
|
||||||
|
AWS_REGION=your-aws-region # Defaults to us-east-1
|
||||||
|
AWS_ACCESS_KEY_ID=your-access-key # Required for AWS authentication
|
||||||
|
AWS_SECRET_ACCESS_KEY=your-secret-key # Required for AWS authentication
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
|
||||||
|
The tool returns results in JSON format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"content": "Retrieved text content",
|
||||||
|
"content_type": "text",
|
||||||
|
"source_type": "S3",
|
||||||
|
"source_uri": "s3://bucket/document.pdf",
|
||||||
|
"score": 0.95,
|
||||||
|
"metadata": {
|
||||||
|
"additional": "metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nextToken": "pagination-token",
|
||||||
|
"guardrailAction": "NONE"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Custom Retrieval Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
kb_tool = BedrockKBRetrieverTool(
|
||||||
|
knowledge_base_id="your-kb-id",
|
||||||
|
retrieval_configuration={
|
||||||
|
"vectorSearchConfiguration": {
|
||||||
|
"numberOfResults": 10,
|
||||||
|
"overrideSearchType": "HYBRID"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_expert = Agent(
|
||||||
|
role='Policy Expert',
|
||||||
|
goal='Analyze company policies in detail',
|
||||||
|
backstory='I am an expert in corporate policy analysis with deep knowledge of regulatory requirements.',
|
||||||
|
tools=[kb_tool]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Data Sources
|
||||||
|
|
||||||
|
- Amazon S3
|
||||||
|
- Confluence
|
||||||
|
- Salesforce
|
||||||
|
- SharePoint
|
||||||
|
- Web pages
|
||||||
|
- Custom document locations
|
||||||
|
- Amazon Kendra
|
||||||
|
- SQL databases
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### Enterprise Knowledge Integration
|
||||||
|
- Enable CrewAI agents to access your organization's proprietary knowledge without exposing sensitive data
|
||||||
|
- Allow agents to make decisions based on your company's specific policies, procedures, and documentation
|
||||||
|
- Create agents that can answer questions based on your internal documentation while maintaining data security
|
||||||
|
|
||||||
|
### Specialized Domain Knowledge
|
||||||
|
- Connect CrewAI agents to domain-specific knowledge bases (legal, medical, technical) without retraining models
|
||||||
|
- Leverage existing knowledge repositories that are already maintained in your AWS environment
|
||||||
|
- Combine CrewAI's reasoning with domain-specific information from your knowledge bases
|
||||||
|
|
||||||
|
### Data-Driven Decision Making
|
||||||
|
- Ground CrewAI agent responses in your actual company data rather than general knowledge
|
||||||
|
- Ensure agents provide recommendations based on your specific business context and documentation
|
||||||
|
- Reduce hallucinations by retrieving factual information from your knowledge bases
|
||||||
|
|
||||||
|
### Scalable Information Access
|
||||||
|
- Access terabytes of organizational knowledge without embedding it all into your models
|
||||||
|
- Dynamically query only the relevant information needed for specific tasks
|
||||||
|
- Leverage AWS's scalable infrastructure to handle large knowledge bases efficiently
|
||||||
|
|
||||||
|
### Compliance and Governance
|
||||||
|
- Ensure CrewAI agents provide responses that align with your company's approved documentation
|
||||||
|
- Create auditable trails of information sources used by your agents
|
||||||
|
- Maintain control over what information sources your agents can access
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .retriever_tool import BedrockKBRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BedrockKBRetrieverTool"]
|
||||||
@@ -0,0 +1,262 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||||
|
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockKBRetrieverToolInput(BaseModel):
|
||||||
|
"""Input schema for BedrockKBRetrieverTool."""
|
||||||
|
|
||||||
|
query: str = Field(
|
||||||
|
..., description="The query to retrieve information from the knowledge base"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockKBRetrieverTool(BaseTool):
|
||||||
|
name: str = "Bedrock Knowledge Base Retriever Tool"
|
||||||
|
description: str = (
|
||||||
|
"Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput
|
||||||
|
knowledge_base_id: str = None
|
||||||
|
number_of_results: Optional[int] = 5
|
||||||
|
retrieval_configuration: Optional[Dict[str, Any]] = None
|
||||||
|
guardrail_configuration: Optional[Dict[str, Any]] = None
|
||||||
|
next_token: Optional[str] = None
|
||||||
|
package_dependencies: List[str] = ["boto3"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
knowledge_base_id: str = None,
|
||||||
|
number_of_results: Optional[int] = 5,
|
||||||
|
retrieval_configuration: Optional[Dict[str, Any]] = None,
|
||||||
|
guardrail_configuration: Optional[Dict[str, Any]] = None,
|
||||||
|
next_token: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
||||||
|
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
|
||||||
|
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
|
||||||
|
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
|
||||||
|
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Get knowledge_base_id from environment variable if not provided
|
||||||
|
self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID")
|
||||||
|
self.number_of_results = number_of_results
|
||||||
|
self.guardrail_configuration = guardrail_configuration
|
||||||
|
self.next_token = next_token
|
||||||
|
|
||||||
|
# Initialize retrieval_configuration with provided parameters or use the one provided
|
||||||
|
if retrieval_configuration is None:
|
||||||
|
self.retrieval_configuration = self._build_retrieval_configuration()
|
||||||
|
else:
|
||||||
|
self.retrieval_configuration = retrieval_configuration
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
self._validate_parameters()
|
||||||
|
|
||||||
|
# Update the description to include the knowledge base details
|
||||||
|
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
||||||
|
|
||||||
|
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||||
|
"""Build the retrieval configuration based on provided parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The constructed retrieval configuration
|
||||||
|
"""
|
||||||
|
vector_search_config = {}
|
||||||
|
|
||||||
|
# Add number of results if provided
|
||||||
|
if self.number_of_results is not None:
|
||||||
|
vector_search_config["numberOfResults"] = self.number_of_results
|
||||||
|
|
||||||
|
return {"vectorSearchConfiguration": vector_search_config}
|
||||||
|
|
||||||
|
def _validate_parameters(self):
|
||||||
|
"""Validate the parameters according to AWS API requirements."""
|
||||||
|
try:
|
||||||
|
# Validate knowledge_base_id
|
||||||
|
if not self.knowledge_base_id:
|
||||||
|
raise BedrockValidationError("knowledge_base_id cannot be empty")
|
||||||
|
if not isinstance(self.knowledge_base_id, str):
|
||||||
|
raise BedrockValidationError("knowledge_base_id must be a string")
|
||||||
|
if len(self.knowledge_base_id) > 10:
|
||||||
|
raise BedrockValidationError(
|
||||||
|
"knowledge_base_id must be 10 characters or less"
|
||||||
|
)
|
||||||
|
if not all(c.isalnum() for c in self.knowledge_base_id):
|
||||||
|
raise BedrockValidationError(
|
||||||
|
"knowledge_base_id must contain only alphanumeric characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate next_token if provided
|
||||||
|
if self.next_token:
|
||||||
|
if not isinstance(self.next_token, str):
|
||||||
|
raise BedrockValidationError("next_token must be a string")
|
||||||
|
if len(self.next_token) < 1 or len(self.next_token) > 2048:
|
||||||
|
raise BedrockValidationError(
|
||||||
|
"next_token must be between 1 and 2048 characters"
|
||||||
|
)
|
||||||
|
if " " in self.next_token:
|
||||||
|
raise BedrockValidationError("next_token cannot contain spaces")
|
||||||
|
|
||||||
|
# Validate number_of_results if provided
|
||||||
|
if self.number_of_results is not None:
|
||||||
|
if not isinstance(self.number_of_results, int):
|
||||||
|
raise BedrockValidationError("number_of_results must be an integer")
|
||||||
|
if self.number_of_results < 1:
|
||||||
|
raise BedrockValidationError(
|
||||||
|
"number_of_results must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BedrockValidationError as e:
|
||||||
|
raise BedrockValidationError(f"Parameter validation failed: {e!s}")
|
||||||
|
|
||||||
|
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Processed result with standardized format
|
||||||
|
"""
|
||||||
|
# Extract content
|
||||||
|
content_obj = result.get("content", {})
|
||||||
|
content = content_obj.get("text", "")
|
||||||
|
content_type = content_obj.get("type", "text")
|
||||||
|
|
||||||
|
# Extract location information
|
||||||
|
location = result.get("location", {})
|
||||||
|
location_type = location.get("type", "unknown")
|
||||||
|
source_uri = None
|
||||||
|
|
||||||
|
# Map for location types and their URI fields
|
||||||
|
location_mapping = {
|
||||||
|
"s3Location": {"field": "uri", "type": "S3"},
|
||||||
|
"confluenceLocation": {"field": "url", "type": "Confluence"},
|
||||||
|
"salesforceLocation": {"field": "url", "type": "Salesforce"},
|
||||||
|
"sharePointLocation": {"field": "url", "type": "SharePoint"},
|
||||||
|
"webLocation": {"field": "url", "type": "Web"},
|
||||||
|
"customDocumentLocation": {"field": "id", "type": "CustomDocument"},
|
||||||
|
"kendraDocumentLocation": {"field": "uri", "type": "KendraDocument"},
|
||||||
|
"sqlLocation": {"field": "query", "type": "SQL"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract the URI based on location type
|
||||||
|
for loc_key, config in location_mapping.items():
|
||||||
|
if loc_key in location:
|
||||||
|
source_uri = location[loc_key].get(config["field"])
|
||||||
|
if not location_type or location_type == "unknown":
|
||||||
|
location_type = config["type"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create result object
|
||||||
|
result_object = {
|
||||||
|
"content": content,
|
||||||
|
"content_type": content_type,
|
||||||
|
"source_type": location_type,
|
||||||
|
"source_uri": source_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional fields if available
|
||||||
|
if "score" in result:
|
||||||
|
result_object["score"] = result["score"]
|
||||||
|
|
||||||
|
if "metadata" in result:
|
||||||
|
result_object["metadata"] = result["metadata"]
|
||||||
|
|
||||||
|
# Handle byte content if present
|
||||||
|
if "byteContent" in content_obj:
|
||||||
|
result_object["byte_content"] = content_obj["byteContent"]
|
||||||
|
|
||||||
|
# Handle row content if present
|
||||||
|
if "row" in content_obj:
|
||||||
|
result_object["row_content"] = content_obj["row"]
|
||||||
|
|
||||||
|
return result_object
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the Bedrock Agent Runtime client
|
||||||
|
bedrock_agent_runtime = boto3.client(
|
||||||
|
"bedrock-agent-runtime",
|
||||||
|
region_name=os.getenv(
|
||||||
|
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")
|
||||||
|
),
|
||||||
|
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the request parameters
|
||||||
|
retrieve_params = {
|
||||||
|
"knowledgeBaseId": self.knowledge_base_id,
|
||||||
|
"retrievalQuery": {"text": query},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters if provided
|
||||||
|
if self.retrieval_configuration:
|
||||||
|
retrieve_params["retrievalConfiguration"] = self.retrieval_configuration
|
||||||
|
|
||||||
|
if self.guardrail_configuration:
|
||||||
|
retrieve_params["guardrailConfiguration"] = self.guardrail_configuration
|
||||||
|
|
||||||
|
if self.next_token:
|
||||||
|
retrieve_params["nextToken"] = self.next_token
|
||||||
|
|
||||||
|
# Make the retrieve API call
|
||||||
|
response = bedrock_agent_runtime.retrieve(**retrieve_params)
|
||||||
|
|
||||||
|
# Process the response
|
||||||
|
results = []
|
||||||
|
for result in response.get("retrievalResults", []):
|
||||||
|
processed_result = self._process_retrieval_result(result)
|
||||||
|
results.append(processed_result)
|
||||||
|
|
||||||
|
# Build the response object
|
||||||
|
response_object = {}
|
||||||
|
if results:
|
||||||
|
response_object["results"] = results
|
||||||
|
else:
|
||||||
|
response_object["message"] = "No results found for the given query."
|
||||||
|
|
||||||
|
if "nextToken" in response:
|
||||||
|
response_object["nextToken"] = response["nextToken"]
|
||||||
|
|
||||||
|
if "guardrailAction" in response:
|
||||||
|
response_object["guardrailAction"] = response["guardrailAction"]
|
||||||
|
|
||||||
|
# Return the results as a JSON string
|
||||||
|
return json.dumps(response_object, indent=2)
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = "Unknown"
|
||||||
|
error_message = str(e)
|
||||||
|
|
||||||
|
# Try to extract error code if available
|
||||||
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
|
error_code = e.response["Error"].get("Code", "Unknown")
|
||||||
|
error_message = e.response["Error"].get("Message", str(e))
|
||||||
|
|
||||||
|
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}")
|
||||||
52
lib/crewai-tools/src/crewai_tools/aws/s3/README.md
Normal file
52
lib/crewai-tools/src/crewai_tools/aws/s3/README.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# AWS S3 Tools
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
These tools provide a way to interact with Amazon S3, a cloud storage service.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install the crewai_tools package
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## AWS Connectivity
|
||||||
|
|
||||||
|
The tools use `boto3` to connect to AWS S3.
|
||||||
|
You can configure your environment to use AWS IAM roles, see [AWS IAM Roles documentation](https://docs.aws.amazon.com/sdk-for-python/v1/developer-guide/iam-roles.html#creating-an-iam-role)
|
||||||
|
|
||||||
|
Set the following environment variables:
|
||||||
|
|
||||||
|
- `CREW_AWS_REGION`
|
||||||
|
- `CREW_AWS_ACCESS_KEY_ID`
|
||||||
|
- `CREW_AWS_SEC_ACCESS_KEY`
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use the AWS S3 tools in your CrewAI agents, import the necessary tools and include them in your agent's configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools.aws.s3 import S3ReaderTool, S3WriterTool
|
||||||
|
|
||||||
|
# For reading from S3
|
||||||
|
@agent
|
||||||
|
def file_retriever(self) -> Agent:
|
||||||
|
return Agent(
|
||||||
|
config=self.agents_config['file_retriever'],
|
||||||
|
verbose=True,
|
||||||
|
tools=[S3ReaderTool()]
|
||||||
|
)
|
||||||
|
|
||||||
|
# For writing to S3
|
||||||
|
@agent
|
||||||
|
def file_uploader(self) -> Agent:
|
||||||
|
return Agent(
|
||||||
|
config=self.agents_config['file_uploader'],
|
||||||
|
verbose=True,
|
||||||
|
tools=[S3WriterTool()]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
These tools can be used to read from and write to S3 buckets within your CrewAI workflows. Make sure you have properly configured your AWS credentials as mentioned in the AWS Connectivity section above.
|
||||||
2
lib/crewai-tools/src/crewai_tools/aws/s3/__init__.py
Normal file
2
lib/crewai-tools/src/crewai_tools/aws/s3/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .reader_tool import S3ReaderTool
|
||||||
|
from .writer_tool import S3WriterTool
|
||||||
49
lib/crewai-tools/src/crewai_tools/aws/s3/reader_tool.py
Normal file
49
lib/crewai-tools/src/crewai_tools/aws/s3/reader_tool.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class S3ReaderToolInput(BaseModel):
|
||||||
|
"""Input schema for S3ReaderTool."""
|
||||||
|
|
||||||
|
file_path: str = Field(
|
||||||
|
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class S3ReaderTool(BaseTool):
|
||||||
|
name: str = "S3 Reader Tool"
|
||||||
|
description: str = "Reads a file from Amazon S3 given an S3 file path"
|
||||||
|
args_schema: Type[BaseModel] = S3ReaderToolInput
|
||||||
|
package_dependencies: List[str] = ["boto3"]
|
||||||
|
|
||||||
|
def _run(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||||
|
|
||||||
|
try:
|
||||||
|
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
|
||||||
|
aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content from S3
|
||||||
|
response = s3.get_object(Bucket=bucket_name, Key=object_key)
|
||||||
|
file_content = response["Body"].read().decode("utf-8")
|
||||||
|
|
||||||
|
return file_content
|
||||||
|
except ClientError as e:
|
||||||
|
return f"Error reading file from S3: {e!s}"
|
||||||
|
|
||||||
|
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||||
|
parts = file_path.replace("s3://", "").split("/", 1)
|
||||||
|
return parts[0], parts[1]
|
||||||
49
lib/crewai-tools/src/crewai_tools/aws/s3/writer_tool.py
Normal file
49
lib/crewai-tools/src/crewai_tools/aws/s3/writer_tool.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class S3WriterToolInput(BaseModel):
|
||||||
|
"""Input schema for S3WriterTool."""
|
||||||
|
|
||||||
|
file_path: str = Field(
|
||||||
|
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
|
||||||
|
)
|
||||||
|
content: str = Field(..., description="Content to write to the file")
|
||||||
|
|
||||||
|
|
||||||
|
class S3WriterTool(BaseTool):
|
||||||
|
name: str = "S3 Writer Tool"
|
||||||
|
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
|
||||||
|
args_schema: Type[BaseModel] = S3WriterToolInput
|
||||||
|
package_dependencies: List[str] = ["boto3"]
|
||||||
|
|
||||||
|
def _run(self, file_path: str, content: str) -> str:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||||
|
|
||||||
|
try:
|
||||||
|
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
|
||||||
|
aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
s3.put_object(
|
||||||
|
Bucket=bucket_name, Key=object_key, Body=content.encode("utf-8")
|
||||||
|
)
|
||||||
|
return f"Successfully wrote content to {file_path}"
|
||||||
|
except ClientError as e:
|
||||||
|
return f"Error writing file to S3: {e!s}"
|
||||||
|
|
||||||
|
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||||
|
parts = file_path.replace("s3://", "").split("/", 1)
|
||||||
|
return parts[0], parts[1]
|
||||||
131
lib/crewai-tools/src/crewai_tools/printer.py
Normal file
131
lib/crewai-tools/src/crewai_tools/printer.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Utility for colored console output."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Printer:
|
||||||
|
"""Handles colored console output formatting."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print(content: str, color: Optional[str] = None) -> None:
|
||||||
|
"""Prints content with optional color formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed.
|
||||||
|
color: Optional color name to format the output. If provided,
|
||||||
|
must match one of the _print_* methods available in this class.
|
||||||
|
If not provided or if the color is not supported, prints without
|
||||||
|
formatting.
|
||||||
|
"""
|
||||||
|
if hasattr(Printer, f"_print_{color}"):
|
||||||
|
getattr(Printer, f"_print_{color}")(content)
|
||||||
|
else:
|
||||||
|
print(content)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_purple(content: str) -> None:
|
||||||
|
"""Prints content in bold purple color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold purple.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[95m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_green(content: str) -> None:
|
||||||
|
"""Prints content in bold green color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold green.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[92m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_purple(content: str) -> None:
|
||||||
|
"""Prints content in purple color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in purple.
|
||||||
|
"""
|
||||||
|
print(f"\033[95m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_red(content: str) -> None:
|
||||||
|
"""Prints content in red color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in red.
|
||||||
|
"""
|
||||||
|
print(f"\033[91m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_blue(content: str) -> None:
|
||||||
|
"""Prints content in bold blue color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold blue.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[94m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_yellow(content: str) -> None:
|
||||||
|
"""Prints content in yellow color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in yellow.
|
||||||
|
"""
|
||||||
|
print(f"\033[93m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_yellow(content: str) -> None:
|
||||||
|
"""Prints content in bold yellow color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold yellow.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[93m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_cyan(content: str) -> None:
|
||||||
|
"""Prints content in cyan color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in cyan.
|
||||||
|
"""
|
||||||
|
print(f"\033[96m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_cyan(content: str) -> None:
|
||||||
|
"""Prints content in bold cyan color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold cyan.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[96m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_magenta(content: str) -> None:
|
||||||
|
"""Prints content in magenta color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in magenta.
|
||||||
|
"""
|
||||||
|
print(f"\033[35m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_bold_magenta(content: str) -> None:
|
||||||
|
"""Prints content in bold magenta color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in bold magenta.
|
||||||
|
"""
|
||||||
|
print(f"\033[1m\033[35m {content}\033[00m")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_green(content: str) -> None:
|
||||||
|
"""Prints content in green color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The string to be printed in green.
|
||||||
|
"""
|
||||||
|
print(f"\033[32m {content}\033[00m")
|
||||||
9
lib/crewai-tools/src/crewai_tools/rag/__init__.py
Normal file
9
lib/crewai-tools/src/crewai_tools/rag/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from crewai_tools.rag.core import RAG, EmbeddingService
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RAG",
|
||||||
|
"DataType",
|
||||||
|
"EmbeddingService",
|
||||||
|
]
|
||||||
41
lib/crewai-tools/src/crewai_tools/rag/base_loader.py
Normal file
41
lib/crewai-tools/src/crewai_tools/rag/base_loader.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai_tools.rag.misc import compute_sha256
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderResult(BaseModel):
|
||||||
|
content: str = Field(description="The text content of the source")
|
||||||
|
source: str = Field(description="The source of the content", default="unknown")
|
||||||
|
metadata: Dict[str, Any] = Field(
|
||||||
|
description="The metadata of the source", default_factory=dict
|
||||||
|
)
|
||||||
|
doc_id: str = Field(description="The id of the document")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLoader(ABC):
|
||||||
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||||
|
self.config = config or {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
|
||||||
|
|
||||||
|
def generate_doc_id(
|
||||||
|
self, source_ref: str | None = None, content: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a unique document id based on the source reference and content.
|
||||||
|
If the source reference is not provided, the content is used as the source reference.
|
||||||
|
If the content is not provided, the source reference is used as the content.
|
||||||
|
If both are provided, the source reference is used as the content.
|
||||||
|
|
||||||
|
Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
source_ref = source_ref or ""
|
||||||
|
content = content or ""
|
||||||
|
|
||||||
|
return compute_sha256(source_ref + content)
|
||||||
20
lib/crewai-tools/src/crewai_tools/rag/chunkers/__init__.py
Normal file
20
lib/crewai-tools/src/crewai_tools/rag/chunkers/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
|
||||||
|
from crewai_tools.rag.chunkers.structured_chunker import (
|
||||||
|
CsvChunker,
|
||||||
|
JsonChunker,
|
||||||
|
XmlChunker,
|
||||||
|
)
|
||||||
|
from crewai_tools.rag.chunkers.text_chunker import DocxChunker, MdxChunker, TextChunker
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseChunker",
|
||||||
|
"CsvChunker",
|
||||||
|
"DefaultChunker",
|
||||||
|
"DocxChunker",
|
||||||
|
"JsonChunker",
|
||||||
|
"MdxChunker",
|
||||||
|
"TextChunker",
|
||||||
|
"XmlChunker",
|
||||||
|
]
|
||||||
181
lib/crewai-tools/src/crewai_tools/rag/chunkers/base_chunker.py
Normal file
181
lib/crewai-tools/src/crewai_tools/rag/chunkers/base_chunker.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import re
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveCharacterTextSplitter:
|
||||||
|
"""
|
||||||
|
A text splitter that recursively splits text based on a hierarchy of separators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 4000,
|
||||||
|
chunk_overlap: int = 200,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RecursiveCharacterTextSplitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: Maximum size of each chunk
|
||||||
|
chunk_overlap: Number of characters to overlap between chunks
|
||||||
|
separators: List of separators to use for splitting (in order of preference)
|
||||||
|
keep_separator: Whether to keep the separator in the split text
|
||||||
|
"""
|
||||||
|
if chunk_overlap >= chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._chunk_size = chunk_size
|
||||||
|
self._chunk_overlap = chunk_overlap
|
||||||
|
self._keep_separator = keep_separator
|
||||||
|
|
||||||
|
self._separators = separators or [
|
||||||
|
"\n\n",
|
||||||
|
"\n",
|
||||||
|
" ",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
return self._split_text(text, self._separators)
|
||||||
|
|
||||||
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||||
|
separator = separators[-1]
|
||||||
|
new_separators = []
|
||||||
|
|
||||||
|
for i, sep in enumerate(separators):
|
||||||
|
if sep == "":
|
||||||
|
separator = sep
|
||||||
|
break
|
||||||
|
if re.search(re.escape(sep), text):
|
||||||
|
separator = sep
|
||||||
|
new_separators = separators[i + 1 :]
|
||||||
|
break
|
||||||
|
|
||||||
|
splits = self._split_text_with_separator(text, separator)
|
||||||
|
|
||||||
|
good_splits = []
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
if len(split) < self._chunk_size:
|
||||||
|
good_splits.append(split)
|
||||||
|
else:
|
||||||
|
if new_separators:
|
||||||
|
other_info = self._split_text(split, new_separators)
|
||||||
|
good_splits.extend(other_info)
|
||||||
|
else:
|
||||||
|
good_splits.extend(self._split_by_characters(split))
|
||||||
|
|
||||||
|
return self._merge_splits(good_splits, separator)
|
||||||
|
|
||||||
|
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
|
||||||
|
if separator == "":
|
||||||
|
return list(text)
|
||||||
|
|
||||||
|
if self._keep_separator and separator in text:
|
||||||
|
parts = text.split(separator)
|
||||||
|
splits = []
|
||||||
|
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if i == 0:
|
||||||
|
splits.append(part)
|
||||||
|
elif i == len(parts) - 1:
|
||||||
|
if part:
|
||||||
|
splits.append(separator + part)
|
||||||
|
else:
|
||||||
|
if part:
|
||||||
|
splits.append(separator + part)
|
||||||
|
else:
|
||||||
|
if splits:
|
||||||
|
splits[-1] += separator
|
||||||
|
|
||||||
|
return [s for s in splits if s]
|
||||||
|
return text.split(separator)
|
||||||
|
|
||||||
|
def _split_by_characters(self, text: str) -> List[str]:
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(text), self._chunk_size):
|
||||||
|
chunks.append(text[i : i + self._chunk_size])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||||
|
"""Merge splits into chunks with proper overlap."""
|
||||||
|
docs = []
|
||||||
|
current_doc = []
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
split_len = len(split)
|
||||||
|
|
||||||
|
if total + split_len > self._chunk_size and current_doc:
|
||||||
|
if separator == "":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
if self._keep_separator and separator == " ":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
doc = separator.join(current_doc)
|
||||||
|
|
||||||
|
if doc:
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
# Handle overlap by keeping some of the previous content
|
||||||
|
while total > self._chunk_overlap and len(current_doc) > 1:
|
||||||
|
removed = current_doc.pop(0)
|
||||||
|
total -= len(removed)
|
||||||
|
if separator != "":
|
||||||
|
total -= len(separator)
|
||||||
|
|
||||||
|
current_doc.append(split)
|
||||||
|
total += split_len
|
||||||
|
if separator != "" and len(current_doc) > 1:
|
||||||
|
total += len(separator)
|
||||||
|
|
||||||
|
if current_doc:
|
||||||
|
if separator == "":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
if self._keep_separator and separator == " ":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
doc = separator.join(current_doc)
|
||||||
|
|
||||||
|
if doc:
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChunker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
chunk_overlap: int = 200,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Chunker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: Maximum size of each chunk
|
||||||
|
chunk_overlap: Number of characters to overlap between chunks
|
||||||
|
separators: List of separators to use for splitting
|
||||||
|
keep_separator: Whether to keep separators in the chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separators=separators,
|
||||||
|
keep_separator=keep_separator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> List[str]:
|
||||||
|
if not text or not text.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
return self._splitter.split_text(text)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 2000,
|
||||||
|
chunk_overlap: int = 20,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class CsvChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 1200,
|
||||||
|
chunk_overlap: int = 100,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||||
|
"\n", # Line breaks
|
||||||
|
" | ", # Column separators
|
||||||
|
", ", # Comma separators
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 2000,
|
||||||
|
chunk_overlap: int = 200,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n\n", # Object/array boundaries
|
||||||
|
"\n", # Line breaks
|
||||||
|
"},", # Object endings
|
||||||
|
"],", # Array endings
|
||||||
|
", ", # Property separators
|
||||||
|
": ", # Key-value separators
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
|
|
||||||
|
|
||||||
|
class XmlChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 2500,
|
||||||
|
chunk_overlap: int = 250,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n\n", # Element boundaries
|
||||||
|
"\n", # Line breaks
|
||||||
|
">", # Tag endings
|
||||||
|
". ", # Sentence endings (for text content)
|
||||||
|
"! ", # Exclamation endings
|
||||||
|
"? ", # Question endings
|
||||||
|
", ", # Comma separators
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class TextChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 1500,
|
||||||
|
chunk_overlap: int = 150,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n\n\n", # Multiple line breaks (sections)
|
||||||
|
"\n\n", # Paragraph breaks
|
||||||
|
"\n", # Line breaks
|
||||||
|
". ", # Sentence endings
|
||||||
|
"! ", # Exclamation endings
|
||||||
|
"? ", # Question endings
|
||||||
|
"; ", # Semicolon breaks
|
||||||
|
", ", # Comma breaks
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
|
|
||||||
|
|
||||||
|
class DocxChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 2500,
|
||||||
|
chunk_overlap: int = 250,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n\n\n", # Multiple line breaks (major sections)
|
||||||
|
"\n\n", # Paragraph breaks
|
||||||
|
"\n", # Line breaks
|
||||||
|
". ", # Sentence endings
|
||||||
|
"! ", # Exclamation endings
|
||||||
|
"? ", # Question endings
|
||||||
|
"; ", # Semicolon breaks
|
||||||
|
", ", # Comma breaks
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
|
|
||||||
|
|
||||||
|
class MdxChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 3000,
|
||||||
|
chunk_overlap: int = 300,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n## ", # H2 headers (major sections)
|
||||||
|
"\n### ", # H3 headers (subsections)
|
||||||
|
"\n#### ", # H4 headers (sub-subsections)
|
||||||
|
"\n\n", # Paragraph breaks
|
||||||
|
"\n```", # Code block boundaries
|
||||||
|
"\n", # Line breaks
|
||||||
|
". ", # Sentence endings
|
||||||
|
"! ", # Exclamation endings
|
||||||
|
"? ", # Question endings
|
||||||
|
"; ", # Semicolon breaks
|
||||||
|
", ", # Comma breaks
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 2500,
|
||||||
|
chunk_overlap: int = 250,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
):
|
||||||
|
if separators is None:
|
||||||
|
separators = [
|
||||||
|
"\n\n\n", # Major section breaks
|
||||||
|
"\n\n", # Paragraph breaks
|
||||||
|
"\n", # Line breaks
|
||||||
|
". ", # Sentence endings
|
||||||
|
"! ", # Exclamation endings
|
||||||
|
"? ", # Question endings
|
||||||
|
"; ", # Semicolon breaks
|
||||||
|
", ", # Comma breaks
|
||||||
|
" ", # Word breaks
|
||||||
|
"", # Character level
|
||||||
|
]
|
||||||
|
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||||
252
lib/crewai-tools/src/crewai_tools/rag/core.py
Normal file
252
lib/crewai-tools/src/crewai_tools/rag/core.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import litellm
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
from crewai_tools.rag.misc import compute_sha256
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
|
||||||
|
self.model = model
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> List[float]:
|
||||||
|
try:
|
||||||
|
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
||||||
|
return [data["embedding"] for data in response.data]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating batch embeddings: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
content: str
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
data_type: DataType = DataType.TEXT
|
||||||
|
source: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RAG(Adapter):
|
||||||
|
collection_name: str = "crewai_knowledge_base"
|
||||||
|
persist_directory: Optional[str] = None
|
||||||
|
embedding_model: str = "text-embedding-3-large"
|
||||||
|
summarize: bool = False
|
||||||
|
top_k: int = 5
|
||||||
|
embedding_config: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
_collection: Any = PrivateAttr()
|
||||||
|
_embedding_service: EmbeddingService = PrivateAttr()
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
try:
|
||||||
|
if self.persist_directory:
|
||||||
|
self._client = chromadb.PersistentClient(path=self.persist_directory)
|
||||||
|
else:
|
||||||
|
self._client = chromadb.Client()
|
||||||
|
|
||||||
|
self._collection = self._client.get_or_create_collection(
|
||||||
|
name=self.collection_name,
|
||||||
|
metadata={
|
||||||
|
"hnsw:space": "cosine",
|
||||||
|
"description": "CrewAI Knowledge Base",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embedding_service = EmbeddingService(
|
||||||
|
model=self.embedding_model, **self.embedding_config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
super().model_post_init(__context)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
content: str | Path,
|
||||||
|
data_type: Optional[Union[str, DataType]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
loader: Optional[BaseLoader] = None,
|
||||||
|
chunker: Optional[BaseChunker] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
source_content = SourceContent(content)
|
||||||
|
|
||||||
|
data_type = self._get_data_type(data_type=data_type, content=source_content)
|
||||||
|
|
||||||
|
if not loader:
|
||||||
|
loader = data_type.get_loader()
|
||||||
|
|
||||||
|
if not chunker:
|
||||||
|
chunker = data_type.get_chunker()
|
||||||
|
|
||||||
|
loader_result = loader.load(source_content)
|
||||||
|
doc_id = loader_result.doc_id
|
||||||
|
|
||||||
|
existing_doc = self._collection.get(
|
||||||
|
where={"source": source_content.source_ref}, limit=1
|
||||||
|
)
|
||||||
|
existing_doc_id = (
|
||||||
|
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||||
|
if existing_doc["metadatas"]
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_doc_id == doc_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Document with source {loader_result.source} already exists"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||||
|
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||||
|
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||||
|
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
chunks = chunker.chunk(loader_result.content)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
doc_metadata = (metadata or {}).copy()
|
||||||
|
doc_metadata["chunk_index"] = i
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
id=compute_sha256(chunk),
|
||||||
|
content=chunk,
|
||||||
|
metadata=doc_metadata,
|
||||||
|
data_type=data_type,
|
||||||
|
source=loader_result.source,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
logger.warning("No documents to add")
|
||||||
|
return
|
||||||
|
|
||||||
|
contents = [doc.content for doc in documents]
|
||||||
|
try:
|
||||||
|
embeddings = self._embedding_service.embed_batch(contents)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate embeddings: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
ids = [doc.id for doc in documents]
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
doc_metadata = doc.metadata.copy()
|
||||||
|
doc_metadata.update(
|
||||||
|
{
|
||||||
|
"data_type": doc.data_type.value,
|
||||||
|
"source": doc.source,
|
||||||
|
"doc_id": doc_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
metadatas.append(doc_metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._collection.add(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=contents,
|
||||||
|
metadatas=metadatas,
|
||||||
|
)
|
||||||
|
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||||
|
|
||||||
|
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
|
||||||
|
try:
|
||||||
|
question_embedding = self._embedding_service.embed_text(question)
|
||||||
|
|
||||||
|
results = self._collection.query(
|
||||||
|
query_embeddings=[question_embedding],
|
||||||
|
n_results=self.top_k,
|
||||||
|
where=where,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not results
|
||||||
|
or not results.get("documents")
|
||||||
|
or not results["documents"][0]
|
||||||
|
):
|
||||||
|
return "No relevant content found."
|
||||||
|
|
||||||
|
documents = results["documents"][0]
|
||||||
|
metadatas = results.get("metadatas", [None])[0] or []
|
||||||
|
distances = results.get("distances", [None])[0] or []
|
||||||
|
|
||||||
|
# Return sources with relevance scores
|
||||||
|
formatted_results = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||||
|
distance = distances[i] if i < len(distances) else 1.0
|
||||||
|
source = metadata.get("source", "unknown") if metadata else "unknown"
|
||||||
|
score = (
|
||||||
|
1 - distance if distance is not None else 0
|
||||||
|
) # Convert distance to similarity
|
||||||
|
formatted_results.append(
|
||||||
|
f"[Source: {source}, Relevance: {score:.3f}]\n{doc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n\n".join(formatted_results)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Query failed: {e}")
|
||||||
|
return f"Error querying knowledge base: {e}"
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
try:
|
||||||
|
self._client.delete_collection(self.collection_name)
|
||||||
|
logger.info(f"Deleted collection: {self.collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete collection: {e}")
|
||||||
|
|
||||||
|
def get_collection_info(self) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
count = self._collection.count()
|
||||||
|
return {
|
||||||
|
"name": self.collection_name,
|
||||||
|
"count": count,
|
||||||
|
"embedding_model": self.embedding_model,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get collection info: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def _get_data_type(
|
||||||
|
self, content: SourceContent, data_type: str | DataType | None = None
|
||||||
|
) -> DataType:
|
||||||
|
try:
|
||||||
|
if isinstance(data_type, str):
|
||||||
|
return DataType(data_type)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return content.data_type
|
||||||
161
lib/crewai-tools/src/crewai_tools/rag/data_types.py
Normal file
161
lib/crewai-tools/src/crewai_tools/rag/data_types.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from enum import Enum
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(str, Enum):
|
||||||
|
PDF_FILE = "pdf_file"
|
||||||
|
TEXT_FILE = "text_file"
|
||||||
|
CSV = "csv"
|
||||||
|
JSON = "json"
|
||||||
|
XML = "xml"
|
||||||
|
DOCX = "docx"
|
||||||
|
MDX = "mdx"
|
||||||
|
|
||||||
|
# Database types
|
||||||
|
MYSQL = "mysql"
|
||||||
|
POSTGRES = "postgres"
|
||||||
|
|
||||||
|
# Repository types
|
||||||
|
GITHUB = "github"
|
||||||
|
DIRECTORY = "directory"
|
||||||
|
|
||||||
|
# Web types
|
||||||
|
WEBSITE = "website"
|
||||||
|
DOCS_SITE = "docs_site"
|
||||||
|
YOUTUBE_VIDEO = "youtube_video"
|
||||||
|
YOUTUBE_CHANNEL = "youtube_channel"
|
||||||
|
|
||||||
|
# Raw types
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
def get_chunker(self) -> BaseChunker:
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
chunkers = {
|
||||||
|
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||||
|
DataType.MDX: ("text_chunker", "MdxChunker"),
|
||||||
|
# Structured formats
|
||||||
|
DataType.CSV: ("structured_chunker", "CsvChunker"),
|
||||||
|
DataType.JSON: ("structured_chunker", "JsonChunker"),
|
||||||
|
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||||
|
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||||
|
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.GITHUB: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.MYSQL: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.POSTGRES: ("text_chunker", "TextChunker"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self not in chunkers:
|
||||||
|
raise ValueError(f"No chunker defined for {self}")
|
||||||
|
module_name, class_name = chunkers[self]
|
||||||
|
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = import_module(module_path)
|
||||||
|
return getattr(module, class_name)()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading chunker for {self}: {e}")
|
||||||
|
|
||||||
|
def get_loader(self) -> BaseLoader:
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
loaders = {
|
||||||
|
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||||
|
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||||
|
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||||
|
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||||
|
DataType.WEBSITE: ("webpage_loader", "WebPageLoader"),
|
||||||
|
DataType.MDX: ("mdx_loader", "MDXLoader"),
|
||||||
|
DataType.JSON: ("json_loader", "JSONLoader"),
|
||||||
|
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||||
|
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||||
|
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||||
|
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
|
||||||
|
DataType.YOUTUBE_CHANNEL: (
|
||||||
|
"youtube_channel_loader",
|
||||||
|
"YoutubeChannelLoader",
|
||||||
|
),
|
||||||
|
DataType.GITHUB: ("github_loader", "GithubLoader"),
|
||||||
|
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
|
||||||
|
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
|
||||||
|
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self not in loaders:
|
||||||
|
raise ValueError(f"No loader defined for {self}")
|
||||||
|
module_name, class_name = loaders[self]
|
||||||
|
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||||
|
try:
|
||||||
|
module = import_module(module_path)
|
||||||
|
return getattr(module, class_name)()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading loader for {self}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class DataTypes:
|
||||||
|
@staticmethod
|
||||||
|
def from_content(content: str | Path | None = None) -> DataType:
|
||||||
|
if content is None:
|
||||||
|
return DataType.TEXT
|
||||||
|
|
||||||
|
if isinstance(content, Path):
|
||||||
|
content = str(content)
|
||||||
|
|
||||||
|
is_url = False
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
url = urlparse(content)
|
||||||
|
is_url = (url.scheme and url.netloc) or url.scheme == "file"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_file_type(path: str) -> DataType | None:
|
||||||
|
mapping = {
|
||||||
|
".pdf": DataType.PDF_FILE,
|
||||||
|
".csv": DataType.CSV,
|
||||||
|
".mdx": DataType.MDX,
|
||||||
|
".md": DataType.MDX,
|
||||||
|
".docx": DataType.DOCX,
|
||||||
|
".json": DataType.JSON,
|
||||||
|
".xml": DataType.XML,
|
||||||
|
".txt": DataType.TEXT_FILE,
|
||||||
|
}
|
||||||
|
for ext, dtype in mapping.items():
|
||||||
|
if path.endswith(ext):
|
||||||
|
return dtype
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_url:
|
||||||
|
dtype = get_file_type(url.path)
|
||||||
|
if dtype:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
|
||||||
|
return DataType.DOCS_SITE
|
||||||
|
if "github.com" in url.netloc:
|
||||||
|
return DataType.GITHUB
|
||||||
|
|
||||||
|
return DataType.WEBSITE
|
||||||
|
|
||||||
|
if os.path.isfile(content):
|
||||||
|
dtype = get_file_type(content)
|
||||||
|
if dtype:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
if os.path.exists(content):
|
||||||
|
return DataType.TEXT_FILE
|
||||||
|
elif os.path.isdir(content):
|
||||||
|
return DataType.DIRECTORY
|
||||||
|
|
||||||
|
return DataType.TEXT
|
||||||
27
lib/crewai-tools/src/crewai_tools/rag/loaders/__init__.py
Normal file
27
lib/crewai-tools/src/crewai_tools/rag/loaders/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||||
|
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||||
|
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||||
|
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||||
|
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||||
|
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
|
||||||
|
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||||
|
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||||
|
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||||
|
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
|
||||||
|
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CSVLoader",
|
||||||
|
"DOCXLoader",
|
||||||
|
"DirectoryLoader",
|
||||||
|
"JSONLoader",
|
||||||
|
"MDXLoader",
|
||||||
|
"PDFLoader",
|
||||||
|
"TextFileLoader",
|
||||||
|
"TextLoader",
|
||||||
|
"WebPageLoader",
|
||||||
|
"XMLLoader",
|
||||||
|
"YoutubeChannelLoader",
|
||||||
|
"YoutubeVideoLoader",
|
||||||
|
]
|
||||||
74
lib/crewai-tools/src/crewai_tools/rag/loaders/csv_loader.py
Normal file
74
lib/crewai-tools/src/crewai_tools/rag/loaders/csv_loader.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import csv
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class CSVLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
|
||||||
|
content_str = source_content.source
|
||||||
|
if source_content.is_url():
|
||||||
|
content_str = self._load_from_url(content_str, kwargs)
|
||||||
|
elif source_content.path_exists():
|
||||||
|
content_str = self._load_from_file(content_str)
|
||||||
|
|
||||||
|
return self._parse_csv(content_str, source_ref)
|
||||||
|
|
||||||
|
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"Accept": "text/csv, application/csv, text/plain",
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}")
|
||||||
|
|
||||||
|
def _load_from_file(self, path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
def _parse_csv(self, content: str, source_ref: str) -> LoaderResult:
|
||||||
|
try:
|
||||||
|
csv_reader = csv.DictReader(StringIO(content))
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
headers = csv_reader.fieldnames
|
||||||
|
|
||||||
|
if headers:
|
||||||
|
text_parts.append("Headers: " + " | ".join(headers))
|
||||||
|
text_parts.append("-" * 50)
|
||||||
|
|
||||||
|
for row_num, row in enumerate(csv_reader, 1):
|
||||||
|
row_text = " | ".join([f"{k}: {v}" for k, v in row.items() if v])
|
||||||
|
text_parts.append(f"Row {row_num}: {row_text}")
|
||||||
|
|
||||||
|
text = "\n".join(text_parts)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"format": "csv",
|
||||||
|
"columns": headers,
|
||||||
|
"rows": len(text_parts) - 2 if headers else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
text = content
|
||||||
|
metadata = {"format": "csv", "parse_error": str(e)}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=text,
|
||||||
|
source=source_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||||
|
)
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load and process all files from a directory recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Directory path or URL to a directory listing
|
||||||
|
**kwargs: Additional options:
|
||||||
|
- recursive: bool (default True) - Whether to search recursively
|
||||||
|
- include_extensions: list - Only include files with these extensions
|
||||||
|
- exclude_extensions: list - Exclude files with these extensions
|
||||||
|
- max_files: int - Maximum number of files to process
|
||||||
|
"""
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
|
||||||
|
if source_content.is_url():
|
||||||
|
raise ValueError(
|
||||||
|
"URL directory loading is not supported. Please provide a local directory path."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(source_ref):
|
||||||
|
raise FileNotFoundError(f"Directory does not exist: {source_ref}")
|
||||||
|
|
||||||
|
if not os.path.isdir(source_ref):
|
||||||
|
raise ValueError(f"Path is not a directory: {source_ref}")
|
||||||
|
|
||||||
|
return self._process_directory(source_ref, kwargs)
|
||||||
|
|
||||||
|
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
|
||||||
|
recursive = kwargs.get("recursive", True)
|
||||||
|
include_extensions = kwargs.get("include_extensions", None)
|
||||||
|
exclude_extensions = kwargs.get("exclude_extensions", None)
|
||||||
|
max_files = kwargs.get("max_files", None)
|
||||||
|
|
||||||
|
files = self._find_files(
|
||||||
|
dir_path, recursive, include_extensions, exclude_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_files and len(files) > max_files:
|
||||||
|
files = files[:max_files]
|
||||||
|
|
||||||
|
all_contents = []
|
||||||
|
processed_files = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
try:
|
||||||
|
result = self._process_single_file(file_path)
|
||||||
|
if result:
|
||||||
|
all_contents.append(f"=== File: {file_path} ===\n{result.content}")
|
||||||
|
processed_files.append(
|
||||||
|
{
|
||||||
|
"path": file_path,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
"source": result.source,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing {file_path}: {e!s}"
|
||||||
|
errors.append(error_msg)
|
||||||
|
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
|
||||||
|
|
||||||
|
combined_content = "\n\n".join(all_contents)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"format": "directory",
|
||||||
|
"directory_path": dir_path,
|
||||||
|
"total_files": len(files),
|
||||||
|
"processed_files": len(processed_files),
|
||||||
|
"errors": len(errors),
|
||||||
|
"file_details": processed_files,
|
||||||
|
"error_details": errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=combined_content,
|
||||||
|
source=dir_path,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_files(
|
||||||
|
self,
|
||||||
|
dir_path: str,
|
||||||
|
recursive: bool,
|
||||||
|
include_ext: List[str] | None = None,
|
||||||
|
exclude_ext: List[str] | None = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Find all files in directory matching criteria."""
|
||||||
|
files = []
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
for root, dirs, filenames in os.walk(dir_path):
|
||||||
|
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if self._should_include_file(filename, include_ext, exclude_ext):
|
||||||
|
files.append(os.path.join(root, filename))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
for item in os.listdir(dir_path):
|
||||||
|
item_path = os.path.join(dir_path, item)
|
||||||
|
if os.path.isfile(item_path) and self._should_include_file(
|
||||||
|
item, include_ext, exclude_ext
|
||||||
|
):
|
||||||
|
files.append(item_path)
|
||||||
|
except PermissionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
def _should_include_file(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
include_ext: List[str] = None,
|
||||||
|
exclude_ext: List[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if a file should be included based on criteria."""
|
||||||
|
if filename.startswith("."):
|
||||||
|
return False
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(filename.lower())
|
||||||
|
|
||||||
|
if include_ext:
|
||||||
|
if ext not in [
|
||||||
|
e.lower() if e.startswith(".") else f".{e.lower()}" for e in include_ext
|
||||||
|
]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if exclude_ext:
|
||||||
|
if ext in [
|
||||||
|
e.lower() if e.startswith(".") else f".{e.lower()}" for e in exclude_ext
|
||||||
|
]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_single_file(self, file_path: str) -> LoaderResult:
|
||||||
|
from crewai_tools.rag.data_types import DataTypes
|
||||||
|
|
||||||
|
data_type = DataTypes.from_content(Path(file_path))
|
||||||
|
|
||||||
|
loader = data_type.get_loader()
|
||||||
|
|
||||||
|
result = loader.load(SourceContent(file_path))
|
||||||
|
|
||||||
|
if result.metadata is None:
|
||||||
|
result.metadata = {}
|
||||||
|
|
||||||
|
result.metadata.update(
|
||||||
|
{
|
||||||
|
"file_path": file_path,
|
||||||
|
"file_size": os.path.getsize(file_path),
|
||||||
|
"data_type": str(data_type),
|
||||||
|
"loader_type": loader.__class__.__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Documentation site loader."""
|
||||||
|
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class DocsSiteLoader(BaseLoader):
|
||||||
|
"""Loader for documentation websites."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a documentation site.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Documentation site URL
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with documentation content
|
||||||
|
"""
|
||||||
|
docs_url = source.source
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(docs_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.decompose()
|
||||||
|
|
||||||
|
title = soup.find("title")
|
||||||
|
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||||
|
|
||||||
|
main_content = None
|
||||||
|
for selector in [
|
||||||
|
"main",
|
||||||
|
"article",
|
||||||
|
'[role="main"]',
|
||||||
|
".content",
|
||||||
|
"#content",
|
||||||
|
".documentation",
|
||||||
|
]:
|
||||||
|
main_content = soup.select_one(selector)
|
||||||
|
if main_content:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not main_content:
|
||||||
|
main_content = soup.find("body")
|
||||||
|
|
||||||
|
if not main_content:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to extract content from documentation site: {docs_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts = [f"Title: {title_text}", ""]
|
||||||
|
|
||||||
|
headings = main_content.find_all(["h1", "h2", "h3"])
|
||||||
|
if headings:
|
||||||
|
text_parts.append("Table of Contents:")
|
||||||
|
for heading in headings[:15]:
|
||||||
|
level = int(heading.name[1])
|
||||||
|
indent = " " * (level - 1)
|
||||||
|
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
text = main_content.get_text(separator="\n", strip=True)
|
||||||
|
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||||
|
text_parts.extend(lines)
|
||||||
|
|
||||||
|
nav_links = []
|
||||||
|
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
|
||||||
|
nav = soup.select_one(nav_selector)
|
||||||
|
if nav:
|
||||||
|
links = nav.find_all("a", href=True)
|
||||||
|
for link in links[:20]:
|
||||||
|
href = link["href"]
|
||||||
|
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||||
|
full_url = urljoin(docs_url, href)
|
||||||
|
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||||
|
|
||||||
|
if nav_links:
|
||||||
|
text_parts.append("")
|
||||||
|
text_parts.append("Related documentation pages:")
|
||||||
|
text_parts.extend(nav_links[:10])
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": docs_url,
|
||||||
|
"title": title_text,
|
||||||
|
"domain": urlparse(docs_url).netloc,
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=docs_url, content=content),
|
||||||
|
)
|
||||||
81
lib/crewai-tools/src/crewai_tools/rag/loaders/docx_loader.py
Normal file
81
lib/crewai-tools/src/crewai_tools/rag/loaders/docx_loader.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class DOCXLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
try:
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]"
|
||||||
|
)
|
||||||
|
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
|
||||||
|
if source_content.is_url():
|
||||||
|
temp_file = self._download_from_url(source_ref, kwargs)
|
||||||
|
try:
|
||||||
|
return self._load_from_file(temp_file, source_ref, DocxDocument)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
elif source_content.path_exists():
|
||||||
|
return self._load_from_file(source_ref, source_ref, DocxDocument)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Source must be a valid file path or URL, got: {source_content.source}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_from_url(self, url: str, kwargs: dict) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Create temporary file to save the DOCX content
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as temp_file:
|
||||||
|
temp_file.write(response.content)
|
||||||
|
return temp_file.name
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}")
|
||||||
|
|
||||||
|
def _load_from_file(
|
||||||
|
self, file_path: str, source_ref: str, DocxDocument
|
||||||
|
) -> LoaderResult:
|
||||||
|
try:
|
||||||
|
doc = DocxDocument(file_path)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
for paragraph in doc.paragraphs:
|
||||||
|
if paragraph.text.strip():
|
||||||
|
text_parts.append(paragraph.text)
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"format": "docx",
|
||||||
|
"paragraphs": len(doc.paragraphs),
|
||||||
|
"tables": len(doc.tables),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=source_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=content),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading DOCX file: {e!s}")
|
||||||
112
lib/crewai-tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
112
lib/crewai-tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""GitHub repository content loader."""
|
||||||
|
|
||||||
|
from github import Github, GithubException
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class GithubLoader(BaseLoader):
|
||||||
|
"""Loader for GitHub repository content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a GitHub repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: GitHub repository URL
|
||||||
|
**kwargs: Additional arguments including gh_token and content_types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with repository content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
gh_token = metadata.get("gh_token")
|
||||||
|
content_types = metadata.get("content_types", ["code", "repo"])
|
||||||
|
|
||||||
|
repo_url = source.source
|
||||||
|
if not repo_url.startswith("https://github.com/"):
|
||||||
|
raise ValueError(f"Invalid GitHub URL: {repo_url}")
|
||||||
|
|
||||||
|
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
|
||||||
|
|
||||||
|
repo_name = f"{parts[0]}/{parts[1]}"
|
||||||
|
|
||||||
|
g = Github(gh_token) if gh_token else Github()
|
||||||
|
|
||||||
|
try:
|
||||||
|
repo = g.get_repo(repo_name)
|
||||||
|
except GithubException as e:
|
||||||
|
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||||
|
|
||||||
|
all_content = []
|
||||||
|
|
||||||
|
if "repo" in content_types:
|
||||||
|
all_content.append(f"Repository: {repo.full_name}")
|
||||||
|
all_content.append(f"Description: {repo.description or 'No description'}")
|
||||||
|
all_content.append(f"Language: {repo.language or 'Not specified'}")
|
||||||
|
all_content.append(f"Stars: {repo.stargazers_count}")
|
||||||
|
all_content.append(f"Forks: {repo.forks_count}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if "code" in content_types:
|
||||||
|
try:
|
||||||
|
readme = repo.get_readme()
|
||||||
|
all_content.append("README:")
|
||||||
|
all_content.append(
|
||||||
|
readme.decoded_content.decode("utf-8", errors="ignore")
|
||||||
|
)
|
||||||
|
all_content.append("")
|
||||||
|
except GithubException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
contents = repo.get_contents("")
|
||||||
|
if isinstance(contents, list):
|
||||||
|
all_content.append("Repository structure:")
|
||||||
|
for content_file in contents[:20]:
|
||||||
|
all_content.append(
|
||||||
|
f"- {content_file.path} ({content_file.type})"
|
||||||
|
)
|
||||||
|
all_content.append("")
|
||||||
|
except GithubException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "pr" in content_types:
|
||||||
|
prs = repo.get_pulls(state="open")
|
||||||
|
pr_list = list(prs[:5])
|
||||||
|
if pr_list:
|
||||||
|
all_content.append("Recent Pull Requests:")
|
||||||
|
for pr in pr_list:
|
||||||
|
all_content.append(f"- PR #{pr.number}: {pr.title}")
|
||||||
|
if pr.body:
|
||||||
|
body_preview = pr.body[:200].replace("\n", " ")
|
||||||
|
all_content.append(f" {body_preview}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if "issue" in content_types:
|
||||||
|
issues = repo.get_issues(state="open")
|
||||||
|
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
|
||||||
|
if issue_list:
|
||||||
|
all_content.append("Recent Issues:")
|
||||||
|
for issue in issue_list:
|
||||||
|
all_content.append(f"- Issue #{issue.number}: {issue.title}")
|
||||||
|
if issue.body:
|
||||||
|
body_preview = issue.body[:200].replace("\n", " ")
|
||||||
|
all_content.append(f" {body_preview}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if not all_content:
|
||||||
|
raise ValueError(f"No content could be loaded from repository: {repo_url}")
|
||||||
|
|
||||||
|
content = "\n".join(all_content)
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": repo_url,
|
||||||
|
"repo": repo_name,
|
||||||
|
"content_types": content_types,
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=repo_url, content=content),
|
||||||
|
)
|
||||||
78
lib/crewai-tools/src/crewai_tools/rag/loaders/json_loader.py
Normal file
78
lib/crewai-tools/src/crewai_tools/rag/loaders/json_loader.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class JSONLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
content = source_content.source
|
||||||
|
|
||||||
|
if source_content.is_url():
|
||||||
|
content = self._load_from_url(source_ref, kwargs)
|
||||||
|
elif source_content.path_exists():
|
||||||
|
content = self._load_from_file(source_ref)
|
||||||
|
|
||||||
|
return self._parse_json(content, source_ref)
|
||||||
|
|
||||||
|
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return (
|
||||||
|
response.text
|
||||||
|
if not self._is_json_response(response)
|
||||||
|
else json.dumps(response.json(), indent=2)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}")
|
||||||
|
|
||||||
|
def _is_json_response(self, response) -> bool:
|
||||||
|
try:
|
||||||
|
response.json()
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_from_file(self, path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
def _parse_json(self, content: str, source_ref: str) -> LoaderResult:
|
||||||
|
try:
|
||||||
|
data = json.loads(content)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
text = "\n".join(
|
||||||
|
f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()
|
||||||
|
)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
text = "\n".join(json.dumps(item, indent=0) for item in data)
|
||||||
|
else:
|
||||||
|
text = json.dumps(data, indent=0)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"format": "json",
|
||||||
|
"type": type(data).__name__,
|
||||||
|
"size": len(data) if isinstance(data, (list, dict)) else 1,
|
||||||
|
}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
text = content
|
||||||
|
metadata = {"format": "json", "parse_error": str(e)}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=text,
|
||||||
|
source=source_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||||
|
)
|
||||||
67
lib/crewai-tools/src/crewai_tools/rag/loaders/mdx_loader.py
Normal file
67
lib/crewai-tools/src/crewai_tools/rag/loaders/mdx_loader.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class MDXLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
content = source_content.source
|
||||||
|
|
||||||
|
if source_content.is_url():
|
||||||
|
content = self._load_from_url(source_ref, kwargs)
|
||||||
|
elif source_content.path_exists():
|
||||||
|
content = self._load_from_file(source_ref)
|
||||||
|
|
||||||
|
return self._parse_mdx(content, source_ref)
|
||||||
|
|
||||||
|
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}")
|
||||||
|
|
||||||
|
def _load_from_file(self, path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
|
||||||
|
cleaned_content = content
|
||||||
|
|
||||||
|
# Remove import statements
|
||||||
|
cleaned_content = re.sub(
|
||||||
|
r"^import\s+.*?\n", "", cleaned_content, flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove export statements
|
||||||
|
cleaned_content = re.sub(
|
||||||
|
r"^export\s+.*?(?:\n|$)", "", cleaned_content, flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove JSX tags (simple approach)
|
||||||
|
cleaned_content = re.sub(r"<[^>]+>", "", cleaned_content)
|
||||||
|
|
||||||
|
# Clean up extra whitespace
|
||||||
|
cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content)
|
||||||
|
cleaned_content = cleaned_content.strip()
|
||||||
|
|
||||||
|
metadata = {"format": "mdx"}
|
||||||
|
return LoaderResult(
|
||||||
|
content=cleaned_content,
|
||||||
|
source=source_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content),
|
||||||
|
)
|
||||||
100
lib/crewai-tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
100
lib/crewai-tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""MySQL database loader."""
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import pymysql
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLLoader(BaseLoader):
|
||||||
|
"""Loader for MySQL database content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a MySQL database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||||
|
**kwargs: Additional arguments including db_uri
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with database content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
db_uri = metadata.get("db_uri")
|
||||||
|
|
||||||
|
if not db_uri:
|
||||||
|
raise ValueError("Database URI is required for MySQL loader")
|
||||||
|
|
||||||
|
query = source.source
|
||||||
|
|
||||||
|
parsed = urlparse(db_uri)
|
||||||
|
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
||||||
|
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
||||||
|
|
||||||
|
connection_params = {
|
||||||
|
"host": parsed.hostname or "localhost",
|
||||||
|
"port": parsed.port or 3306,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
"cursorclass": pymysql.cursors.DictCursor,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not connection_params["database"]:
|
||||||
|
raise ValueError("Database name is required in the URI")
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = pymysql.connect(**connection_params)
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
content = "No data found in the table"
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={"source": query, "row_count": 0},
|
||||||
|
doc_id=self.generate_doc_id(
|
||||||
|
source_ref=query, content=content
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
|
||||||
|
columns = list(rows[0].keys())
|
||||||
|
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||||
|
text_parts.append(f"Total rows: {len(rows)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
for i, row in enumerate(rows, 1):
|
||||||
|
text_parts.append(f"Row {i}:")
|
||||||
|
for col, val in row.items():
|
||||||
|
if val is not None:
|
||||||
|
text_parts.append(f" {col}: {val}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": query,
|
||||||
|
"database": connection_params["database"],
|
||||||
|
"row_count": len(rows),
|
||||||
|
"columns": columns,
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
connection.close()
|
||||||
|
except pymysql.Error as e:
|
||||||
|
raise ValueError(f"MySQL database error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||||
71
lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
71
lib/crewai-tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""PDF loader for extracting text from PDF files."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class PDFLoader(BaseLoader):
|
||||||
|
"""Loader for PDF files."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract text from a PDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the PDF file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with extracted text content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the PDF file doesn't exist
|
||||||
|
ImportError: If required PDF libraries aren't installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pypdf
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import PyPDF2 as pypdf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = source.source
|
||||||
|
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||||
|
|
||||||
|
text_content = []
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": str(file_path),
|
||||||
|
"file_name": Path(file_path).name,
|
||||||
|
"file_type": "pdf",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
pdf_reader = pypdf.PdfReader(file)
|
||||||
|
metadata["num_pages"] = len(pdf_reader.pages)
|
||||||
|
|
||||||
|
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||||
|
page_text = page.extract_text()
|
||||||
|
if page_text.strip():
|
||||||
|
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error reading PDF file {file_path}: {e!s}")
|
||||||
|
|
||||||
|
if not text_content:
|
||||||
|
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||||
|
else:
|
||||||
|
content = "\n\n".join(text_content)
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=str(file_path),
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||||
|
)
|
||||||
100
lib/crewai-tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
100
lib/crewai-tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""PostgreSQL database loader."""
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresLoader(BaseLoader):
|
||||||
|
"""Loader for PostgreSQL database content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a PostgreSQL database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||||
|
**kwargs: Additional arguments including db_uri
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with database content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
db_uri = metadata.get("db_uri")
|
||||||
|
|
||||||
|
if not db_uri:
|
||||||
|
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||||
|
|
||||||
|
query = source.source
|
||||||
|
|
||||||
|
parsed = urlparse(db_uri)
|
||||||
|
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||||
|
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||||
|
|
||||||
|
connection_params = {
|
||||||
|
"host": parsed.hostname or "localhost",
|
||||||
|
"port": parsed.port or 5432,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||||
|
"cursor_factory": RealDictCursor,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not connection_params["database"]:
|
||||||
|
raise ValueError("Database name is required in the URI")
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = psycopg2.connect(**connection_params)
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
content = "No data found in the table"
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={"source": query, "row_count": 0},
|
||||||
|
doc_id=self.generate_doc_id(
|
||||||
|
source_ref=query, content=content
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
|
||||||
|
columns = list(rows[0].keys())
|
||||||
|
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||||
|
text_parts.append(f"Total rows: {len(rows)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
for i, row in enumerate(rows, 1):
|
||||||
|
text_parts.append(f"Row {i}:")
|
||||||
|
for col, val in row.items():
|
||||||
|
if val is not None:
|
||||||
|
text_parts.append(f" {col}: {val}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": query,
|
||||||
|
"database": connection_params["database"],
|
||||||
|
"row_count": len(rows),
|
||||||
|
"columns": columns,
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
connection.close()
|
||||||
|
except psycopg2.Error as e:
|
||||||
|
raise ValueError(f"PostgreSQL database error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||||
29
lib/crewai-tools/src/crewai_tools/rag/loaders/text_loader.py
Normal file
29
lib/crewai-tools/src/crewai_tools/rag/loaders/text_loader.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class TextFileLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
if not source_content.path_exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"The following file does not exist: {source_content.source}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(source_content.source, "r", encoding="utf-8") as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=source_ref,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=content),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
return LoaderResult(
|
||||||
|
content=source_content.source,
|
||||||
|
source=source_content.source_ref,
|
||||||
|
doc_id=self.generate_doc_id(content=source_content.source),
|
||||||
|
)
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class WebPageLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
url = source_content.source
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||||
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||||
|
"Accept-Language": "en-US,en;q=0.9",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=15, headers=headers)
|
||||||
|
response.encoding = response.apparent_encoding
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.decompose()
|
||||||
|
|
||||||
|
text = soup.get_text(" ")
|
||||||
|
text = re.sub("[ \t]+", " ", text)
|
||||||
|
text = re.sub("\\s+\n\\s+", "\n", text)
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
title = (
|
||||||
|
soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||||
|
)
|
||||||
|
metadata = {
|
||||||
|
"url": url,
|
||||||
|
"title": title,
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"content_type": response.headers.get("content-type", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=text,
|
||||||
|
source=url,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=url, content=text),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading webpage {url}: {e!s}")
|
||||||
64
lib/crewai-tools/src/crewai_tools/rag/loaders/xml_loader.py
Normal file
64
lib/crewai-tools/src/crewai_tools/rag/loaders/xml_loader.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class XMLLoader(BaseLoader):
|
||||||
|
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
source_ref = source_content.source_ref
|
||||||
|
content = source_content.source
|
||||||
|
|
||||||
|
if source_content.is_url():
|
||||||
|
content = self._load_from_url(source_ref, kwargs)
|
||||||
|
elif source_content.path_exists():
|
||||||
|
content = self._load_from_file(source_ref)
|
||||||
|
|
||||||
|
return self._parse_xml(content, source_ref)
|
||||||
|
|
||||||
|
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = kwargs.get(
|
||||||
|
"headers",
|
||||||
|
{
|
||||||
|
"Accept": "application/xml, text/xml, text/plain",
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error fetching XML from URL {url}: {e!s}")
|
||||||
|
|
||||||
|
def _load_from_file(self, path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||||
|
try:
|
||||||
|
if content.strip().startswith("<"):
|
||||||
|
root = ET.fromstring(content)
|
||||||
|
else:
|
||||||
|
root = ET.parse(source_ref).getroot()
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
for text_content in root.itertext():
|
||||||
|
if text_content and text_content.strip():
|
||||||
|
text_parts.append(text_content.strip())
|
||||||
|
|
||||||
|
text = "\n".join(text_parts)
|
||||||
|
metadata = {"format": "xml", "root_tag": root.tag}
|
||||||
|
except ET.ParseError as e:
|
||||||
|
text = content
|
||||||
|
metadata = {"format": "xml", "parse_error": str(e)}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=text,
|
||||||
|
source=source_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||||
|
)
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""YouTube channel loader for extracting content from YouTube channels."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class YoutubeChannelLoader(BaseLoader):
|
||||||
|
"""Loader for YouTube channels."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract content from a YouTube channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the YouTube channel URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with channel content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If required YouTube libraries aren't installed
|
||||||
|
ValueError: If the URL is not a valid YouTube channel URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pytube import Channel
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"YouTube channel support requires pytube. Install with: uv add pytube"
|
||||||
|
)
|
||||||
|
|
||||||
|
channel_url = source.source
|
||||||
|
|
||||||
|
if not any(
|
||||||
|
pattern in channel_url
|
||||||
|
for pattern in [
|
||||||
|
"youtube.com/channel/",
|
||||||
|
"youtube.com/c/",
|
||||||
|
"youtube.com/@",
|
||||||
|
"youtube.com/user/",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": channel_url,
|
||||||
|
"data_type": "youtube_channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = Channel(channel_url)
|
||||||
|
|
||||||
|
metadata["channel_name"] = channel.channel_name
|
||||||
|
metadata["channel_id"] = channel.channel_id
|
||||||
|
|
||||||
|
max_videos = kwargs.get("max_videos", 10)
|
||||||
|
video_urls = list(channel.video_urls)[:max_videos]
|
||||||
|
metadata["num_videos_loaded"] = len(video_urls)
|
||||||
|
metadata["total_videos"] = len(list(channel.video_urls))
|
||||||
|
|
||||||
|
content_parts = [
|
||||||
|
f"YouTube Channel: {channel.channel_name}",
|
||||||
|
f"Channel ID: {channel.channel_id}",
|
||||||
|
f"Total Videos: {metadata['total_videos']}",
|
||||||
|
f"Videos Loaded: {metadata['num_videos_loaded']}",
|
||||||
|
"\n--- Video Summaries ---\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pytube import YouTube
|
||||||
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
|
|
||||||
|
for i, video_url in enumerate(video_urls, 1):
|
||||||
|
try:
|
||||||
|
video_id = self._extract_video_id(video_url)
|
||||||
|
if not video_id:
|
||||||
|
continue
|
||||||
|
yt = YouTube(video_url)
|
||||||
|
title = yt.title or f"Video {i}"
|
||||||
|
description = (
|
||||||
|
yt.description[:200] if yt.description else "No description"
|
||||||
|
)
|
||||||
|
|
||||||
|
content_parts.append(f"\n{i}. {title}")
|
||||||
|
content_parts.append(f" URL: {video_url}")
|
||||||
|
content_parts.append(f" Description: {description}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api = YouTubeTranscriptApi()
|
||||||
|
transcript_list = api.list(video_id)
|
||||||
|
transcript = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_transcript(["en"])
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
transcript = (
|
||||||
|
transcript_list.find_generated_transcript(
|
||||||
|
["en"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
transcript = next(iter(transcript_list), None)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
transcript_data = transcript.fetch()
|
||||||
|
text_parts = []
|
||||||
|
char_count = 0
|
||||||
|
for entry in transcript_data:
|
||||||
|
text = (
|
||||||
|
entry.text.strip()
|
||||||
|
if hasattr(entry, "text")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
text_parts.append(text)
|
||||||
|
char_count += len(text)
|
||||||
|
if char_count > 500:
|
||||||
|
break
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
preview = " ".join(text_parts)[:500]
|
||||||
|
content_parts.append(
|
||||||
|
f" Transcript Preview: {preview}..."
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
content_parts.append(" Transcript: Not available")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
content_parts.append(f"\n{i}. Error loading video: {e!s}")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
for i, video_url in enumerate(video_urls, 1):
|
||||||
|
content_parts.append(f"\n{i}. {video_url}")
|
||||||
|
|
||||||
|
content = "\n".join(content_parts)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to load YouTube channel {channel_url}: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=channel_url,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=channel_url, content=content),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_video_id(self, url: str) -> str | None:
|
||||||
|
"""Extract video ID from YouTube URL."""
|
||||||
|
patterns = [
|
||||||
|
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
"""YouTube video loader for extracting transcripts from YouTube videos."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class YoutubeVideoLoader(BaseLoader):
|
||||||
|
"""Loader for YouTube videos."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract transcript from a YouTube video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the YouTube URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with transcript content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If required YouTube libraries aren't installed
|
||||||
|
ValueError: If the URL is not a valid YouTube video URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"YouTube support requires youtube-transcript-api. "
|
||||||
|
"Install with: uv add youtube-transcript-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = source.source
|
||||||
|
video_id = self._extract_video_id(video_url)
|
||||||
|
|
||||||
|
if not video_id:
|
||||||
|
raise ValueError(f"Invalid YouTube URL: {video_url}")
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": video_url,
|
||||||
|
"video_id": video_id,
|
||||||
|
"data_type": "youtube_video",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
api = YouTubeTranscriptApi()
|
||||||
|
transcript_list = api.list(video_id)
|
||||||
|
|
||||||
|
transcript = None
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_transcript(["en"])
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_generated_transcript(["en"])
|
||||||
|
except:
|
||||||
|
transcript = next(iter(transcript_list))
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
metadata["language"] = transcript.language
|
||||||
|
metadata["is_generated"] = transcript.is_generated
|
||||||
|
|
||||||
|
transcript_data = transcript.fetch()
|
||||||
|
|
||||||
|
text_content = []
|
||||||
|
for entry in transcript_data:
|
||||||
|
text = entry.text.strip() if hasattr(entry, "text") else ""
|
||||||
|
if text:
|
||||||
|
text_content.append(text)
|
||||||
|
|
||||||
|
content = " ".join(text_content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pytube import YouTube
|
||||||
|
|
||||||
|
yt = YouTube(video_url)
|
||||||
|
metadata["title"] = yt.title
|
||||||
|
metadata["author"] = yt.author
|
||||||
|
metadata["length_seconds"] = yt.length
|
||||||
|
metadata["description"] = (
|
||||||
|
yt.description[:500] if yt.description else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if yt.title:
|
||||||
|
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No transcript available for YouTube video: {video_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to extract transcript from YouTube video {video_id}: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=video_url,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=video_url, content=content),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_video_id(self, url: str) -> str | None:
|
||||||
|
"""Extract video ID from various YouTube URL formats."""
|
||||||
|
patterns = [
|
||||||
|
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if hostname:
|
||||||
|
hostname_lower = hostname.lower()
|
||||||
|
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
|
||||||
|
if (
|
||||||
|
hostname_lower == "youtube.com"
|
||||||
|
or hostname_lower.endswith(".youtube.com")
|
||||||
|
or hostname_lower == "youtu.be"
|
||||||
|
):
|
||||||
|
query_params = parse_qs(parsed.query)
|
||||||
|
if "v" in query_params:
|
||||||
|
return query_params["v"][0]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
31
lib/crewai-tools/src/crewai_tools/rag/misc.py
Normal file
31
lib/crewai-tools/src/crewai_tools/rag/misc.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sha256(content: str) -> str:
|
||||||
|
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Sanitize metadata to ensure ChromaDB compatibility.
|
||||||
|
|
||||||
|
ChromaDB only accepts str, int, float, or bool values in metadata.
|
||||||
|
This function converts other types to strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Dictionary of metadata to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized metadata dictionary with only ChromaDB-compatible types
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
sanitized[key] = value
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
# Convert lists/tuples to pipe-separated strings
|
||||||
|
sanitized[key] = " | ".join(str(v) for v in value)
|
||||||
|
else:
|
||||||
|
# Convert other types to string
|
||||||
|
sanitized[key] = str(value)
|
||||||
|
return sanitized
|
||||||
47
lib/crewai-tools/src/crewai_tools/rag/source_content.py
Normal file
47
lib/crewai-tools/src/crewai_tools/rag/source_content.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from crewai_tools.rag.misc import compute_sha256
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
|
class SourceContent:
|
||||||
|
def __init__(self, source: str | Path):
|
||||||
|
self.source = str(source)
|
||||||
|
|
||||||
|
def is_url(self) -> bool:
|
||||||
|
if not isinstance(self.source, str):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(self.source)
|
||||||
|
return bool(parsed_url.scheme and parsed_url.netloc)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def path_exists(self) -> bool:
|
||||||
|
return os.path.exists(self.source)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def data_type(self) -> "DataType":
|
||||||
|
from crewai_tools.rag.data_types import DataTypes
|
||||||
|
|
||||||
|
return DataTypes.from_content(self.source)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def source_ref(self) -> str:
|
||||||
|
""" "
|
||||||
|
Returns the source reference for the content.
|
||||||
|
If the content is a URL or a local file, returns the source.
|
||||||
|
Otherwise, returns the hash of the content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.is_url() or self.path_exists():
|
||||||
|
return self.source
|
||||||
|
|
||||||
|
return compute_sha256(self.source)
|
||||||
127
lib/crewai-tools/src/crewai_tools/tools/__init__.py
Normal file
127
lib/crewai-tools/src/crewai_tools/tools/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from .ai_mind_tool.ai_mind_tool import AIMindTool
|
||||||
|
from .apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||||
|
from .arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||||
|
from .brave_search_tool.brave_search_tool import BraveSearchTool
|
||||||
|
from .brightdata_tool import (
|
||||||
|
BrightDataDatasetTool,
|
||||||
|
BrightDataSearchTool,
|
||||||
|
BrightDataWebUnlockerTool,
|
||||||
|
)
|
||||||
|
from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool
|
||||||
|
from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool
|
||||||
|
from .code_interpreter_tool.code_interpreter_tool import CodeInterpreterTool
|
||||||
|
from .composio_tool.composio_tool import ComposioTool
|
||||||
|
from .contextualai_create_agent_tool.contextual_create_agent_tool import (
|
||||||
|
ContextualAICreateAgentTool,
|
||||||
|
)
|
||||||
|
from .contextualai_parse_tool.contextual_parse_tool import ContextualAIParseTool
|
||||||
|
from .contextualai_query_tool.contextual_query_tool import ContextualAIQueryTool
|
||||||
|
from .contextualai_rerank_tool.contextual_rerank_tool import ContextualAIRerankTool
|
||||||
|
from .couchbase_tool.couchbase_tool import CouchbaseFTSVectorSearchTool
|
||||||
|
from .crewai_enterprise_tools.crewai_enterprise_tools import CrewaiEnterpriseTools
|
||||||
|
from .crewai_platform_tools.crewai_platform_tools import CrewaiPlatformTools
|
||||||
|
from .csv_search_tool.csv_search_tool import CSVSearchTool
|
||||||
|
from .dalle_tool.dalle_tool import DallETool
|
||||||
|
from .databricks_query_tool.databricks_query_tool import DatabricksQueryTool
|
||||||
|
from .directory_read_tool.directory_read_tool import DirectoryReadTool
|
||||||
|
from .directory_search_tool.directory_search_tool import DirectorySearchTool
|
||||||
|
from .docx_search_tool.docx_search_tool import DOCXSearchTool
|
||||||
|
from .exa_tools.exa_search_tool import EXASearchTool
|
||||||
|
from .file_read_tool.file_read_tool import FileReadTool
|
||||||
|
from .file_writer_tool.file_writer_tool import FileWriterTool
|
||||||
|
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
||||||
|
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||||
|
FirecrawlCrawlWebsiteTool,
|
||||||
|
)
|
||||||
|
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||||
|
FirecrawlScrapeWebsiteTool,
|
||||||
|
)
|
||||||
|
from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||||
|
from .generate_crewai_automation_tool.generate_crewai_automation_tool import (
|
||||||
|
GenerateCrewaiAutomationTool,
|
||||||
|
)
|
||||||
|
from .github_search_tool.github_search_tool import GithubSearchTool
|
||||||
|
from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool
|
||||||
|
from .invoke_crewai_automation_tool.invoke_crewai_automation_tool import (
|
||||||
|
InvokeCrewAIAutomationTool,
|
||||||
|
)
|
||||||
|
from .json_search_tool.json_search_tool import JSONSearchTool
|
||||||
|
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||||
|
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||||
|
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||||
|
from .mongodb_vector_search_tool import (
|
||||||
|
MongoDBToolSchema,
|
||||||
|
MongoDBVectorSearchConfig,
|
||||||
|
MongoDBVectorSearchTool,
|
||||||
|
)
|
||||||
|
from .multion_tool.multion_tool import MultiOnTool
|
||||||
|
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||||
|
from .nl2sql.nl2sql_tool import NL2SQLTool
|
||||||
|
from .ocr_tool.ocr_tool import OCRTool
|
||||||
|
from .oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
|
||||||
|
OxylabsAmazonProductScraperTool,
|
||||||
|
)
|
||||||
|
from .oxylabs_amazon_search_scraper_tool.oxylabs_amazon_search_scraper_tool import (
|
||||||
|
OxylabsAmazonSearchScraperTool,
|
||||||
|
)
|
||||||
|
from .oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
|
||||||
|
OxylabsGoogleSearchScraperTool,
|
||||||
|
)
|
||||||
|
from .oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
|
||||||
|
OxylabsUniversalScraperTool,
|
||||||
|
)
|
||||||
|
from .parallel_tools import (
|
||||||
|
ParallelSearchTool,
|
||||||
|
)
|
||||||
|
from .patronus_eval_tool import (
|
||||||
|
PatronusEvalTool,
|
||||||
|
PatronusLocalEvaluatorTool,
|
||||||
|
PatronusPredefinedCriteriaEvalTool,
|
||||||
|
)
|
||||||
|
from .pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||||
|
from .pg_search_tool.pg_search_tool import PGSearchTool
|
||||||
|
from .qdrant_vector_search_tool.qdrant_search_tool import QdrantVectorSearchTool
|
||||||
|
from .rag.rag_tool import RagTool
|
||||||
|
from .scrape_element_from_website.scrape_element_from_website import (
|
||||||
|
ScrapeElementFromWebsiteTool,
|
||||||
|
)
|
||||||
|
from .scrape_website_tool.scrape_website_tool import ScrapeWebsiteTool
|
||||||
|
from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import (
|
||||||
|
ScrapegraphScrapeTool,
|
||||||
|
ScrapegraphScrapeToolSchema,
|
||||||
|
)
|
||||||
|
from .scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import (
|
||||||
|
ScrapflyScrapeWebsiteTool,
|
||||||
|
)
|
||||||
|
from .selenium_scraping_tool.selenium_scraping_tool import SeleniumScrapingTool
|
||||||
|
from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool
|
||||||
|
from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool
|
||||||
|
from .serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||||
|
from .serper_scrape_website_tool.serper_scrape_website_tool import (
|
||||||
|
SerperScrapeWebsiteTool,
|
||||||
|
)
|
||||||
|
from .serply_api_tool.serply_job_search_tool import SerplyJobSearchTool
|
||||||
|
from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
|
||||||
|
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
|
||||||
|
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
|
||||||
|
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
|
||||||
|
from .singlestore_search_tool import SingleStoreSearchTool
|
||||||
|
from .snowflake_search_tool import (
|
||||||
|
SnowflakeConfig,
|
||||||
|
SnowflakeSearchTool,
|
||||||
|
SnowflakeSearchToolInput,
|
||||||
|
)
|
||||||
|
from .spider_tool.spider_tool import SpiderTool
|
||||||
|
from .stagehand_tool.stagehand_tool import StagehandTool
|
||||||
|
from .tavily_extractor_tool.tavily_extractor_tool import TavilyExtractorTool
|
||||||
|
from .tavily_search_tool.tavily_search_tool import TavilySearchTool
|
||||||
|
from .txt_search_tool.txt_search_tool import TXTSearchTool
|
||||||
|
from .vision_tool.vision_tool import VisionTool
|
||||||
|
from .weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||||
|
from .website_search.website_search_tool import WebsiteSearchTool
|
||||||
|
from .xml_search_tool.xml_search_tool import XMLSearchTool
|
||||||
|
from .youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||||
|
YoutubeChannelSearchTool,
|
||||||
|
)
|
||||||
|
from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool
|
||||||
|
from .zapier_action_tool.zapier_action_tool import ZapierActionTools
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# AIMind Tool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
[Minds](https://mindsdb.com/minds) are AI systems provided by [MindsDB](https://mindsdb.com/) that work similarly to large language models (LLMs) but go beyond by answering any question from any data.
|
||||||
|
|
||||||
|
This is accomplished by selecting the most relevant data for an answer using parametric search, understanding the meaning and providing responses within the correct context through semantic search, and finally, delivering precise answers by analyzing data and using machine learning (ML) models.
|
||||||
|
|
||||||
|
The `AIMindTool` can be used to query data sources in natural language by simply configuring their connection parameters.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. Install the `crewai[tools]` package:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install the Minds SDK:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install minds-sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Sign for a Minds account [here](https://mdb.ai/register), and obtain an API key.
|
||||||
|
|
||||||
|
4. Set the Minds API key in an environment variable named `MINDS_API_KEY`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import AIMindTool
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the AIMindTool.
|
||||||
|
aimind_tool = AIMindTool(
|
||||||
|
datasources=[
|
||||||
|
{
|
||||||
|
"description": "house sales data",
|
||||||
|
"engine": "postgres",
|
||||||
|
"connection_data": {
|
||||||
|
"user": "demo_user",
|
||||||
|
"password": "demo_password",
|
||||||
|
"host": "samples.mindsdb.com",
|
||||||
|
"port": 5432,
|
||||||
|
"database": "demo",
|
||||||
|
"schema": "demo_data"
|
||||||
|
},
|
||||||
|
"tables": ["house_sales"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
aimind_tool.run("How many 3 bedroom houses were sold in 2008?")
|
||||||
|
```
|
||||||
|
|
||||||
|
The `datasources` parameter is a list of dictionaries, each containing the following keys:
|
||||||
|
|
||||||
|
- `description`: A description of the data contained in the datasource.
|
||||||
|
- `engine`: The engine (or type) of the datasource. Find a list of supported engines in the link below.
|
||||||
|
- `connection_data`: A dictionary containing the connection parameters for the datasource. Find a list of connection parameters for each engine in the link below.
|
||||||
|
- `tables`: A list of tables that the data source will use. This is optional and can be omitted if all tables in the data source are to be used.
|
||||||
|
|
||||||
|
A list of supported data sources and their connection parameters can be found [here](https://docs.mdb.ai/docs/data_sources).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.project import agent
|
||||||
|
|
||||||
|
|
||||||
|
# Define an agent with the AIMindTool.
|
||||||
|
@agent
|
||||||
|
def researcher(self) -> Agent:
|
||||||
|
return Agent(
|
||||||
|
config=self.agents_config["researcher"],
|
||||||
|
allow_delegation=False,
|
||||||
|
tools=[aimind_tool]
|
||||||
|
)
|
||||||
|
```
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AIMindToolConstants:
|
||||||
|
MINDS_API_BASE_URL = "https://mdb.ai/"
|
||||||
|
MIND_NAME_PREFIX = "crwai_mind_"
|
||||||
|
DATASOURCE_NAME_PREFIX = "crwai_ds_"
|
||||||
|
|
||||||
|
|
||||||
|
class AIMindToolInputSchema(BaseModel):
|
||||||
|
"""Input for AIMind Tool."""
|
||||||
|
|
||||||
|
query: str = Field(description="Question in natural language to ask the AI-Mind")
|
||||||
|
|
||||||
|
|
||||||
|
class AIMindTool(BaseTool):
|
||||||
|
name: str = "AIMind Tool"
|
||||||
|
description: str = (
|
||||||
|
"A wrapper around [AI-Minds](https://mindsdb.com/minds). "
|
||||||
|
"Useful for when you need answers to questions from your data, stored in "
|
||||||
|
"data sources including PostgreSQL, MySQL, MariaDB, ClickHouse, Snowflake "
|
||||||
|
"and Google BigQuery. "
|
||||||
|
"Input should be a question in natural language."
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = AIMindToolInputSchema
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
datasources: Optional[List[Dict[str, Any]]] = None
|
||||||
|
mind_name: Optional[str] = None
|
||||||
|
package_dependencies: List[str] = ["minds-sdk"]
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(name="MINDS_API_KEY", description="API key for AI-Minds", required=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"API key must be provided either through constructor or MINDS_API_KEY environment variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from minds.client import Client # type: ignore
|
||||||
|
from minds.datasources import DatabaseConfig # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`minds_sdk` package not found, please run `pip install minds-sdk`"
|
||||||
|
)
|
||||||
|
|
||||||
|
minds_client = Client(api_key=self.api_key)
|
||||||
|
|
||||||
|
# Convert the datasources to DatabaseConfig objects.
|
||||||
|
datasources = []
|
||||||
|
for datasource in self.datasources:
|
||||||
|
config = DatabaseConfig(
|
||||||
|
name=f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}_{secrets.token_hex(5)}",
|
||||||
|
engine=datasource["engine"],
|
||||||
|
description=datasource["description"],
|
||||||
|
connection_data=datasource["connection_data"],
|
||||||
|
tables=datasource["tables"],
|
||||||
|
)
|
||||||
|
datasources.append(config)
|
||||||
|
|
||||||
|
# Generate a random name for the Mind.
|
||||||
|
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
|
||||||
|
|
||||||
|
mind = minds_client.minds.create(
|
||||||
|
name=name, datasources=datasources, replace=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mind_name = mind.name
|
||||||
|
|
||||||
|
def _run(self, query: str):
|
||||||
|
# Run the query on the AI-Mind.
|
||||||
|
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||||
|
openai_client = OpenAI(
|
||||||
|
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = openai_client.chat.completions.create(
|
||||||
|
model=self.mind_name,
|
||||||
|
messages=[{"role": "user", "content": query}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion.choices[0].message.content
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
# ApifyActorsTool
|
||||||
|
|
||||||
|
Integrate [Apify Actors](https://apify.com/actors) into your CrewAI workflows.
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
The `ApifyActorsTool` connects [Apify Actors](https://apify.com/actors), cloud-based programs for web scraping and automation, to your CrewAI workflows.
|
||||||
|
Use any of the 4,000+ Actors on [Apify Store](https://apify.com/store) for use cases such as extracting data from social media, search engines, online maps, e-commerce sites, travel portals, or general websites.
|
||||||
|
|
||||||
|
For details, see the [Apify CrewAI integration](https://docs.apify.com/platform/integrations/crewai) in Apify documentation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To use `ApifyActorsTool`, install the necessary packages and set up your Apify API token. Follow the [Apify API documentation](https://docs.apify.com/platform/integrations/api) for steps to obtain the token.
|
||||||
|
|
||||||
|
### Steps
|
||||||
|
|
||||||
|
1. **Install dependencies**
|
||||||
|
Install `crewai[tools]` and `langchain-apify`:
|
||||||
|
```bash
|
||||||
|
pip install 'crewai[tools]' langchain-apify
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Set your API token**
|
||||||
|
Export the token as an environment variable:
|
||||||
|
```bash
|
||||||
|
export APIFY_API_TOKEN='your-api-token-here'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
Use the `ApifyActorsTool` manually to run the [RAG Web Browser Actor](https://apify.com/apify/rag-web-browser) to perform a web search:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import ApifyActorsTool
|
||||||
|
|
||||||
|
# Initialize the tool with an Apify Actor
|
||||||
|
tool = ApifyActorsTool(actor_name="apify/rag-web-browser")
|
||||||
|
|
||||||
|
# Run the tool with input parameters
|
||||||
|
results = tool.run(run_input={"query": "What is CrewAI?", "maxResults": 5})
|
||||||
|
|
||||||
|
# Process the results
|
||||||
|
for result in results:
|
||||||
|
print(f"URL: {result['metadata']['url']}")
|
||||||
|
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected output
|
||||||
|
|
||||||
|
Here is the output from running the code above:
|
||||||
|
|
||||||
|
```text
|
||||||
|
URL: https://www.example.com/crewai-intro
|
||||||
|
Content: CrewAI is a framework for building AI-powered workflows...
|
||||||
|
URL: https://docs.crewai.com/
|
||||||
|
Content: Official documentation for CrewAI...
|
||||||
|
```
|
||||||
|
|
||||||
|
The `ApifyActorsTool` automatically fetches the Actor definition and input schema from Apify using the provided `actor_name` and then constructs the tool description and argument schema. This means you need to specify only a valid `actor_name`, and the tool handles the rest when used with agents—no need to specify the `run_input`. Here's how it works:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai_tools import ApifyActorsTool
|
||||||
|
|
||||||
|
rag_browser = ApifyActorsTool(actor_name="apify/rag-web-browser")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Research Analyst",
|
||||||
|
goal="Find and summarize information about specific topics",
|
||||||
|
backstory="You are an experienced researcher with attention to detail",
|
||||||
|
tools=[rag_browser],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run other Actors from [Apify Store](https://apify.com/store) simply by changing the `actor_name` and, when using it manually, adjusting the `run_input` based on the Actor input schema.
|
||||||
|
|
||||||
|
For an example of usage with agents, see the [CrewAI Actor template](https://apify.com/templates/python-crewai).
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The `ApifyActorsTool` requires these inputs to work:
|
||||||
|
|
||||||
|
- **`actor_name`**
|
||||||
|
The ID of the Apify Actor to run, e.g., `"apify/rag-web-browser"`. Browse all Actors on [Apify Store](https://apify.com/store).
|
||||||
|
- **`run_input`**
|
||||||
|
A dictionary of input parameters for the Actor when running the tool manually.
|
||||||
|
- For example, for the `apify/rag-web-browser` Actor: `{"query": "search term", "maxResults": 5}`
|
||||||
|
- See the Actor's [input schema](https://apify.com/apify/rag-web-browser/input-schema) for the list of input parameters.
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- **[Apify](https://apify.com/)**: Explore the Apify platform.
|
||||||
|
- **[How to build an AI agent on Apify](https://blog.apify.com/how-to-build-an-ai-agent/)** - A complete step-by-step guide to creating, publishing, and monetizing AI agents on the Apify platform.
|
||||||
|
- **[RAG Web Browser Actor](https://apify.com/apify/rag-web-browser)**: A popular Actor for web search for LLMs.
|
||||||
|
- **[CrewAI Integration Guide](https://docs.apify.com/platform/integrations/crewai)**: Follow the official guide for integrating Apify and CrewAI.
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_apify import ApifyActorsTool as _ApifyActorsTool
|
||||||
|
|
||||||
|
|
||||||
|
class ApifyActorsTool(BaseTool):
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="APIFY_API_TOKEN",
|
||||||
|
description="API token for Apify platform access",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
"""Tool that runs Apify Actors.
|
||||||
|
|
||||||
|
To use, you should have the environment variable `APIFY_API_TOKEN` set
|
||||||
|
with your API key.
|
||||||
|
|
||||||
|
For details, see https://docs.apify.com/platform/integrations/crewai
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_name (str): The name of the Apify Actor to run.
|
||||||
|
*args: Variable length argument list passed to BaseTool.
|
||||||
|
**kwargs: Arbitrary keyword arguments passed to BaseTool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: Results from the Actor execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `APIFY_API_TOKEN` is not set or if the tool is not initialized.
|
||||||
|
ImportError: If `langchain_apify` package is not installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from crewai_tools import ApifyActorsTool
|
||||||
|
|
||||||
|
tool = ApifyActorsTool(actor_name="apify/rag-web-browser")
|
||||||
|
|
||||||
|
results = tool.run(run_input={"query": "What is CrewAI?", "maxResults": 5})
|
||||||
|
for result in results:
|
||||||
|
print(f"URL: {result['metadata']['url']}")
|
||||||
|
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||||
|
"""
|
||||||
|
actor_tool: "_ApifyActorsTool" = Field(description="Apify Actor Tool")
|
||||||
|
package_dependencies: List[str] = ["langchain-apify"]
|
||||||
|
|
||||||
|
def __init__(self, actor_name: str, *args: Any, **kwargs: Any) -> None:
|
||||||
|
if not os.environ.get("APIFY_API_TOKEN"):
|
||||||
|
msg = (
|
||||||
|
"APIFY_API_TOKEN environment variable is not set. "
|
||||||
|
"Please set it to your API key, to learn how to get it, "
|
||||||
|
"see https://docs.apify.com/platform/integrations/api"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_apify import ApifyActorsTool as _ApifyActorsTool
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import langchain_apify python package. "
|
||||||
|
"Please install it with `pip install langchain-apify` or `uv add langchain-apify`."
|
||||||
|
)
|
||||||
|
actor_tool = _ApifyActorsTool(actor_name)
|
||||||
|
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"name": actor_tool.name,
|
||||||
|
"description": actor_tool.description,
|
||||||
|
"args_schema": actor_tool.args_schema,
|
||||||
|
"actor_tool": actor_tool,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _run(self, run_input: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Run the Actor tool with the given input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: Results from the Actor execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If 'actor_tool' is not initialized.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.actor_tool._run(run_input)
|
||||||
|
except Exception as e:
|
||||||
|
msg = (
|
||||||
|
f"Failed to run ApifyActorsTool {self.name}. "
|
||||||
|
"Please check your Apify account Actor run logs for more details."
|
||||||
|
f"Error: {e}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg) from e
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
### Example 1: Fetching Research Papers from arXiv with CrewAI
|
||||||
|
|
||||||
|
This example demonstrates how to build a simple CrewAI workflow that automatically searches for and downloads academic papers from [arXiv.org](https://arxiv.org). The setup uses:
|
||||||
|
|
||||||
|
* A custom `ArxivPaperTool` to fetch metadata and download PDFs
|
||||||
|
* A single `Agent` tasked with locating relevant papers based on a given research topic
|
||||||
|
* A `Task` to define the data retrieval and download process
|
||||||
|
* A sequential `Crew` to orchestrate execution
|
||||||
|
|
||||||
|
The downloaded PDFs are saved to a local directory (`./DOWNLOADS`). Filenames are optionally based on sanitized paper titles, ensuring compatibility with your operating system.
|
||||||
|
|
||||||
|
> The saved PDFs can be further used in **downstream tasks**, such as:
|
||||||
|
>
|
||||||
|
> * **RAG (Retrieval-Augmented Generation)**
|
||||||
|
> * **Summarization**
|
||||||
|
> * **Citation extraction**
|
||||||
|
> * **Embedding-based search or analysis**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
from crewai import Agent, Task, Crew, Process, LLM
|
||||||
|
from crewai_tools import ArxivPaperTool
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="ollama/llama3.1",
|
||||||
|
base_url="http://localhost:11434",
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
topic = "Crew AI"
|
||||||
|
max_results = 3
|
||||||
|
save_dir = "./DOWNLOADS"
|
||||||
|
use_title_as_filename = True
|
||||||
|
|
||||||
|
tool = ArxivPaperTool(
|
||||||
|
download_pdfs=True,
|
||||||
|
save_dir=save_dir,
|
||||||
|
use_title_as_filename=True
|
||||||
|
)
|
||||||
|
tool.result_as_answer = True #Required,otherwise
|
||||||
|
|
||||||
|
|
||||||
|
arxiv_paper_fetch = Agent(
|
||||||
|
role="Arxiv Data Fetcher",
|
||||||
|
goal=f"Retrieve relevant papers from arXiv based on a research topic {topic} and maximum number of papers to be downloaded is{max_results},try to use title as filename {use_title_as_filename} and download PDFs to {save_dir},",
|
||||||
|
backstory="An expert in scientific data retrieval, skilled in extracting academic content from arXiv.",
|
||||||
|
# tools=[ArxivPaperTool()],
|
||||||
|
llm=llm,
|
||||||
|
verbose=True,
|
||||||
|
allow_delegation=False
|
||||||
|
)
|
||||||
|
fetch_task = Task(
|
||||||
|
description=(
|
||||||
|
f"Search arXiv for the topic '{topic}' and fetch up to {max_results} papers. "
|
||||||
|
f"Download PDFs for analysis and store them at {save_dir}."
|
||||||
|
),
|
||||||
|
expected_output="PDFs saved to disk for downstream agents.",
|
||||||
|
agent=arxiv_paper_fetch,
|
||||||
|
tools=[tool], # Use the actual tool instance here
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
pdf_qa_crew = Crew(
|
||||||
|
agents=[arxiv_paper_fetch],
|
||||||
|
tasks=[fetch_task],
|
||||||
|
process=Process.sequential,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
result = pdf_qa_crew.kickoff()
|
||||||
|
|
||||||
|
print(f"\n🤖 Answer:\n\n{result.raw}\n")
|
||||||
|
```
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
# ArxivPaperTool
|
||||||
|
|
||||||
|
|
||||||
|
# 📚 ArxivPaperTool
|
||||||
|
|
||||||
|
The **ArxivPaperTool** is a utility for fetching metadata and optionally downloading PDFs of academic papers from the [arXiv](https://arxiv.org) platform using its public API. It supports configurable queries, batch retrieval, PDF downloading, and clean formatting for summaries and metadata. This tool is particularly useful for researchers, students, academic agents, and AI tools performing automated literature reviews.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This tool:
|
||||||
|
|
||||||
|
* Accepts a **search query** and retrieves a list of papers from arXiv.
|
||||||
|
* Allows configuration of the **maximum number of results** to fetch.
|
||||||
|
* Optionally downloads the **PDFs** of the matched papers.
|
||||||
|
* Lets you specify whether to name PDF files using the **arXiv ID** or **paper title**.
|
||||||
|
* Saves downloaded files into a **custom or default directory**.
|
||||||
|
* Returns structured summaries of all fetched papers including metadata.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
|
||||||
|
| Argument | Type | Required | Description |
|
||||||
|
| ----------------------- | ------ | -------- | --------------------------------------------------------------------------------- |
|
||||||
|
| `search_query` | `str` | ✅ | Search query string (e.g., `"transformer neural network"`). |
|
||||||
|
| `max_results` | `int` | ✅ | Number of results to fetch (between 1 and 100). |
|
||||||
|
| `download_pdfs` | `bool` | ❌ | Whether to download the corresponding PDFs. Defaults to `False`. |
|
||||||
|
| `save_dir` | `str` | ❌ | Directory to save PDFs (created if it doesn’t exist). Defaults to `./arxiv_pdfs`. |
|
||||||
|
| `use_title_as_filename` | `bool` | ❌ | Use the paper title as the filename (sanitized). Defaults to `False`. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📄 `ArxivPaperTool` Usage Examples
|
||||||
|
|
||||||
|
This document shows how to use the `ArxivPaperTool` to fetch research paper metadata from arXiv and optionally download PDFs.
|
||||||
|
|
||||||
|
### 🔧 Tool Initialization
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import ArxivPaperTool
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example 1: Fetch Metadata Only (No Downloads)
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = ArxivPaperTool()
|
||||||
|
result = tool._run(
|
||||||
|
search_query="deep learning",
|
||||||
|
max_results=1
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example 2: Fetch and Download PDFs (arXiv ID as Filename)
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = ArxivPaperTool(download_pdfs=True)
|
||||||
|
result = tool._run(
|
||||||
|
search_query="transformer models",
|
||||||
|
max_results=2
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example 3: Download PDFs into a Custom Directory
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = ArxivPaperTool(
|
||||||
|
download_pdfs=True,
|
||||||
|
save_dir="./my_papers"
|
||||||
|
)
|
||||||
|
result = tool._run(
|
||||||
|
search_query="graph neural networks",
|
||||||
|
max_results=2
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example 4: Use Paper Titles as Filenames
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = ArxivPaperTool(
|
||||||
|
download_pdfs=True,
|
||||||
|
use_title_as_filename=True
|
||||||
|
)
|
||||||
|
result = tool._run(
|
||||||
|
search_query="vision transformers",
|
||||||
|
max_results=1
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example 5: All Options Combined
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = ArxivPaperTool(
|
||||||
|
download_pdfs=True,
|
||||||
|
save_dir="./downloads",
|
||||||
|
use_title_as_filename=True
|
||||||
|
)
|
||||||
|
result = tool._run(
|
||||||
|
search_query="stable diffusion",
|
||||||
|
max_results=3
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Run via `__main__`
|
||||||
|
|
||||||
|
Your file can also include:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = ArxivPaperTool(
|
||||||
|
download_pdfs=True,
|
||||||
|
save_dir="./downloads2",
|
||||||
|
use_title_as_filename=False
|
||||||
|
)
|
||||||
|
result = tool._run(
|
||||||
|
search_query="deep learning",
|
||||||
|
max_results=1
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import ClassVar, List, Optional, Type
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivToolInput(BaseModel):
|
||||||
|
search_query: str = Field(
|
||||||
|
..., description="Search query for Arxiv, e.g., 'transformer neural network'"
|
||||||
|
)
|
||||||
|
max_results: int = Field(
|
||||||
|
5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivPaperTool(BaseTool):
|
||||||
|
BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query"
|
||||||
|
SLEEP_DURATION: ClassVar[int] = 1
|
||||||
|
SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300
|
||||||
|
ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}"
|
||||||
|
REQUEST_TIMEOUT: ClassVar[int] = 10
|
||||||
|
name: str = "Arxiv Paper Fetcher and Downloader"
|
||||||
|
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
|
||||||
|
args_schema: Type[BaseModel] = ArxivToolInput
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
package_dependencies: List[str] = ["pydantic"]
|
||||||
|
env_vars: List[EnvVar] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.download_pdfs = download_pdfs
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.use_title_as_filename = use_title_as_filename
|
||||||
|
|
||||||
|
def _run(self, search_query: str, max_results: int = 5) -> str:
|
||||||
|
try:
|
||||||
|
args = ArxivToolInput(search_query=search_query, max_results=max_results)
|
||||||
|
logger.info(
|
||||||
|
f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
|
||||||
|
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
|
||||||
|
f"use_title_as_filename={self.use_title_as_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
papers = self.fetch_arxiv_data(args.search_query, args.max_results)
|
||||||
|
|
||||||
|
if self.download_pdfs:
|
||||||
|
save_dir = self._validate_save_path(self.save_dir)
|
||||||
|
for paper in papers:
|
||||||
|
if paper["pdf_url"]:
|
||||||
|
if self.use_title_as_filename:
|
||||||
|
safe_title = re.sub(
|
||||||
|
r'[\\/*?:"<>|]', "_", paper["title"]
|
||||||
|
).strip()
|
||||||
|
filename_base = safe_title or paper["arxiv_id"]
|
||||||
|
else:
|
||||||
|
filename_base = paper["arxiv_id"]
|
||||||
|
filename = f"{filename_base[:500]}.pdf"
|
||||||
|
save_path = Path(save_dir) / filename
|
||||||
|
|
||||||
|
self.download_pdf(paper["pdf_url"], save_path)
|
||||||
|
time.sleep(self.SLEEP_DURATION)
|
||||||
|
|
||||||
|
results = [self._format_paper_result(p) for p in papers]
|
||||||
|
return "\n\n" + "-" * 80 + "\n\n".join(results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ArxivTool Error: {e!s}")
|
||||||
|
return f"Failed to fetch or download Arxiv papers: {e!s}"
|
||||||
|
|
||||||
|
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]:
|
||||||
|
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
|
||||||
|
logger.info(f"Fetching data from Arxiv API: {api_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
api_url, timeout=self.REQUEST_TIMEOUT
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(f"HTTP {response.status}: {response.reason}")
|
||||||
|
data = response.read().decode("utf-8")
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
logger.error(f"Error fetching data from Arxiv: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
root = ET.fromstring(data)
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
|
||||||
|
raw_id = self._get_element_text(entry, "id")
|
||||||
|
arxiv_id = raw_id.split("/")[-1].replace(".", "_") if raw_id else "unknown"
|
||||||
|
|
||||||
|
title = self._get_element_text(entry, "title") or "No Title"
|
||||||
|
summary = self._get_element_text(entry, "summary") or "No Summary"
|
||||||
|
published = self._get_element_text(entry, "published") or "No Publish Date"
|
||||||
|
authors = [
|
||||||
|
self._get_element_text(author, "name") or "Unknown"
|
||||||
|
for author in entry.findall(self.ATOM_NAMESPACE + "author")
|
||||||
|
]
|
||||||
|
|
||||||
|
pdf_url = self._extract_pdf_url(entry)
|
||||||
|
|
||||||
|
papers.append(
|
||||||
|
{
|
||||||
|
"arxiv_id": arxiv_id,
|
||||||
|
"title": title,
|
||||||
|
"summary": summary,
|
||||||
|
"authors": authors,
|
||||||
|
"published_date": published,
|
||||||
|
"pdf_url": pdf_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]:
|
||||||
|
elem = entry.find(f"{ArxivPaperTool.ATOM_NAMESPACE}{element_name}")
|
||||||
|
return elem.text.strip() if elem is not None and elem.text else None
|
||||||
|
|
||||||
|
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]:
|
||||||
|
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||||
|
if link.attrib.get("title", "").lower() == "pdf":
|
||||||
|
return link.attrib.get("href")
|
||||||
|
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||||
|
href = link.attrib.get("href")
|
||||||
|
if href and "pdf" in href:
|
||||||
|
return href
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _format_paper_result(self, paper: dict) -> str:
|
||||||
|
summary = (
|
||||||
|
(paper["summary"][: self.SUMMARY_TRUNCATE_LENGTH] + "...")
|
||||||
|
if len(paper["summary"]) > self.SUMMARY_TRUNCATE_LENGTH
|
||||||
|
else paper["summary"]
|
||||||
|
)
|
||||||
|
authors_str = ", ".join(paper["authors"])
|
||||||
|
return (
|
||||||
|
f"Title: {paper['title']}\n"
|
||||||
|
f"Authors: {authors_str}\n"
|
||||||
|
f"Published: {paper['published_date']}\n"
|
||||||
|
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
|
||||||
|
f"Summary: {summary}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_save_path(path: str) -> Path:
|
||||||
|
save_path = Path(path).resolve()
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return save_path
|
||||||
|
|
||||||
|
def download_pdf(self, pdf_url: str, save_path: str):
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
|
||||||
|
urllib.request.urlretrieve(pdf_url, str(save_path))
|
||||||
|
logger.info(f"PDF saved: {save_path}")
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
logger.error(f"Network error occurred while downloading {pdf_url}: {e}")
|
||||||
|
raise
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"File save error for {save_path}: {e}")
|
||||||
|
raise
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import urllib.error
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from crewai_tools import ArxivPaperTool
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tool():
|
||||||
|
return ArxivPaperTool(download_pdfs=False)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_arxiv_response():
|
||||||
|
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||||
|
<entry>
|
||||||
|
<id>http://arxiv.org/abs/1234.5678</id>
|
||||||
|
<title>Sample Paper</title>
|
||||||
|
<summary>This is a summary of the sample paper.</summary>
|
||||||
|
<published>2022-01-01T00:00:00Z</published>
|
||||||
|
<author><name>John Doe</name></author>
|
||||||
|
<link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/>
|
||||||
|
</entry>
|
||||||
|
</feed>"""
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_fetch_arxiv_data(mock_urlopen, tool):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||||
|
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
results = tool.fetch_arxiv_data("transformer", 1)
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert results[0]["title"] == "Sample Paper"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout"))
|
||||||
|
def test_fetch_arxiv_data_network_error(mock_urlopen, tool):
|
||||||
|
with pytest.raises(urllib.error.URLError):
|
||||||
|
tool.fetch_arxiv_data("transformer", 1)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlretrieve")
|
||||||
|
def test_download_pdf_success(mock_urlretrieve):
|
||||||
|
tool = ArxivPaperTool()
|
||||||
|
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf"))
|
||||||
|
mock_urlretrieve.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied"))
|
||||||
|
def test_download_pdf_oserror(mock_urlretrieve):
|
||||||
|
tool = ArxivPaperTool()
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
tool.download_pdf(
|
||||||
|
"http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
@patch("urllib.request.urlretrieve")
|
||||||
|
def test_run_with_download(mock_urlretrieve, mock_urlopen):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||||
|
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
tool = ArxivPaperTool(download_pdfs=True)
|
||||||
|
output = tool._run("transformer", 1)
|
||||||
|
assert "Title: Sample Paper" in output
|
||||||
|
mock_urlretrieve.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_run_no_download(mock_urlopen):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
|
||||||
|
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
tool = ArxivPaperTool(download_pdfs=False)
|
||||||
|
result = tool._run("transformer", 1)
|
||||||
|
assert "Title: Sample Paper" in result
|
||||||
|
|
||||||
|
|
||||||
|
@patch("pathlib.Path.mkdir")
|
||||||
|
def test_validate_save_path_creates_directory(mock_mkdir):
|
||||||
|
path = ArxivPaperTool._validate_save_path("new_folder")
|
||||||
|
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||||
|
assert isinstance(path, Path)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_run_handles_exception(mock_urlopen):
|
||||||
|
mock_urlopen.side_effect = Exception("API failure")
|
||||||
|
tool = ArxivPaperTool()
|
||||||
|
result = tool._run("transformer", 1)
|
||||||
|
assert "Failed to fetch or download Arxiv papers" in result
|
||||||
|
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_invalid_xml_response(mock_urlopen, tool):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.read.return_value = b"<invalid><xml>"
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(ET.ParseError):
|
||||||
|
tool.fetch_arxiv_data("quantum", 1)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(ArxivPaperTool, "fetch_arxiv_data")
|
||||||
|
def test_run_with_max_results(mock_fetch, tool):
|
||||||
|
mock_fetch.return_value = [
|
||||||
|
{
|
||||||
|
"arxiv_id": f"test_{i}",
|
||||||
|
"title": f"Title {i}",
|
||||||
|
"summary": "Summary",
|
||||||
|
"authors": ["Author"],
|
||||||
|
"published_date": "2023-01-01",
|
||||||
|
"pdf_url": None,
|
||||||
|
}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = tool._run(search_query="test", max_results=100)
|
||||||
|
assert result.count("Title:") == 100
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
# BraveSearchTool Documentation
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This tool is designed to perform a web search for a specified query from a text's content across the internet. It utilizes the Brave Web Search API, which is a REST API to query Brave Search and get back search results from the web. The following sections describe how to curate requests, including parameters and headers, to Brave Web Search API and get a JSON response back.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
To incorporate this tool into your project, follow the installation instructions below:
|
||||||
|
```shell
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
The following example demonstrates how to initialize the tool and execute a search with a given query:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import BraveSearchTool
|
||||||
|
|
||||||
|
# Initialize the tool for internet searching capabilities
|
||||||
|
tool = BraveSearchTool()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Steps to Get Started
|
||||||
|
To effectively use the `BraveSearchTool`, follow these steps:
|
||||||
|
|
||||||
|
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
|
||||||
|
2. **API Key Acquisition**: Acquire a API key [here](https://api.search.brave.com/app/keys).
|
||||||
|
3. **Environment Configuration**: Store your obtained API key in an environment variable named `BRAVE_API_KEY` to facilitate its use by the tool.
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
By integrating the `BraveSearchTool` into Python projects, users gain the ability to conduct real-time, relevant searches across the internet directly from their applications. By adhering to the setup and usage guidelines provided, incorporating this tool into projects is streamlined and straightforward.
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, ClassVar, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _save_results_to_file(content: str) -> None:
|
||||||
|
"""Saves the search results to a file."""
|
||||||
|
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||||
|
with open(filename, "w") as file:
|
||||||
|
file.write(content)
|
||||||
|
print(f"Results saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchToolSchema(BaseModel):
|
||||||
|
"""Input for BraveSearchTool."""
|
||||||
|
|
||||||
|
search_query: str = Field(
|
||||||
|
..., description="Mandatory search query you want to use to search the internet"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchTool(BaseTool):
|
||||||
|
"""
|
||||||
|
BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||||
|
|
||||||
|
This module provides functionality to search the internet using Brave's Search API,
|
||||||
|
supporting customizable result counts and country-specific searches.
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- requests
|
||||||
|
- pydantic
|
||||||
|
- python-dotenv (for API key management)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Brave Web Search the internet"
|
||||||
|
description: str = (
|
||||||
|
"A tool that can be used to search the internet with a search_query."
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = BraveSearchToolSchema
|
||||||
|
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
country: Optional[str] = ""
|
||||||
|
n_results: int = 10
|
||||||
|
save_file: bool = False
|
||||||
|
_last_request_time: ClassVar[float] = 0
|
||||||
|
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="BRAVE_API_KEY", description="API key for Brave Search", required=True
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if "BRAVE_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"BRAVE_API_KEY environment variable is required for BraveSearchTool"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
current_time = time.time()
|
||||||
|
if (current_time - self._last_request_time) < self._min_request_interval:
|
||||||
|
time.sleep(
|
||||||
|
self._min_request_interval - (current_time - self._last_request_time)
|
||||||
|
)
|
||||||
|
BraveSearchTool._last_request_time = time.time()
|
||||||
|
try:
|
||||||
|
search_query = kwargs.get("search_query") or kwargs.get("query")
|
||||||
|
if not search_query:
|
||||||
|
raise ValueError("Search query is required")
|
||||||
|
|
||||||
|
save_file = kwargs.get("save_file", self.save_file)
|
||||||
|
n_results = kwargs.get("n_results", self.n_results)
|
||||||
|
|
||||||
|
payload = {"q": search_query, "count": n_results}
|
||||||
|
|
||||||
|
if self.country != "":
|
||||||
|
payload["country"] = self.country
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(self.search_url, headers=headers, params=payload)
|
||||||
|
response.raise_for_status() # Handle non-200 responses
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
if "web" in results:
|
||||||
|
results = results["web"]["results"]
|
||||||
|
string = []
|
||||||
|
for result in results:
|
||||||
|
try:
|
||||||
|
string.append(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f"Title: {result['title']}",
|
||||||
|
f"Link: {result['url']}",
|
||||||
|
f"Snippet: {result['description']}",
|
||||||
|
"---",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = "\n".join(string)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return f"Error performing search: {e!s}"
|
||||||
|
except KeyError as e:
|
||||||
|
return f"Error parsing search results: {e!s}"
|
||||||
|
if save_file:
|
||||||
|
_save_results_to_file(content)
|
||||||
|
return f"\nSearch results: {content}\n"
|
||||||
|
return content
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# BrightData Tools Documentation
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
A comprehensive suite of CrewAI tools that leverage Bright Data's powerful infrastructure for web scraping, data extraction, and search operations. These tools provide three distinct capabilities:
|
||||||
|
|
||||||
|
- **BrightDataDatasetTool**: Extract structured data from popular data feeds (Amazon, LinkedIn, Instagram, etc.) using pre-built datasets
|
||||||
|
- **BrightDataSearchTool**: Perform web searches across multiple search engines with geo-targeting and device simulation
|
||||||
|
- **BrightDataWebUnlockerTool**: Scrape any website content while bypassing bot protection mechanisms
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To incorporate these tools into your project, follow the installation instructions below:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install crewai[tools] aiohttp requests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Dataset Tool - Extract Amazon Product Data
|
||||||
|
```python
|
||||||
|
from crewai_tools import BrightDataDatasetTool
|
||||||
|
|
||||||
|
# Initialize with specific dataset and URL
|
||||||
|
tool = BrightDataDatasetTool(
|
||||||
|
dataset_type="amazon_product",
|
||||||
|
url="https://www.amazon.com/dp/B08QB1QMJ5/"
|
||||||
|
)
|
||||||
|
result = tool.run()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Search Tool - Perform Web Search
|
||||||
|
```python
|
||||||
|
from crewai_tools import BrightDataSearchTool
|
||||||
|
|
||||||
|
# Initialize with search query
|
||||||
|
tool = BrightDataSearchTool(
|
||||||
|
query="latest AI trends 2025",
|
||||||
|
search_engine="google",
|
||||||
|
country="us"
|
||||||
|
)
|
||||||
|
result = tool.run()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Web Unlocker Tool - Scrape Website Content
|
||||||
|
```python
|
||||||
|
from crewai_tools import BrightDataWebUnlockerTool
|
||||||
|
|
||||||
|
# Initialize with target URL
|
||||||
|
tool = BrightDataWebUnlockerTool(
|
||||||
|
url="https://example.com",
|
||||||
|
data_format="markdown"
|
||||||
|
)
|
||||||
|
result = tool.run()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Steps to Get Started
|
||||||
|
|
||||||
|
To effectively use the BrightData Tools, follow these steps:
|
||||||
|
|
||||||
|
1. **Package Installation**: Confirm that the `crewai[tools]` package is installed in your Python environment.
|
||||||
|
|
||||||
|
2. **API Key Acquisition**: Register for a Bright Data account at `https://brightdata.com/` and obtain your API credentials from your account settings.
|
||||||
|
|
||||||
|
3. **Environment Configuration**: Set up the required environment variables:
|
||||||
|
```bash
|
||||||
|
export BRIGHT_DATA_API_KEY="your_api_key_here"
|
||||||
|
export BRIGHT_DATA_ZONE="your_zone_here"
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Tool Selection**: Choose the appropriate tool based on your needs:
|
||||||
|
- Use **DatasetTool** for structured data from supported platforms
|
||||||
|
- Use **SearchTool** for web search operations
|
||||||
|
- Use **WebUnlockerTool** for general website scraping
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
By integrating BrightData Tools into your CrewAI agents, you gain access to enterprise-grade web scraping and data extraction capabilities. These tools handle complex challenges like bot protection, geo-restrictions, and data parsing, allowing you to focus on building your applications rather than managing scraping infrastructure.
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .brightdata_dataset import BrightDataDatasetTool
|
||||||
|
from .brightdata_serp import BrightDataSearchTool
|
||||||
|
from .brightdata_unlocker import BrightDataWebUnlockerTool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BrightDataDatasetTool", "BrightDataSearchTool", "BrightDataWebUnlockerTool"]
|
||||||
@@ -0,0 +1,595 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataConfig(BaseModel):
|
||||||
|
API_URL: str = "https://api.brightdata.com"
|
||||||
|
DEFAULT_TIMEOUT: int = 600
|
||||||
|
DEFAULT_POLLING_INTERVAL: int = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls):
|
||||||
|
return cls(
|
||||||
|
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
|
||||||
|
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
|
||||||
|
DEFAULT_POLLING_INTERVAL=int(
|
||||||
|
os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataDatasetToolException(Exception):
|
||||||
|
"""Exception raised for custom error in the application."""
|
||||||
|
|
||||||
|
def __init__(self, message, error_code):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.message} (Error Code: {self.error_code})"
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataDatasetToolSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Schema for validating input parameters for the BrightDataDatasetTool.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dataset_type (str): Required Bright Data Dataset Type used to specify which dataset to access.
|
||||||
|
format (str): Response format (json by default). Multiple formats exist - json, ndjson, jsonl, csv
|
||||||
|
url (str): The URL from which structured data needs to be extracted.
|
||||||
|
zipcode (Optional[str]): An optional ZIP code to narrow down the data geographically.
|
||||||
|
additional_params (Optional[Dict]): Extra parameters for the Bright Data API call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_type: str = Field(..., description="The Bright Data Dataset Type")
|
||||||
|
format: Optional[str] = Field(
|
||||||
|
default="json", description="Response format (json by default)"
|
||||||
|
)
|
||||||
|
url: str = Field(..., description="The URL to extract data from")
|
||||||
|
zipcode: Optional[str] = Field(default=None, description="Optional zipcode")
|
||||||
|
additional_params: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, description="Additional params if any"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
config = BrightDataConfig.from_env()
|
||||||
|
|
||||||
|
BRIGHTDATA_API_URL = config.API_URL
|
||||||
|
timeout = config.DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
datasets = [
|
||||||
|
{
|
||||||
|
"id": "amazon_product",
|
||||||
|
"dataset_id": "gd_l7q7dkf244hwjntr0",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured amazon product data.",
|
||||||
|
"Requires a valid product URL with /dp/ in it.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "amazon_product_reviews",
|
||||||
|
"dataset_id": "gd_le8e811kzy4ggddlq",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured amazon product review data.",
|
||||||
|
"Requires a valid product URL with /dp/ in it.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "amazon_product_search",
|
||||||
|
"dataset_id": "gd_lwdb4vjm1ehb499uxs",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured amazon product search data.",
|
||||||
|
"Requires a valid search keyword and amazon domain URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["keyword", "url", "pages_to_search"],
|
||||||
|
"defaults": {"pages_to_search": "1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "walmart_product",
|
||||||
|
"dataset_id": "gd_l95fol7l1ru6rlo116",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured walmart product data.",
|
||||||
|
"Requires a valid product URL with /ip/ in it.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "walmart_seller",
|
||||||
|
"dataset_id": "gd_m7ke48w81ocyu4hhz0",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured walmart seller data.",
|
||||||
|
"Requires a valid walmart seller URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ebay_product",
|
||||||
|
"dataset_id": "gd_ltr9mjt81n0zzdk1fb",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured ebay product data.",
|
||||||
|
"Requires a valid ebay product URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "homedepot_products",
|
||||||
|
"dataset_id": "gd_lmusivh019i7g97q2n",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured homedepot product data.",
|
||||||
|
"Requires a valid homedepot product URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "zara_products",
|
||||||
|
"dataset_id": "gd_lct4vafw1tgx27d4o0",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured zara product data.",
|
||||||
|
"Requires a valid zara product URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "etsy_products",
|
||||||
|
"dataset_id": "gd_ltppk0jdv1jqz25mz",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured etsy product data.",
|
||||||
|
"Requires a valid etsy product URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "bestbuy_products",
|
||||||
|
"dataset_id": "gd_ltre1jqe1jfr7cccf",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured bestbuy product data.",
|
||||||
|
"Requires a valid bestbuy product URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "linkedin_person_profile",
|
||||||
|
"dataset_id": "gd_l1viktl72bvl7bjuj0",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured linkedin people profile data.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "linkedin_company_profile",
|
||||||
|
"dataset_id": "gd_l1vikfnt1wgvvqz95w",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured linkedin company profile data",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "linkedin_job_listings",
|
||||||
|
"dataset_id": "gd_lpfll7v5hcqtkxl6l",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured linkedin job listings data",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "linkedin_posts",
|
||||||
|
"dataset_id": "gd_lyy3tktm25m4avu764",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured linkedin posts data",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "linkedin_people_search",
|
||||||
|
"dataset_id": "gd_m8d03he47z8nwb5xc",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured linkedin people search data",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url", "first_name", "last_name"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "crunchbase_company",
|
||||||
|
"dataset_id": "gd_l1vijqt9jfj7olije",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured crunchbase company data",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "zoominfo_company_profile",
|
||||||
|
"dataset_id": "gd_m0ci4a4ivx3j5l6nx",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured ZoomInfo company profile data.",
|
||||||
|
"Requires a valid ZoomInfo company URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "instagram_profiles",
|
||||||
|
"dataset_id": "gd_l1vikfch901nx3by4",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Instagram profile data.",
|
||||||
|
"Requires a valid Instagram URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "instagram_posts",
|
||||||
|
"dataset_id": "gd_lk5ns7kz21pck8jpis",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Instagram post data.",
|
||||||
|
"Requires a valid Instagram URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "instagram_reels",
|
||||||
|
"dataset_id": "gd_lyclm20il4r5helnj",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Instagram reel data.",
|
||||||
|
"Requires a valid Instagram URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "instagram_comments",
|
||||||
|
"dataset_id": "gd_ltppn085pokosxh13",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Instagram comments data.",
|
||||||
|
"Requires a valid Instagram URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "facebook_posts",
|
||||||
|
"dataset_id": "gd_lyclm1571iy3mv57zw",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Facebook post data.",
|
||||||
|
"Requires a valid Facebook post URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "facebook_marketplace_listings",
|
||||||
|
"dataset_id": "gd_lvt9iwuh6fbcwmx1a",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Facebook marketplace listing data.",
|
||||||
|
"Requires a valid Facebook marketplace listing URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "facebook_company_reviews",
|
||||||
|
"dataset_id": "gd_m0dtqpiu1mbcyc2g86",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Facebook company reviews data.",
|
||||||
|
"Requires a valid Facebook company URL and number of reviews.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url", "num_of_reviews"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "facebook_events",
|
||||||
|
"dataset_id": "gd_m14sd0to1jz48ppm51",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Facebook events data.",
|
||||||
|
"Requires a valid Facebook event URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tiktok_profiles",
|
||||||
|
"dataset_id": "gd_l1villgoiiidt09ci",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Tiktok profiles data.",
|
||||||
|
"Requires a valid Tiktok profile URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tiktok_posts",
|
||||||
|
"dataset_id": "gd_lu702nij2f790tmv9h",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Tiktok post data.",
|
||||||
|
"Requires a valid Tiktok post URL.",
|
||||||
|
"This can be a cache lookup, so it can be more reliable than scraping",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tiktok_shop",
|
||||||
|
"dataset_id": "gd_m45m1u911dsa4274pi",
|
||||||
|
"description": "\n".join(
|
||||||
|
[
|
||||||
|
"Quickly read structured Tiktok shop data.",
|
||||||
|
"Requires a valid Tiktok shop product URL.",
|
||||||
|
"This can be a cache lookup...",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"inputs": ["url"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataDatasetTool(BaseTool):
|
||||||
|
"""
|
||||||
|
CrewAI-compatible tool for scraping structured data using Bright Data Datasets.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): Tool name displayed in the CrewAI environment.
|
||||||
|
description (str): Tool description shown to agents or users.
|
||||||
|
args_schema (Type[BaseModel]): Pydantic schema for validating input arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Bright Data Dataset Tool"
|
||||||
|
description: str = "Scrapes structured data using Bright Data Dataset API from a URL and optional input parameters"
|
||||||
|
args_schema: Type[BaseModel] = BrightDataDatasetToolSchema
|
||||||
|
dataset_type: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
format: str = "json"
|
||||||
|
zipcode: Optional[str] = None
|
||||||
|
additional_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_type: str = None,
|
||||||
|
url: str = None,
|
||||||
|
format: str = "json",
|
||||||
|
zipcode: str = None,
|
||||||
|
additional_params: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dataset_type = dataset_type
|
||||||
|
self.url = url
|
||||||
|
self.format = format
|
||||||
|
self.zipcode = zipcode
|
||||||
|
self.additional_params = additional_params
|
||||||
|
|
||||||
|
def filter_dataset_by_id(self, target_id):
|
||||||
|
return [dataset for dataset in datasets if dataset["id"] == target_id]
|
||||||
|
|
||||||
|
async def get_dataset_data_async(
|
||||||
|
self,
|
||||||
|
dataset_type: str,
|
||||||
|
output_format: str,
|
||||||
|
url: str,
|
||||||
|
zipcode: Optional[str] = None,
|
||||||
|
additional_params: Optional[Dict[str, Any]] = None,
|
||||||
|
polling_interval: int = 1,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Asynchronously trigger and poll Bright Data dataset scraping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_type (str): Bright Data Dataset Type.
|
||||||
|
url (str): Target URL to scrape.
|
||||||
|
zipcode (Optional[str]): Optional ZIP code for geo-specific data.
|
||||||
|
additional_params (Optional[Dict]): Extra API parameters.
|
||||||
|
polling_interval (int): Time interval in seconds between polling attempts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Structured dataset result from Bright Data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If any API step fails or the job fails.
|
||||||
|
TimeoutError: If polling times out before job completion.
|
||||||
|
"""
|
||||||
|
request_data = {"url": url}
|
||||||
|
if zipcode is not None:
|
||||||
|
request_data["zipcode"] = zipcode
|
||||||
|
|
||||||
|
# Set additional parameters dynamically depending upon the dataset that is being requested
|
||||||
|
if additional_params:
|
||||||
|
request_data.update(additional_params)
|
||||||
|
|
||||||
|
api_key = os.getenv("BRIGHT_DATA_API_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_id = ""
|
||||||
|
dataset = self.filter_dataset_by_id(dataset_type)
|
||||||
|
|
||||||
|
if len(dataset) == 1:
|
||||||
|
dataset_id = dataset[0]["dataset_id"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to find the dataset for {dataset_type}. Please make sure to pass a valid one"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Step 1: Trigger job
|
||||||
|
async with session.post(
|
||||||
|
f"{BRIGHTDATA_API_URL}/datasets/v3/trigger",
|
||||||
|
params={"dataset_id": dataset_id, "include_errors": "true"},
|
||||||
|
json=[request_data],
|
||||||
|
headers=headers,
|
||||||
|
) as trigger_response:
|
||||||
|
if trigger_response.status != 200:
|
||||||
|
raise BrightDataDatasetToolException(
|
||||||
|
f"Trigger failed: {await trigger_response.text()}",
|
||||||
|
trigger_response.status,
|
||||||
|
)
|
||||||
|
trigger_data = await trigger_response.json()
|
||||||
|
print(trigger_data)
|
||||||
|
snapshot_id = trigger_data.get("snapshot_id")
|
||||||
|
|
||||||
|
# Step 2: Poll for completion
|
||||||
|
elapsed = 0
|
||||||
|
while elapsed < timeout:
|
||||||
|
await asyncio.sleep(polling_interval)
|
||||||
|
elapsed += polling_interval
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"{BRIGHTDATA_API_URL}/datasets/v3/progress/{snapshot_id}",
|
||||||
|
headers=headers,
|
||||||
|
) as status_response:
|
||||||
|
if status_response.status != 200:
|
||||||
|
raise BrightDataDatasetToolException(
|
||||||
|
f"Status check failed: {await status_response.text()}",
|
||||||
|
status_response.status,
|
||||||
|
)
|
||||||
|
status_data = await status_response.json()
|
||||||
|
if status_data.get("status") == "ready":
|
||||||
|
print("Job is ready")
|
||||||
|
break
|
||||||
|
if status_data.get("status") == "error":
|
||||||
|
raise BrightDataDatasetToolException(
|
||||||
|
f"Job failed: {status_data}", 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TimeoutError("Polling timed out before job completed.")
|
||||||
|
|
||||||
|
# Step 3: Retrieve result
|
||||||
|
async with session.get(
|
||||||
|
f"{BRIGHTDATA_API_URL}/datasets/v3/snapshot/{snapshot_id}",
|
||||||
|
params={"format": output_format},
|
||||||
|
headers=headers,
|
||||||
|
) as snapshot_response:
|
||||||
|
if snapshot_response.status != 200:
|
||||||
|
raise BrightDataDatasetToolException(
|
||||||
|
f"Result fetch failed: {await snapshot_response.text()}",
|
||||||
|
snapshot_response.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await snapshot_response.text()
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
url: str = None,
|
||||||
|
dataset_type: str = None,
|
||||||
|
format: str = None,
|
||||||
|
zipcode: str = None,
|
||||||
|
additional_params: Dict[str, Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
dataset_type = dataset_type or self.dataset_type
|
||||||
|
output_format = format or self.format
|
||||||
|
url = url or self.url
|
||||||
|
zipcode = zipcode or self.zipcode
|
||||||
|
additional_params = additional_params or self.additional_params
|
||||||
|
|
||||||
|
if not dataset_type:
|
||||||
|
raise ValueError(
|
||||||
|
"dataset_type is required either in constructor or method call"
|
||||||
|
)
|
||||||
|
if not url:
|
||||||
|
raise ValueError("url is required either in constructor or method call")
|
||||||
|
|
||||||
|
valid_output_formats = {"json", "ndjson", "jsonl", "csv"}
|
||||||
|
if output_format not in valid_output_formats:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported output format: {output_format}. Must be one of {', '.join(valid_output_formats)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = os.getenv("BRIGHT_DATA_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("BRIGHT_DATA_API_KEY environment variable is required.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return asyncio.run(
|
||||||
|
self.get_dataset_data_async(
|
||||||
|
dataset_type=dataset_type,
|
||||||
|
output_format=output_format,
|
||||||
|
url=url,
|
||||||
|
zipcode=zipcode,
|
||||||
|
additional_params=additional_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
return f"Timeout Exception occured in method : get_dataset_data_async. Details - {e!s}"
|
||||||
|
except BrightDataDatasetToolException as e:
|
||||||
|
return (
|
||||||
|
f"Exception occured in method : get_dataset_data_async. Details - {e!s}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Bright Data API error: {e!s}"
|
||||||
@@ -0,0 +1,232 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataConfig(BaseModel):
|
||||||
|
API_URL: str = "https://api.brightdata.com/request"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls):
|
||||||
|
return cls(
|
||||||
|
API_URL=os.environ.get(
|
||||||
|
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataSearchToolSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Schema that defines the input arguments for the BrightDataSearchToolSchema.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
query (str): The search query to be executed (e.g., "latest AI news").
|
||||||
|
search_engine (Optional[str]): The search engine to use ("google", "bing", "yandex"). Default is "google".
|
||||||
|
country (Optional[str]): Two-letter country code for geo-targeting (e.g., "us", "in"). Default is "us".
|
||||||
|
language (Optional[str]): Language code for search results (e.g., "en", "es"). Default is "en".
|
||||||
|
search_type (Optional[str]): Type of search, such as "isch" (images), "nws" (news), "jobs", etc.
|
||||||
|
device_type (Optional[str]): Device type to simulate ("desktop", "mobile", "ios", "android"). Default is "desktop".
|
||||||
|
parse_results (Optional[bool]): If True, results will be returned in structured JSON. If False, raw HTML. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str = Field(..., description="Search query to perform")
|
||||||
|
search_engine: Optional[str] = Field(
|
||||||
|
default="google",
|
||||||
|
description="Search engine domain (e.g., 'google', 'bing', 'yandex')",
|
||||||
|
)
|
||||||
|
country: Optional[str] = Field(
|
||||||
|
default="us",
|
||||||
|
description="Two-letter country code for geo-targeting (e.g., 'us', 'gb')",
|
||||||
|
)
|
||||||
|
language: Optional[str] = Field(
|
||||||
|
default="en",
|
||||||
|
description="Language code (e.g., 'en', 'es') used in the query URL",
|
||||||
|
)
|
||||||
|
search_type: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Type of search (e.g., 'isch' for images, 'nws' for news)",
|
||||||
|
)
|
||||||
|
device_type: Optional[str] = Field(
|
||||||
|
default="desktop",
|
||||||
|
description="Device type to simulate (e.g., 'mobile', 'desktop', 'ios')",
|
||||||
|
)
|
||||||
|
parse_results: Optional[bool] = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to parse and return JSON (True) or raw HTML/text (False)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataSearchTool(BaseTool):
|
||||||
|
"""
|
||||||
|
A web search tool that utilizes Bright Data's SERP API to perform queries and return either structured results
|
||||||
|
or raw page content from search engines like Google or Bing.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): Tool name used by the agent.
|
||||||
|
description (str): A brief explanation of what the tool does.
|
||||||
|
args_schema (Type[BaseModel]): Schema class for validating tool arguments.
|
||||||
|
base_url (str): The Bright Data API endpoint used for making the POST request.
|
||||||
|
api_key (str): Bright Data API key loaded from the environment variable 'BRIGHT_DATA_API_KEY'.
|
||||||
|
zone (str): Zone identifier from Bright Data, loaded from the environment variable 'BRIGHT_DATA_ZONE'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If API key or zone environment variables are not set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Bright Data SERP Search"
|
||||||
|
description: str = "Tool to perform web search using Bright Data SERP API."
|
||||||
|
args_schema: Type[BaseModel] = BrightDataSearchToolSchema
|
||||||
|
_config = BrightDataConfig.from_env()
|
||||||
|
base_url: str = ""
|
||||||
|
api_key: str = ""
|
||||||
|
zone: str = ""
|
||||||
|
query: Optional[str] = None
|
||||||
|
search_engine: str = "google"
|
||||||
|
country: str = "us"
|
||||||
|
language: str = "en"
|
||||||
|
search_type: Optional[str] = None
|
||||||
|
device_type: str = "desktop"
|
||||||
|
parse_results: bool = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query: str = None,
|
||||||
|
search_engine: str = "google",
|
||||||
|
country: str = "us",
|
||||||
|
language: str = "en",
|
||||||
|
search_type: str = None,
|
||||||
|
device_type: str = "desktop",
|
||||||
|
parse_results: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_url = self._config.API_URL
|
||||||
|
self.query = query
|
||||||
|
self.search_engine = search_engine
|
||||||
|
self.country = country
|
||||||
|
self.language = language
|
||||||
|
self.search_type = search_type
|
||||||
|
self.device_type = device_type
|
||||||
|
self.parse_results = parse_results
|
||||||
|
|
||||||
|
self.api_key = os.getenv("BRIGHT_DATA_API_KEY")
|
||||||
|
self.zone = os.getenv("BRIGHT_DATA_ZONE")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("BRIGHT_DATA_API_KEY environment variable is required.")
|
||||||
|
if not self.zone:
|
||||||
|
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
|
||||||
|
|
||||||
|
def get_search_url(self, engine: str, query: str):
|
||||||
|
if engine == "yandex":
|
||||||
|
return f"https://yandex.com/search/?text=${query}"
|
||||||
|
if engine == "bing":
|
||||||
|
return f"https://www.bing.com/search?q=${query}"
|
||||||
|
return f"https://www.google.com/search?q=${query}"
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str = None,
|
||||||
|
search_engine: str = None,
|
||||||
|
country: str = None,
|
||||||
|
language: str = None,
|
||||||
|
search_type: str = None,
|
||||||
|
device_type: str = None,
|
||||||
|
parse_results: bool = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Executes a search query using Bright Data SERP API and returns results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query string (URL encoded internally).
|
||||||
|
search_engine (str): The search engine to use (default: "google").
|
||||||
|
country (str): Country code for geotargeting (default: "us").
|
||||||
|
language (str): Language code for the query (default: "en").
|
||||||
|
search_type (str): Optional type of search such as "nws", "isch", "jobs".
|
||||||
|
device_type (str): Optional device type to simulate (e.g., "mobile", "ios", "desktop").
|
||||||
|
parse_results (bool): If True, returns structured data; else raw page (default: True).
|
||||||
|
results_count (str or int): Number of search results to fetch (default: "10").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict or str: Parsed JSON data from Bright Data if available, otherwise error message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = query or self.query
|
||||||
|
search_engine = search_engine or self.search_engine
|
||||||
|
country = country or self.country
|
||||||
|
language = language or self.language
|
||||||
|
search_type = search_type or self.search_type
|
||||||
|
device_type = device_type or self.device_type
|
||||||
|
parse_results = (
|
||||||
|
parse_results if parse_results is not None else self.parse_results
|
||||||
|
)
|
||||||
|
results_count = kwargs.get("results_count", "10")
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if not query:
|
||||||
|
raise ValueError("query is required either in constructor or method call")
|
||||||
|
|
||||||
|
# Build the search URL
|
||||||
|
query = urllib.parse.quote(query)
|
||||||
|
url = self.get_search_url(search_engine, query)
|
||||||
|
|
||||||
|
# Add parameters to the URL
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if country:
|
||||||
|
params.append(f"gl={country}")
|
||||||
|
|
||||||
|
if language:
|
||||||
|
params.append(f"hl={language}")
|
||||||
|
|
||||||
|
if results_count:
|
||||||
|
params.append(f"num={results_count}")
|
||||||
|
|
||||||
|
if parse_results:
|
||||||
|
params.append("brd_json=1")
|
||||||
|
|
||||||
|
if search_type:
|
||||||
|
if search_type == "jobs":
|
||||||
|
params.append("ibp=htl;jobs")
|
||||||
|
else:
|
||||||
|
params.append(f"tbm={search_type}")
|
||||||
|
|
||||||
|
if device_type:
|
||||||
|
if device_type == "mobile":
|
||||||
|
params.append("brd_mobile=1")
|
||||||
|
elif device_type == "ios":
|
||||||
|
params.append("brd_mobile=ios")
|
||||||
|
elif device_type == "android":
|
||||||
|
params.append("brd_mobile=android")
|
||||||
|
|
||||||
|
# Combine parameters with the URL
|
||||||
|
if params:
|
||||||
|
url += "&" + "&".join(params)
|
||||||
|
|
||||||
|
# Set up the API request parameters
|
||||||
|
request_params = {"zone": self.zone, "url": url, "format": "raw"}
|
||||||
|
|
||||||
|
request_params = {k: v for k, v in request_params.items() if v is not None}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url, json=request_params, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Status code: {response.status_code}")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return f"Error performing BrightData search: {e!s}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching results: {e!s}"
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataConfig(BaseModel):
|
||||||
|
API_URL: str = "https://api.brightdata.com/request"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls):
|
||||||
|
return cls(
|
||||||
|
API_URL=os.environ.get(
|
||||||
|
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataUnlockerToolSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Pydantic schema for input parameters used by the BrightDataWebUnlockerTool.
|
||||||
|
|
||||||
|
This schema defines the structure and validation for parameters passed when performing
|
||||||
|
a web scraping request using Bright Data's Web Unlocker.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
url (str): The target URL to scrape.
|
||||||
|
format (Optional[str]): Format of the response returned by Bright Data. Default 'raw' format.
|
||||||
|
data_format (Optional[str]): Response data format (html by default). markdown is one more option.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str = Field(..., description="URL to perform the web scraping")
|
||||||
|
format: Optional[str] = Field(
|
||||||
|
default="raw", description="Response format (raw is standard)"
|
||||||
|
)
|
||||||
|
data_format: Optional[str] = Field(
|
||||||
|
default="markdown", description="Response data format (html by default)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrightDataWebUnlockerTool(BaseTool):
|
||||||
|
"""
|
||||||
|
A tool for performing web scraping using the Bright Data Web Unlocker API.
|
||||||
|
|
||||||
|
This tool allows automated and programmatic access to web pages by routing requests
|
||||||
|
through Bright Data's unlocking and proxy infrastructure, which can bypass bot
|
||||||
|
protection mechanisms like CAPTCHA, geo-restrictions, and anti-bot detection.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): Name of the tool.
|
||||||
|
description (str): Description of what the tool does.
|
||||||
|
args_schema (Type[BaseModel]): Pydantic model schema for expected input arguments.
|
||||||
|
base_url (str): Base URL of the Bright Data Web Unlocker API.
|
||||||
|
api_key (str): Bright Data API key (must be set in the BRIGHT_DATA_API_KEY environment variable).
|
||||||
|
zone (str): Bright Data zone identifier (must be set in the BRIGHT_DATA_ZONE environment variable).
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
_run(**kwargs: Any) -> Any:
|
||||||
|
Sends a scraping request to Bright Data's Web Unlocker API and returns the result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Bright Data Web Unlocker Scraping"
|
||||||
|
description: str = "Tool to perform web scraping using Bright Data Web Unlocker"
|
||||||
|
args_schema: Type[BaseModel] = BrightDataUnlockerToolSchema
|
||||||
|
_config = BrightDataConfig.from_env()
|
||||||
|
base_url: str = ""
|
||||||
|
api_key: str = ""
|
||||||
|
zone: str = ""
|
||||||
|
url: Optional[str] = None
|
||||||
|
format: str = "raw"
|
||||||
|
data_format: str = "markdown"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, url: str = None, format: str = "raw", data_format: str = "markdown"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_url = self._config.API_URL
|
||||||
|
self.url = url
|
||||||
|
self.format = format
|
||||||
|
self.data_format = data_format
|
||||||
|
|
||||||
|
self.api_key = os.getenv("BRIGHT_DATA_API_KEY")
|
||||||
|
self.zone = os.getenv("BRIGHT_DATA_ZONE")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("BRIGHT_DATA_API_KEY environment variable is required.")
|
||||||
|
if not self.zone:
|
||||||
|
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
url: str = None,
|
||||||
|
format: str = None,
|
||||||
|
data_format: str = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
url = url or self.url
|
||||||
|
format = format or self.format
|
||||||
|
data_format = data_format or self.data_format
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
raise ValueError("url is required either in constructor or method call")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"url": url,
|
||||||
|
"zone": self.zone,
|
||||||
|
"format": format,
|
||||||
|
}
|
||||||
|
valid_data_formats = {"html", "markdown"}
|
||||||
|
if data_format not in valid_data_formats:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported data format: {data_format}. Must be one of {', '.join(valid_data_formats)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_format == "markdown":
|
||||||
|
payload["data_format"] = "markdown"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(self.base_url, json=payload, headers=headers)
|
||||||
|
print(f"Status Code: {response.status_code}")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return f"HTTP Error performing BrightData Web Unlocker Scrape: {e}\nResponse: {getattr(e.response, 'text', '')}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching results: {e!s}"
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
# BrowserbaseLoadTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
[Browserbase](https://browserbase.com) is a developer platform to reliably run, manage, and monitor headless browsers.
|
||||||
|
|
||||||
|
Power your AI data retrievals with:
|
||||||
|
- [Serverless Infrastructure](https://docs.browserbase.com/under-the-hood) providing reliable browsers to extract data from complex UIs
|
||||||
|
- [Stealth Mode](https://docs.browserbase.com/features/stealth-mode) with included fingerprinting tactics and automatic captcha solving
|
||||||
|
- [Session Debugger](https://docs.browserbase.com/features/sessions) to inspect your Browser Session with networks timeline and logs
|
||||||
|
- [Live Debug](https://docs.browserbase.com/guides/session-debug-connection/browser-remote-control) to quickly debug your automation
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
- Get an API key and Project ID from [browserbase.com](https://browserbase.com) and set it in environment variables (`BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID`).
|
||||||
|
- Install the [Browserbase SDK](http://github.com/browserbase/python-sdk) along with `crewai[tools]` package:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install browserbase 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Utilize the BrowserbaseLoadTool as follows to allow your agent to load websites:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import BrowserbaseLoadTool
|
||||||
|
|
||||||
|
tool = BrowserbaseLoadTool()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
|
||||||
|
- `api_key` Optional. Browserbase API key. Default is `BROWSERBASE_API_KEY` env variable.
|
||||||
|
- `project_id` Optional. Browserbase Project ID. Default is `BROWSERBASE_PROJECT_ID` env variable.
|
||||||
|
- `text_content` Retrieve only text content. Default is `False`.
|
||||||
|
- `session_id` Optional. Provide an existing Session ID.
|
||||||
|
- `proxy` Optional. Enable/Disable Proxies."
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserbaseLoadToolSchema(BaseModel):
|
||||||
|
url: str = Field(description="Website URL")
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserbaseLoadTool(BaseTool):
|
||||||
|
name: str = "Browserbase web load tool"
|
||||||
|
description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
|
||||||
|
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema
|
||||||
|
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY")
|
||||||
|
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||||
|
text_content: Optional[bool] = False
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
proxy: Optional[bool] = None
|
||||||
|
browserbase: Optional[Any] = None
|
||||||
|
package_dependencies: List[str] = ["browserbase"]
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="BROWSERBASE_API_KEY",
|
||||||
|
description="API key for Browserbase services",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
EnvVar(
|
||||||
|
name="BROWSERBASE_PROJECT_ID",
|
||||||
|
description="Project ID for Browserbase services",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
text_content: Optional[bool] = False,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
proxy: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not self.api_key:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"BROWSERBASE_API_KEY environment variable is required for initialization"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from browserbase import Browserbase # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
import click
|
||||||
|
|
||||||
|
if click.confirm(
|
||||||
|
"`browserbase` package not found, would you like to install it?"
|
||||||
|
):
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
subprocess.run(["uv", "add", "browserbase"], check=True)
|
||||||
|
from browserbase import Browserbase # type: ignore
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"`browserbase` package not found, please run `uv add browserbase`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.browserbase = Browserbase(api_key=self.api_key)
|
||||||
|
self.text_content = text_content
|
||||||
|
self.session_id = session_id
|
||||||
|
self.proxy = proxy
|
||||||
|
|
||||||
|
def _run(self, url: str):
|
||||||
|
return self.browserbase.load_url(
|
||||||
|
url, self.text_content, self.session_id, self.proxy
|
||||||
|
)
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
# CodeDocsSearchTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
The CodeDocsSearchTool is a powerful RAG (Retrieval-Augmented Generation) tool designed for semantic searches within code documentation. It enables users to efficiently find specific information or topics within code documentation. By providing a `docs_url` during initialization, the tool narrows down the search to that particular documentation site. Alternatively, without a specific `docs_url`, it searches across a wide array of code documentation known or discovered throughout its execution, making it versatile for various documentation search needs.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
To start using the CodeDocsSearchTool, first, install the crewai_tools package via pip:
|
||||||
|
```shell
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
Utilize the CodeDocsSearchTool as follows to conduct searches within code documentation:
|
||||||
|
```python
|
||||||
|
from crewai_tools import CodeDocsSearchTool
|
||||||
|
|
||||||
|
# To search any code documentation content if the URL is known or discovered during its execution:
|
||||||
|
tool = CodeDocsSearchTool()
|
||||||
|
|
||||||
|
# OR
|
||||||
|
|
||||||
|
# To specifically focus your search on a given documentation site by providing its URL:
|
||||||
|
tool = CodeDocsSearchTool(docs_url='https://docs.example.com/reference')
|
||||||
|
```
|
||||||
|
Note: Substitute 'https://docs.example.com/reference' with your target documentation URL and 'How to use search tool' with the search query relevant to your needs.
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
- `docs_url`: Optional. Specifies the URL of the code documentation to be searched. Providing this during the tool's initialization focuses the search on the specified documentation content.
|
||||||
|
|
||||||
|
## Custom model and embeddings
|
||||||
|
|
||||||
|
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = CodeDocsSearchTool(
|
||||||
|
config=dict(
|
||||||
|
llm=dict(
|
||||||
|
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||||
|
config=dict(
|
||||||
|
model="llama2",
|
||||||
|
# temperature=0.5,
|
||||||
|
# top_p=1,
|
||||||
|
# stream=true,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
embedder=dict(
|
||||||
|
provider="google",
|
||||||
|
config=dict(
|
||||||
|
model="models/embedding-001",
|
||||||
|
task_type="retrieval_document",
|
||||||
|
# title="Embeddings",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
|
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||||
|
"""Input for CodeDocsSearchTool."""
|
||||||
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the Code Docs content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
|
||||||
|
"""Input for CodeDocsSearchTool."""
|
||||||
|
|
||||||
|
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
|
||||||
|
|
||||||
|
|
||||||
|
class CodeDocsSearchTool(RagTool):
|
||||||
|
name: str = "Search a Code Docs content"
|
||||||
|
description: str = (
|
||||||
|
"A tool that can be used to semantic search a query from a Code Docs content."
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||||
|
|
||||||
|
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if docs_url is not None:
|
||||||
|
self.add(docs_url)
|
||||||
|
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||||
|
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||||
|
self._generate_description()
|
||||||
|
|
||||||
|
def add(self, docs_url: str) -> None:
|
||||||
|
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
search_query: str,
|
||||||
|
docs_url: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
if docs_url is not None:
|
||||||
|
self.add(docs_url)
|
||||||
|
return super()._run(
|
||||||
|
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||||
|
)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
FROM python:3.12-alpine
|
||||||
|
|
||||||
|
RUN pip install requests beautifulsoup4
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /workspace
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# CodeInterpreterTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This tool is used to give the Agent the ability to run code (Python3) from the code generated by the Agent itself. The code is executed in a sandboxed environment, so it is safe to run any code.
|
||||||
|
|
||||||
|
It is incredible useful since it allows the Agent to generate code, run it in the same environment, get the result and use it to make decisions.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Docker
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
Install the crewai_tools package
|
||||||
|
```shell
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Remember that when using this tool, the code must be generated by the Agent itself. The code must be a Python3 code. And it will take some time for the first time to run because it needs to build the Docker image.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import CodeInterpreterTool
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
...
|
||||||
|
tools=[CodeInterpreterTool()],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or if you need to pass your own Dockerfile just do this
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import CodeInterpreterTool
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
...
|
||||||
|
tools=[CodeInterpreterTool(user_dockerfile_path="<Dockerfile_path>")],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
If it is difficult to connect to docker daemon automatically (especially for macOS users), you can do this to setup docker host manually
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import CodeInterpreterTool
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
...
|
||||||
|
tools=[CodeInterpreterTool(user_docker_base_url="<Docker Host Base Url>",
|
||||||
|
user_dockerfile_path="<Dockerfile_path>")],
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
@@ -0,0 +1,369 @@
|
|||||||
|
"""Code Interpreter Tool for executing Python code in isolated environments.
|
||||||
|
|
||||||
|
This module provides a tool for executing Python code either in a Docker container for
|
||||||
|
safe isolation or directly in a restricted sandbox. It includes mechanisms for blocking
|
||||||
|
potentially unsafe operations and importing restricted modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from crewai_tools.printer import Printer
|
||||||
|
from docker import DockerClient, from_env as docker_from_env
|
||||||
|
from docker.errors import ImageNotFound, NotFound
|
||||||
|
from docker.models.containers import Container
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterSchema(BaseModel):
|
||||||
|
"""Schema for defining inputs to the CodeInterpreterTool.
|
||||||
|
|
||||||
|
This schema defines the required parameters for code execution,
|
||||||
|
including the code to run and any libraries that need to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
code: str = Field(
|
||||||
|
...,
|
||||||
|
description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code",
|
||||||
|
)
|
||||||
|
|
||||||
|
libraries_used: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxPython:
|
||||||
|
"""A restricted Python execution environment for running code safely.
|
||||||
|
|
||||||
|
This class provides methods to safely execute Python code by restricting access to
|
||||||
|
potentially dangerous modules and built-in functions. It creates a sandboxed
|
||||||
|
environment where harmful operations are blocked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BLOCKED_MODULES = {
|
||||||
|
"os",
|
||||||
|
"sys",
|
||||||
|
"subprocess",
|
||||||
|
"shutil",
|
||||||
|
"importlib",
|
||||||
|
"inspect",
|
||||||
|
"tempfile",
|
||||||
|
"sysconfig",
|
||||||
|
"builtins",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNSAFE_BUILTINS = {
|
||||||
|
"exec",
|
||||||
|
"eval",
|
||||||
|
"open",
|
||||||
|
"compile",
|
||||||
|
"input",
|
||||||
|
"globals",
|
||||||
|
"locals",
|
||||||
|
"vars",
|
||||||
|
"help",
|
||||||
|
"dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restricted_import(
|
||||||
|
name: str,
|
||||||
|
custom_globals: Optional[Dict[str, Any]] = None,
|
||||||
|
custom_locals: Optional[Dict[str, Any]] = None,
|
||||||
|
fromlist: Optional[List[str]] = None,
|
||||||
|
level: int = 0,
|
||||||
|
) -> ModuleType:
|
||||||
|
"""A restricted import function that blocks importing of unsafe modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the module to import.
|
||||||
|
custom_globals: Global namespace to use.
|
||||||
|
custom_locals: Local namespace to use.
|
||||||
|
fromlist: List of items to import from the module.
|
||||||
|
level: The level value passed to __import__.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported module if allowed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the module is in the blocked modules list.
|
||||||
|
"""
|
||||||
|
if name in SandboxPython.BLOCKED_MODULES:
|
||||||
|
raise ImportError(f"Importing '{name}' is not allowed.")
|
||||||
|
return __import__(name, custom_globals, custom_locals, fromlist or (), level)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def safe_builtins() -> Dict[str, Any]:
|
||||||
|
"""Creates a dictionary of built-in functions with unsafe ones removed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of safe built-in functions and objects.
|
||||||
|
"""
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
safe_builtins = {
|
||||||
|
k: v
|
||||||
|
for k, v in builtins.__dict__.items()
|
||||||
|
if k not in SandboxPython.UNSAFE_BUILTINS
|
||||||
|
}
|
||||||
|
safe_builtins["__import__"] = SandboxPython.restricted_import
|
||||||
|
return safe_builtins
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exec(code: str, locals: Dict[str, Any]) -> None:
|
||||||
|
"""Executes Python code in a restricted environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to execute as a string.
|
||||||
|
locals: A dictionary that will be used for local variable storage.
|
||||||
|
"""
|
||||||
|
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterTool(BaseTool):
|
||||||
|
"""A tool for executing Python code in isolated environments.
|
||||||
|
|
||||||
|
This tool provides functionality to run Python code either in a Docker container
|
||||||
|
for safe isolation or directly in a restricted sandbox. It can handle installing
|
||||||
|
Python packages and executing arbitrary Python code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "Code Interpreter"
|
||||||
|
description: str = "Interprets Python3 code strings with a final print statement."
|
||||||
|
args_schema: Type[BaseModel] = CodeInterpreterSchema
|
||||||
|
default_image_tag: str = "code-interpreter:latest"
|
||||||
|
code: Optional[str] = None
|
||||||
|
user_dockerfile_path: Optional[str] = None
|
||||||
|
user_docker_base_url: Optional[str] = None
|
||||||
|
unsafe_mode: bool = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_installed_package_path() -> str:
|
||||||
|
"""Gets the installation path of the crewai_tools package.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The directory path where the package is installed.
|
||||||
|
"""
|
||||||
|
spec = importlib.util.find_spec("crewai_tools")
|
||||||
|
return os.path.dirname(spec.origin)
|
||||||
|
|
||||||
|
def _verify_docker_image(self) -> None:
|
||||||
|
"""Verifies if the Docker image is available or builds it if necessary.
|
||||||
|
|
||||||
|
Checks if the required Docker image exists. If not, builds it using either a
|
||||||
|
user-provided Dockerfile or the default one included with the package.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the Dockerfile cannot be found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client = (
|
||||||
|
docker_from_env()
|
||||||
|
if self.user_docker_base_url is None
|
||||||
|
else DockerClient(base_url=self.user_docker_base_url)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.images.get(self.default_image_tag)
|
||||||
|
|
||||||
|
except ImageNotFound:
|
||||||
|
if self.user_dockerfile_path and os.path.exists(self.user_dockerfile_path):
|
||||||
|
dockerfile_path = self.user_dockerfile_path
|
||||||
|
else:
|
||||||
|
package_path = self._get_installed_package_path()
|
||||||
|
dockerfile_path = os.path.join(
|
||||||
|
package_path, "tools/code_interpreter_tool"
|
||||||
|
)
|
||||||
|
if not os.path.exists(dockerfile_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Dockerfile not found in {dockerfile_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client.images.build(
|
||||||
|
path=dockerfile_path,
|
||||||
|
tag=self.default_image_tag,
|
||||||
|
rm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(self, **kwargs) -> str:
|
||||||
|
"""Runs the code interpreter tool with the provided arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments that should include 'code' and 'libraries_used'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the executed code as a string.
|
||||||
|
"""
|
||||||
|
code = kwargs.get("code", self.code)
|
||||||
|
libraries_used = kwargs.get("libraries_used", [])
|
||||||
|
|
||||||
|
if self.unsafe_mode:
|
||||||
|
return self.run_code_unsafe(code, libraries_used)
|
||||||
|
return self.run_code_safety(code, libraries_used)
|
||||||
|
|
||||||
|
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||||
|
"""Installs required Python libraries in the Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container: The Docker container where libraries will be installed.
|
||||||
|
libraries: A list of library names to install using pip.
|
||||||
|
"""
|
||||||
|
for library in libraries:
|
||||||
|
container.exec_run(["pip", "install", library])
|
||||||
|
|
||||||
|
def _init_docker_container(self) -> Container:
|
||||||
|
"""Initializes and returns a Docker container for code execution.
|
||||||
|
|
||||||
|
Stops and removes any existing container with the same name before creating
|
||||||
|
a new one. Maps the current working directory to /workspace in the container.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Docker container object ready for code execution.
|
||||||
|
"""
|
||||||
|
container_name = "code-interpreter"
|
||||||
|
client = docker_from_env()
|
||||||
|
current_path = os.getcwd()
|
||||||
|
|
||||||
|
# Check if the container is already running
|
||||||
|
try:
|
||||||
|
existing_container = client.containers.get(container_name)
|
||||||
|
existing_container.stop()
|
||||||
|
existing_container.remove()
|
||||||
|
except NotFound:
|
||||||
|
pass # Container does not exist, no need to remove
|
||||||
|
|
||||||
|
return client.containers.run(
|
||||||
|
self.default_image_tag,
|
||||||
|
detach=True,
|
||||||
|
tty=True,
|
||||||
|
working_dir="/workspace",
|
||||||
|
name=container_name,
|
||||||
|
volumes={current_path: {"bind": "/workspace", "mode": "rw"}}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_docker_available(self) -> bool:
|
||||||
|
"""Checks if Docker is available and running on the system.
|
||||||
|
|
||||||
|
Attempts to run the 'docker info' command to verify Docker availability.
|
||||||
|
Prints appropriate messages if Docker is not installed or not running.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if Docker is available and running, False otherwise.
|
||||||
|
"""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["docker", "info"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
timeout=1,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||||
|
Printer.print(
|
||||||
|
"Docker is installed but not running or inaccessible.",
|
||||||
|
color="bold_purple",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except FileNotFoundError:
|
||||||
|
Printer.print("Docker is not installed", color="bold_purple")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_code_safety(self, code: str, libraries_used: List[str]) -> str:
|
||||||
|
"""Runs code in the safest available environment.
|
||||||
|
|
||||||
|
Attempts to run code in Docker if available, falls back to a restricted
|
||||||
|
sandbox if Docker is not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to execute as a string.
|
||||||
|
libraries_used: A list of Python library names to install before execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the executed code as a string.
|
||||||
|
"""
|
||||||
|
if self._check_docker_available():
|
||||||
|
return self.run_code_in_docker(code, libraries_used)
|
||||||
|
return self.run_code_in_restricted_sandbox(code)
|
||||||
|
|
||||||
|
def run_code_in_docker(self, code: str, libraries_used: List[str]) -> str:
|
||||||
|
"""Runs Python code in a Docker container for safe isolation.
|
||||||
|
|
||||||
|
Creates a Docker container, installs the required libraries, executes the code,
|
||||||
|
and then cleans up by stopping and removing the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to execute as a string.
|
||||||
|
libraries_used: A list of Python library names to install before execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the executed code as a string, or an error message if execution failed.
|
||||||
|
"""
|
||||||
|
Printer.print("Running code in Docker environment", color="bold_blue")
|
||||||
|
self._verify_docker_image()
|
||||||
|
container = self._init_docker_container()
|
||||||
|
self._install_libraries(container, libraries_used)
|
||||||
|
|
||||||
|
exec_result = container.exec_run(["python3", "-c", code])
|
||||||
|
|
||||||
|
container.stop()
|
||||||
|
container.remove()
|
||||||
|
|
||||||
|
if exec_result.exit_code != 0:
|
||||||
|
return f"Something went wrong while running the code: \n{exec_result.output.decode('utf-8')}"
|
||||||
|
return exec_result.output.decode("utf-8")
|
||||||
|
|
||||||
|
def run_code_in_restricted_sandbox(self, code: str) -> str:
|
||||||
|
"""Runs Python code in a restricted sandbox environment.
|
||||||
|
|
||||||
|
Executes the code with restricted access to potentially dangerous modules and
|
||||||
|
built-in functions for basic safety when Docker is not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to execute as a string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the 'result' variable from the executed code,
|
||||||
|
or an error message if execution failed.
|
||||||
|
"""
|
||||||
|
Printer.print("Running code in restricted sandbox", color="yellow")
|
||||||
|
exec_locals = {}
|
||||||
|
try:
|
||||||
|
SandboxPython.exec(code=code, locals=exec_locals)
|
||||||
|
return exec_locals.get("result", "No result variable found.")
|
||||||
|
except Exception as e:
|
||||||
|
return f"An error occurred: {e!s}"
|
||||||
|
|
||||||
|
def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str:
|
||||||
|
"""Runs code directly on the host machine without any safety restrictions.
|
||||||
|
|
||||||
|
WARNING: This mode is unsafe and should only be used in trusted environments
|
||||||
|
with code from trusted sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The Python code to execute as a string.
|
||||||
|
libraries_used: A list of Python library names to install before execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the 'result' variable from the executed code,
|
||||||
|
or an error message if execution failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
|
||||||
|
# Install libraries on the host machine
|
||||||
|
for library in libraries_used:
|
||||||
|
os.system(f"pip install {library}")
|
||||||
|
|
||||||
|
# Execute the code
|
||||||
|
try:
|
||||||
|
exec_locals = {}
|
||||||
|
exec(code, {}, exec_locals)
|
||||||
|
return exec_locals.get("result", "No result variable found.")
|
||||||
|
except Exception as e:
|
||||||
|
return f"An error occurred: {e!s}"
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
# ComposioTool Documentation
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This tools is a wrapper around the composio toolset and gives your agent access to a wide variety of tools from the composio SDK.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To incorporate this tool into your project, follow the installation instructions below:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install composio-core
|
||||||
|
pip install 'crewai[tools]'
|
||||||
|
```
|
||||||
|
|
||||||
|
after the installation is complete, either run `composio login` or export your composio API key as `COMPOSIO_API_KEY`.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
The following example demonstrates how to initialize the tool and execute a github action:
|
||||||
|
|
||||||
|
1. Initialize toolset
|
||||||
|
|
||||||
|
```python
|
||||||
|
from composio import App
|
||||||
|
from crewai_tools import ComposioTool
|
||||||
|
from crewai import Agent, Task
|
||||||
|
|
||||||
|
|
||||||
|
tools = [ComposioTool.from_action(action=Action.GITHUB_ACTIVITY_STAR_REPO_FOR_AUTHENTICATED_USER)]
|
||||||
|
```
|
||||||
|
|
||||||
|
If you don't know what action you want to use, use `from_app` and `tags` filter to get relevant actions
|
||||||
|
|
||||||
|
```python
|
||||||
|
tools = ComposioTool.from_app(App.GITHUB, tags=["important"])
|
||||||
|
```
|
||||||
|
|
||||||
|
or use `use_case` to search relevant actions
|
||||||
|
|
||||||
|
```python
|
||||||
|
tools = ComposioTool.from_app(App.GITHUB, use_case="Star a github repository")
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Define agent
|
||||||
|
|
||||||
|
```python
|
||||||
|
crewai_agent = Agent(
|
||||||
|
role="Github Agent",
|
||||||
|
goal="You take action on Github using Github APIs",
|
||||||
|
backstory=(
|
||||||
|
"You are AI agent that is responsible for taking actions on Github "
|
||||||
|
"on users behalf. You need to take action on Github using Github APIs"
|
||||||
|
),
|
||||||
|
verbose=True,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Execute task
|
||||||
|
|
||||||
|
```python
|
||||||
|
task = Task(
|
||||||
|
description="Star a repo ComposioHQ/composio on GitHub",
|
||||||
|
agent=crewai_agent,
|
||||||
|
expected_output="if the star happened",
|
||||||
|
)
|
||||||
|
|
||||||
|
task.execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
* More detailed list of tools can be found [here](https://app.composio.dev)
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
Composio tools wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
import typing_extensions as te
|
||||||
|
|
||||||
|
|
||||||
|
class ComposioTool(BaseTool):
|
||||||
|
"""Wrapper for composio tools."""
|
||||||
|
|
||||||
|
composio_action: t.Callable
|
||||||
|
env_vars: t.List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="COMPOSIO_API_KEY",
|
||||||
|
description="API key for Composio services",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||||
|
"""Run the composio action with given arguments."""
|
||||||
|
return self.composio_action(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_connected_account(tool: t.Any, toolset: t.Any) -> None:
|
||||||
|
"""Check if connected account is required and if required it exists or not."""
|
||||||
|
from composio import Action
|
||||||
|
from composio.client.collections import ConnectedAccountModel
|
||||||
|
|
||||||
|
tool = t.cast(Action, tool)
|
||||||
|
if tool.no_auth:
|
||||||
|
return
|
||||||
|
|
||||||
|
connections = t.cast(
|
||||||
|
t.List[ConnectedAccountModel],
|
||||||
|
toolset.client.connected_accounts.get(),
|
||||||
|
)
|
||||||
|
if tool.app not in [connection.appUniqueId for connection in connections]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No connected account found for app `{tool.app}`; "
|
||||||
|
f"Run `composio add {tool.app}` to fix this"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_action(
|
||||||
|
cls,
|
||||||
|
action: t.Any,
|
||||||
|
**kwargs: t.Any,
|
||||||
|
) -> te.Self:
|
||||||
|
"""Wrap a composio tool as crewAI tool."""
|
||||||
|
|
||||||
|
from composio import Action, ComposioToolSet
|
||||||
|
from composio.constants import DEFAULT_ENTITY_ID
|
||||||
|
from composio.utils.shared import json_schema_to_model
|
||||||
|
|
||||||
|
toolset = ComposioToolSet()
|
||||||
|
if not isinstance(action, Action):
|
||||||
|
action = Action(action)
|
||||||
|
|
||||||
|
action = t.cast(Action, action)
|
||||||
|
cls._check_connected_account(
|
||||||
|
tool=action,
|
||||||
|
toolset=toolset,
|
||||||
|
)
|
||||||
|
|
||||||
|
(action_schema,) = toolset.get_action_schemas(actions=[action])
|
||||||
|
schema = action_schema.model_dump(exclude_none=True)
|
||||||
|
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
|
||||||
|
|
||||||
|
def function(**kwargs: t.Any) -> t.Dict:
|
||||||
|
"""Wrapper function for composio action."""
|
||||||
|
return toolset.execute_action(
|
||||||
|
action=Action(schema["name"]),
|
||||||
|
params=kwargs,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
function.__name__ = schema["name"]
|
||||||
|
function.__doc__ = schema["description"]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=schema["name"],
|
||||||
|
description=schema["description"],
|
||||||
|
args_schema=json_schema_to_model(
|
||||||
|
action_schema.parameters.model_dump(
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
composio_action=function,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_app(
|
||||||
|
cls,
|
||||||
|
*apps: t.Any,
|
||||||
|
tags: t.Optional[t.List[str]] = None,
|
||||||
|
use_case: t.Optional[str] = None,
|
||||||
|
**kwargs: t.Any,
|
||||||
|
) -> t.List[te.Self]:
|
||||||
|
"""Create toolset from an app."""
|
||||||
|
if len(apps) == 0:
|
||||||
|
raise ValueError("You need to provide at least one app name")
|
||||||
|
|
||||||
|
if use_case is None and tags is None:
|
||||||
|
raise ValueError("Both `use_case` and `tags` cannot be `None`")
|
||||||
|
|
||||||
|
if use_case is not None and tags is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use both `use_case` and `tags` to filter the actions"
|
||||||
|
)
|
||||||
|
|
||||||
|
from composio import ComposioToolSet
|
||||||
|
|
||||||
|
toolset = ComposioToolSet()
|
||||||
|
if use_case is not None:
|
||||||
|
return [
|
||||||
|
cls.from_action(action=action, **kwargs)
|
||||||
|
for action in toolset.find_actions_by_use_case(*apps, use_case=use_case)
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
cls.from_action(action=action, **kwargs)
|
||||||
|
for action in toolset.find_actions_by_tags(*apps, tags=tags)
|
||||||
|
]
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
# ContextualAICreateAgentTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This tool is designed to integrate Contextual AI's enterprise-grade RAG agents with CrewAI. This tool enables you to create a new Contextual RAG agent. It uploads your documents to create a datastore and returns the Contextual agent ID and datastore ID.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
To incorporate this tool into your project, follow the installation instructions below:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install 'crewai[tools]' contextual-client
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: You'll need a Contextual AI API key. Sign up at [app.contextual.ai](https://app.contextual.ai) to get your free API key.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import ContextualAICreateAgentTool
|
||||||
|
|
||||||
|
# Initialize the tool
|
||||||
|
tool = ContextualAICreateAgentTool(api_key="your_api_key_here")
|
||||||
|
|
||||||
|
# Create agent with documents
|
||||||
|
result = tool._run(
|
||||||
|
agent_name="Financial Analysis Agent",
|
||||||
|
agent_description="Agent for analyzing financial documents",
|
||||||
|
datastore_name="Financial Reports",
|
||||||
|
document_paths=["/path/to/report1.pdf", "/path/to/report2.pdf"],
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- `api_key`: Your Contextual AI API key
|
||||||
|
- `agent_name`: Name for the new agent
|
||||||
|
- `agent_description`: Description of the agent's purpose
|
||||||
|
- `datastore_name`: Name for the document datastore
|
||||||
|
- `document_paths`: List of file paths to upload
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
|
||||||
|
```
|
||||||
|
Successfully created agent 'Research Analyst' with ID: {created_agent_ID} and datastore ID: {created_datastore_ID}. Uploaded 5 documents.
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use `ContextualAIQueryTool` with the returned IDs to query the knowledge base and retrieve relevant information from your documents.
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
- **Complete Pipeline Setup**: Creates datastore, uploads documents, and configures agent in one operation
|
||||||
|
- **Document Processing**: Leverages Contextual AI's powerful parser to ingest complex PDFs and documents
|
||||||
|
- **Vector Storage**: Use Contextual AI's datastore for large document collections
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
- Set up new RAG agents from scratch with complete automation
|
||||||
|
- Upload and organize document collections into structured datastores
|
||||||
|
- Create specialized domain agents for legal, financial, technical, or research workflows
|
||||||
|
|
||||||
|
For more detailed information about Contextual AI's capabilities, visit the [official documentation](https://docs.contextual.ai).
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
from typing import Any, List, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualAICreateAgentSchema(BaseModel):
|
||||||
|
"""Schema for contextual create agent tool."""
|
||||||
|
|
||||||
|
agent_name: str = Field(..., description="Name for the new agent")
|
||||||
|
agent_description: str = Field(..., description="Description for the new agent")
|
||||||
|
datastore_name: str = Field(..., description="Name for the new datastore")
|
||||||
|
document_paths: List[str] = Field(..., description="List of file paths to upload")
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualAICreateAgentTool(BaseTool):
|
||||||
|
"""Tool to create Contextual AI RAG agents with documents."""
|
||||||
|
|
||||||
|
name: str = "Contextual AI Create Agent Tool"
|
||||||
|
description: str = (
|
||||||
|
"Create a new Contextual AI RAG agent with documents and datastore"
|
||||||
|
)
|
||||||
|
args_schema: Type[BaseModel] = ContextualAICreateAgentSchema
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
contextual_client: Any = None
|
||||||
|
package_dependencies: List[str] = ["contextual-client"]
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
try:
|
||||||
|
from contextual import ContextualAI
|
||||||
|
|
||||||
|
self.contextual_client = ContextualAI(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"contextual-client package is required. Install it with: pip install contextual-client"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
agent_description: str,
|
||||||
|
datastore_name: str,
|
||||||
|
document_paths: List[str],
|
||||||
|
) -> str:
|
||||||
|
"""Create a complete RAG pipeline with documents."""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Create datastore
|
||||||
|
datastore = self.contextual_client.datastores.create(name=datastore_name)
|
||||||
|
datastore_id = datastore.id
|
||||||
|
|
||||||
|
# Upload documents
|
||||||
|
document_ids = []
|
||||||
|
for doc_path in document_paths:
|
||||||
|
if not os.path.exists(doc_path):
|
||||||
|
raise FileNotFoundError(f"Document not found: {doc_path}")
|
||||||
|
|
||||||
|
with open(doc_path, "rb") as f:
|
||||||
|
ingestion_result = (
|
||||||
|
self.contextual_client.datastores.documents.ingest(
|
||||||
|
datastore_id, file=f
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document_ids.append(ingestion_result.id)
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
agent = self.contextual_client.agents.create(
|
||||||
|
name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
datastore_ids=[datastore_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Successfully created agent '{agent_name}' with ID: {agent.id} and datastore ID: {datastore_id}. Uploaded {len(document_ids)} documents."
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to create agent with documents: {e!s}"
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# ContextualAIParseTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This tool is designed to integrate Contextual AI's enterprise-grade document parsing capabilities with CrewAI, enabling you to leverage advanced AI-powered document understanding for complex layouts, tables, and figures. Use this tool to extract structured content from your documents using Contextual AI's powerful document parser.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
To incorporate this tool into your project, follow the installation instructions below:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install 'crewai[tools]' contextual-client
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: You'll need a Contextual AI API key. Sign up at [app.contextual.ai](https://app.contextual.ai) to get your free API key.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import ContextualAIParseTool
|
||||||
|
|
||||||
|
tool = ContextualAIParseTool(api_key="your_api_key_here")
|
||||||
|
|
||||||
|
result = tool._run(
|
||||||
|
file_path="/path/to/document.pdf",
|
||||||
|
parse_mode="standard",
|
||||||
|
page_range="0-5",
|
||||||
|
output_types=["markdown-per-page"]
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
The result will show the parsed contents of your document. For example:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"file_name": "attention_is_all_you_need.pdf",
|
||||||
|
"status": "completed",
|
||||||
|
"pages": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"markdown": "Provided proper attribution ...
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"markdown": "## 1 Introduction ...
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
## Parameters
|
||||||
|
- `api_key`: Your Contextual AI API key
|
||||||
|
- `file_path`: Path to document to parse
|
||||||
|
- `parse_mode`: Parsing mode (default: "standard")
|
||||||
|
- `figure_caption_mode`: Figure caption handling (default: "concise")
|
||||||
|
- `enable_document_hierarchy`: Enable hierarchy detection (default: True)
|
||||||
|
- `page_range`: Pages to parse (e.g., "0-5", None for all)
|
||||||
|
- `output_types`: Output formats (default: ["markdown-per-page"])
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
- **Advanced Document Understanding**: Handles complex PDF layouts, tables, and multi-column documents
|
||||||
|
- **Figure and Table Extraction**: Intelligent extraction of figures, charts, and tabular data
|
||||||
|
- **Page Range Selection**: Parse specific pages or entire documents
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
- Extract structured content from complex PDFs and research papers
|
||||||
|
- Parse financial reports, legal documents, and technical manuals
|
||||||
|
- Convert documents to markdown for further processing in RAG pipelines
|
||||||
|
|
||||||
|
For more detailed information about Contextual AI's capabilities, visit the [official documentation](https://docs.contextual.ai).
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualAIParseSchema(BaseModel):
|
||||||
|
"""Schema for contextual parse tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(..., description="Path to the document to parse")
|
||||||
|
parse_mode: str = Field(default="standard", description="Parsing mode")
|
||||||
|
figure_caption_mode: str = Field(
|
||||||
|
default="concise", description="Figure caption mode"
|
||||||
|
)
|
||||||
|
enable_document_hierarchy: bool = Field(
|
||||||
|
default=True, description="Enable document hierarchy"
|
||||||
|
)
|
||||||
|
page_range: Optional[str] = Field(
|
||||||
|
default=None, description="Page range to parse (e.g., '0-5')"
|
||||||
|
)
|
||||||
|
output_types: List[str] = Field(
|
||||||
|
default=["markdown-per-page"], description="List of output types"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualAIParseTool(BaseTool):
|
||||||
|
"""Tool to parse documents using Contextual AI's parser."""
|
||||||
|
|
||||||
|
name: str = "Contextual AI Document Parser"
|
||||||
|
description: str = "Parse documents using Contextual AI's advanced document parser"
|
||||||
|
args_schema: Type[BaseModel] = ContextualAIParseSchema
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
package_dependencies: List[str] = ["contextual-client"]
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
parse_mode: str = "standard",
|
||||||
|
figure_caption_mode: str = "concise",
|
||||||
|
enable_document_hierarchy: bool = True,
|
||||||
|
page_range: Optional[str] = None,
|
||||||
|
output_types: List[str] = ["markdown-per-page"],
|
||||||
|
) -> str:
|
||||||
|
"""Parse a document using Contextual AI's parser."""
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Document not found: {file_path}")
|
||||||
|
|
||||||
|
base_url = "https://api.contextual.ai/v1"
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Submit parse job
|
||||||
|
url = f"{base_url}/parse"
|
||||||
|
config = {
|
||||||
|
"parse_mode": parse_mode,
|
||||||
|
"figure_caption_mode": figure_caption_mode,
|
||||||
|
"enable_document_hierarchy": enable_document_hierarchy,
|
||||||
|
}
|
||||||
|
|
||||||
|
if page_range:
|
||||||
|
config["page_range"] = page_range
|
||||||
|
|
||||||
|
with open(file_path, "rb") as fp:
|
||||||
|
file = {"raw_file": fp}
|
||||||
|
result = requests.post(url, headers=headers, data=config, files=file)
|
||||||
|
response = json.loads(result.text)
|
||||||
|
job_id = response["job_id"]
|
||||||
|
|
||||||
|
# Monitor job status
|
||||||
|
status_url = f"{base_url}/parse/jobs/{job_id}/status"
|
||||||
|
while True:
|
||||||
|
result = requests.get(status_url, headers=headers)
|
||||||
|
parse_response = json.loads(result.text)["status"]
|
||||||
|
|
||||||
|
if parse_response == "completed":
|
||||||
|
break
|
||||||
|
if parse_response == "failed":
|
||||||
|
raise RuntimeError("Document parsing failed")
|
||||||
|
|
||||||
|
sleep(5)
|
||||||
|
|
||||||
|
# Get parse results
|
||||||
|
results_url = f"{base_url}/parse/jobs/{job_id}/results"
|
||||||
|
result = requests.get(
|
||||||
|
results_url,
|
||||||
|
headers=headers,
|
||||||
|
params={"output_types": ",".join(output_types)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps(json.loads(result.text), indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to parse document: {e!s}"
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user