feat: merge latest changes from crewAI-tools main into packages/tools

- Merged upstream changes from crewAI-tools main branch
- Resolved conflicts due to monorepo structure (crewai_tools -> src/crewai_tools)
- Removed deprecated embedchain adapters
- Added new RAG loaders and crewai_rag_adapter
- Consolidated dependencies in pyproject.toml

Fixed critical linting issues:
- Added ClassVar annotations for mutable class attributes
- Added timeouts to requests calls (30s default)
- Fixed exception handling with proper 'from' clauses
- Added noqa comments for public API functions (backward compatibility)
- Updated ruff config to ignore expected patterns:
  - F401 in __init__ files (intentional re-exports)
  - S101 in test files (assertions are expected)
  - S607 for subprocess calls (uv/pip commands are safe)

Remaining issues are from upstream code and will be addressed in separate PRs.
This commit is contained in:
Greyson LaLonde
2025-09-19 00:08:27 -04:00
156 changed files with 4530 additions and 2718 deletions

View File

@@ -9,12 +9,18 @@ authors = [
requires-python = ">=3.10,<3.14" requires-python = ">=3.10,<3.14"
dependencies = [ dependencies = [
"crewai-core", "crewai-core",
"click>=8.1.8",
"lancedb>=0.5.4", "lancedb>=0.5.4",
"pytube>=15.0.0", "pytube>=15.0.0",
"requests>=2.31.0", "requests>=2.31.0",
"docker>=7.1.0", "docker>=7.1.0",
"tiktoken>=0.8.0", "tiktoken>=0.8.0",
"stagehand>=0.4.1", "stagehand>=0.4.1",
"portalocker==2.7.0",
"beautifulsoup4>=4.13.4",
"pypdf>=5.9.0",
"python-docx>=1.2.0",
"youtube-transcript-api>=1.2.2",
] ]
[project.urls] [project.urls]
@@ -24,9 +30,6 @@ Documentation = "https://docs.crewai.com"
[project.optional-dependencies] [project.optional-dependencies]
embedchain = [
"embedchain>=0.1.114",
]
scrapfly-sdk = [ scrapfly-sdk = [
"scrapfly-sdk>=0.8.19", "scrapfly-sdk>=0.8.19",
] ]
@@ -124,6 +127,12 @@ oxylabs = [
mongodb = [ mongodb = [
"pymongo>=4.13" "pymongo>=4.13"
] ]
mysql = [
"pymysql>=1.1.1"
]
postgresql = [
"psycopg2-binary>=2.9.10"
]
bedrock = [ bedrock = [
"beautifulsoup4>=4.13.4", "beautifulsoup4>=4.13.4",
"bedrock-agentcore>=0.1.0", "bedrock-agentcore>=0.1.0",
@@ -135,6 +144,9 @@ contextual = [
"nest-asyncio>=1.6.0", "nest-asyncio>=1.6.0",
] ]
[tool.hatch.metadata]
allow-direct-references = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
@@ -149,3 +161,12 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/crewai_tools"] packages = ["src/crewai_tools"]
[dependency-groups]
dev = [
"pytest-asyncio>=0.25.2",
"pytest>=8.0.0",
"pytest-recording>=0.13.3",
"mypy>=1.18.1",
"ruff>=0.13.0",
]

View File

@@ -59,6 +59,7 @@ from .tools import (
OxylabsAmazonSearchScraperTool, OxylabsAmazonSearchScraperTool,
OxylabsGoogleSearchScraperTool, OxylabsGoogleSearchScraperTool,
OxylabsUniversalScraperTool, OxylabsUniversalScraperTool,
ParallelSearchTool,
PatronusEvalTool, PatronusEvalTool,
PatronusLocalEvaluatorTool, PatronusLocalEvaluatorTool,
PatronusPredefinedCriteriaEvalTool, PatronusPredefinedCriteriaEvalTool,
@@ -96,5 +97,4 @@ from .tools import (
YoutubeChannelSearchTool, YoutubeChannelSearchTool,
YoutubeVideoSearchTool, YoutubeVideoSearchTool,
ZapierActionTools, ZapierActionTools,
ParallelSearchTool,
) )

View File

@@ -0,0 +1,267 @@
"""Adapter for CrewAI's native RAG system."""
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
from crewai_tools.tools.rag.rag_tool import Adapter
from pydantic import PrivateAttr
from typing_extensions import Unpack
ContentItem: TypeAlias = str | Path | dict[str, Any]
class AddDocumentParams(TypedDict, total=False):
"""Parameters for adding documents to the RAG system."""
data_type: DataType
metadata: dict[str, Any]
website: str
url: str
file_path: str | Path
github_url: str
youtube_url: str
directory_path: str | Path
class CrewAIRagAdapter(Adapter):
"""Adapter that uses CrewAI's native RAG system.
Supports custom vector database configuration through the config parameter.
"""
collection_name: str = "default"
summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
config: RagConfigType | None = None
_client: BaseClient | None = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None:
"""Initialize the CrewAI RAG client after model initialization."""
if self.config is not None:
self._client = create_client(self.config)
else:
self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name)
def query(
self,
question: str,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str:
"""Query the knowledge base with a question.
Args:
question: The question to ask
similarity_threshold: Minimum similarity score for results (default: 0.6)
limit: Maximum number of results to return (default: 5)
Returns:
Relevant content from the knowledge base
"""
search_limit = limit if limit is not None else self.limit
search_threshold = (
similarity_threshold
if similarity_threshold is not None
else self.similarity_threshold
)
results: list[SearchResult] = self._client.search(
collection_name=self.collection_name,
query=question,
limit=search_limit,
score_threshold=search_threshold,
)
if not results:
return "No relevant content found."
contents: list[str] = []
for result in results:
content: str = result.get("content", "")
if content:
contents.append(content)
return "\n\n".join(contents)
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
"""Add content to the knowledge base.
This method handles various input types and converts them to documents
for the vector database. It supports the data_type parameter for
compatibility with existing tools.
Args:
*args: Content items to add (strings, paths, or document dicts)
**kwargs: Additional parameters including data_type, metadata, etc.
"""
import os
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.data_types import DataType, DataTypes
from crewai_tools.rag.source_content import SourceContent
documents: list[BaseRecord] = []
data_type: DataType | None = kwargs.get("data_type")
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
for arg in args:
source_ref: str
if isinstance(arg, dict):
source_ref = str(arg.get("source", arg.get("content", "")))
else:
source_ref = str(arg)
if not data_type:
data_type = DataTypes.from_content(source_ref)
if data_type == DataType.DIRECTORY:
if not os.path.isdir(source_ref):
raise ValueError(f"Directory does not exist: {source_ref}")
# Define binary and non-text file extensions to skip
binary_extensions = {
".pyc",
".pyo",
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".ico",
".svg",
".webp",
".pdf",
".zip",
".tar",
".gz",
".bz2",
".7z",
".rar",
".exe",
".dll",
".so",
".dylib",
".bin",
".dat",
".db",
".sqlite",
".class",
".jar",
".war",
".ear",
}
for root, dirs, files in os.walk(source_ref):
dirs[:] = [d for d in dirs if not d.startswith(".")]
for filename in files:
if filename.startswith("."):
continue
# Skip binary files based on extension
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in binary_extensions:
continue
# Skip __pycache__ directories
if "__pycache__" in root:
continue
file_path: str = os.path.join(root, filename)
try:
file_data_type: DataType = DataTypes.from_content(file_path)
file_loader = file_data_type.get_loader()
file_chunker = file_data_type.get_chunker()
file_source = SourceContent(file_path)
file_result: LoaderResult = file_loader.load(file_source)
file_chunks = file_chunker.chunk(file_result.content)
for chunk_idx, file_chunk in enumerate(file_chunks):
file_metadata: dict[str, Any] = base_metadata.copy()
file_metadata.update(file_result.metadata)
file_metadata["data_type"] = str(file_data_type)
file_metadata["file_path"] = file_path
file_metadata["chunk_index"] = chunk_idx
file_metadata["total_chunks"] = len(file_chunks)
if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
).hexdigest()
documents.append(
{
"doc_id": chunk_id,
"content": file_chunk,
"metadata": sanitize_metadata_for_chromadb(
file_metadata
),
}
)
except Exception:
# Silently skip files that can't be processed
continue
else:
metadata: dict[str, Any] = base_metadata.copy()
if data_type in [
DataType.PDF_FILE,
DataType.TEXT_FILE,
DataType.DOCX,
DataType.CSV,
DataType.JSON,
DataType.XML,
DataType.MDX,
]:
if not os.path.isfile(source_ref):
raise FileNotFoundError(f"File does not exist: {source_ref}")
loader = data_type.get_loader()
chunker = data_type.get_chunker()
source_content = SourceContent(source_ref)
loader_result: LoaderResult = loader.load(source_content)
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
chunk_metadata: dict[str, Any] = metadata.copy()
chunk_metadata.update(loader_result.metadata)
chunk_metadata["data_type"] = str(data_type)
chunk_metadata["chunk_index"] = i
chunk_metadata["total_chunks"] = len(chunks)
chunk_metadata["source"] = source_ref
if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
f"{loader_result.doc_id}_{i}_{chunk}".encode()
).hexdigest()
documents.append(
{
"doc_id": chunk_id,
"content": chunk,
"metadata": sanitize_metadata_for_chromadb(chunk_metadata),
}
)
if documents:
self._client.add_documents(
collection_name=self.collection_name, documents=documents
)

View File

@@ -1,34 +0,0 @@
from typing import Any
from crewai_tools.tools.rag.rag_tool import Adapter
try:
from embedchain import App
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
class EmbedchainAdapter(Adapter):
embedchain_app: Any # Will be App when embedchain is available
summarize: bool = False
def __init__(self, **data):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**data)
def query(self, question: str) -> str:
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize)
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.embedchain_app.add(*args, **kwargs)

View File

@@ -1,11 +1,12 @@
import os
import json import json
import requests import os
import warnings
from typing import List, Any, Dict, Literal, Optional, Union, get_origin, Type, cast
from pydantic import Field, create_model
from crewai.tools import BaseTool
import re import re
import warnings
from typing import Any, Literal, Optional, Union, cast, get_origin
import requests
from crewai.tools import BaseTool
from pydantic import Field, create_model
def get_enterprise_api_base_url() -> str: def get_enterprise_api_base_url() -> str:
@@ -13,6 +14,7 @@ def get_enterprise_api_base_url() -> str:
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com") base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
return f"{base_url}/crewai_plus/api/v1/integrations" return f"{base_url}/crewai_plus/api/v1/integrations"
ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url() ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url()
@@ -23,7 +25,7 @@ class EnterpriseActionTool(BaseTool):
default="", description="The enterprise action token" default="", description="The enterprise action token"
) )
action_name: str = Field(default="", description="The name of the action") action_name: str = Field(default="", description="The name of the action")
action_schema: Dict[str, Any] = Field( action_schema: dict[str, Any] = Field(
default={}, description="The schema of the action" default={}, description="The schema of the action"
) )
enterprise_api_base_url: str = Field( enterprise_api_base_url: str = Field(
@@ -36,8 +38,8 @@ class EnterpriseActionTool(BaseTool):
description: str, description: str,
enterprise_action_token: str, enterprise_action_token: str,
action_name: str, action_name: str,
action_schema: Dict[str, Any], action_schema: dict[str, Any],
enterprise_api_base_url: Optional[str] = None, enterprise_api_base_url: str | None = None,
): ):
self._model_registry = {} self._model_registry = {}
self._base_name = self._sanitize_name(name) self._base_name = self._sanitize_name(name)
@@ -86,7 +88,9 @@ class EnterpriseActionTool(BaseTool):
self.enterprise_action_token = enterprise_action_token self.enterprise_action_token = enterprise_action_token
self.action_name = action_name self.action_name = action_name
self.action_schema = action_schema self.action_schema = action_schema
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url() self.enterprise_api_base_url = (
enterprise_api_base_url or get_enterprise_api_base_url()
)
def _sanitize_name(self, name: str) -> str: def _sanitize_name(self, name: str) -> str:
"""Sanitize names to create proper Python class names.""" """Sanitize names to create proper Python class names."""
@@ -95,8 +99,8 @@ class EnterpriseActionTool(BaseTool):
return "".join(word.capitalize() for word in parts if word) return "".join(word.capitalize() for word in parts if word)
def _extract_schema_info( def _extract_schema_info(
self, action_schema: Dict[str, Any] self, action_schema: dict[str, Any]
) -> tuple[Dict[str, Any], List[str]]: ) -> tuple[dict[str, Any], list[str]]:
"""Extract schema properties and required fields from action schema.""" """Extract schema properties and required fields from action schema."""
schema_props = ( schema_props = (
action_schema.get("function", {}) action_schema.get("function", {})
@@ -108,7 +112,7 @@ class EnterpriseActionTool(BaseTool):
) )
return schema_props, required return schema_props, required
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]: def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
"""Process a JSON schema and return appropriate Python type.""" """Process a JSON schema and return appropriate Python type."""
if "anyOf" in schema: if "anyOf" in schema:
any_of_types = schema["anyOf"] any_of_types = schema["anyOf"]
@@ -118,7 +122,7 @@ class EnterpriseActionTool(BaseTool):
if non_null_types: if non_null_types:
base_type = self._process_schema_type(non_null_types[0], type_name) base_type = self._process_schema_type(non_null_types[0], type_name)
return Optional[base_type] if is_nullable else base_type return Optional[base_type] if is_nullable else base_type
return cast(Type[Any], Optional[str]) return cast(type[Any], Optional[str])
if "oneOf" in schema: if "oneOf" in schema:
return self._process_schema_type(schema["oneOf"][0], type_name) return self._process_schema_type(schema["oneOf"][0], type_name)
@@ -137,14 +141,16 @@ class EnterpriseActionTool(BaseTool):
if json_type == "array": if json_type == "array":
items_schema = schema.get("items", {"type": "string"}) items_schema = schema.get("items", {"type": "string"})
item_type = self._process_schema_type(items_schema, f"{type_name}Item") item_type = self._process_schema_type(items_schema, f"{type_name}Item")
return List[item_type] return list[item_type]
if json_type == "object": if json_type == "object":
return self._create_nested_model(schema, type_name) return self._create_nested_model(schema, type_name)
return self._map_json_type_to_python(json_type) return self._map_json_type_to_python(json_type)
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]: def _create_nested_model(
self, schema: dict[str, Any], model_name: str
) -> type[Any]:
"""Create a nested Pydantic model for complex objects.""" """Create a nested Pydantic model for complex objects."""
full_model_name = f"{self._base_name}{model_name}" full_model_name = f"{self._base_name}{model_name}"
@@ -183,21 +189,19 @@ class EnterpriseActionTool(BaseTool):
return dict return dict
def _create_field_definition( def _create_field_definition(
self, field_type: Type[Any], is_required: bool, description: str self, field_type: type[Any], is_required: bool, description: str
) -> tuple: ) -> tuple:
"""Create Pydantic field definition based on type and requirement.""" """Create Pydantic field definition based on type and requirement."""
if is_required: if is_required:
return (field_type, Field(description=description)) return (field_type, Field(description=description))
else: if get_origin(field_type) is Union:
if get_origin(field_type) is Union: return (field_type, Field(default=None, description=description))
return (field_type, Field(default=None, description=description)) return (
else: Optional[field_type],
return ( Field(default=None, description=description),
Optional[field_type], )
Field(default=None, description=description),
)
def _map_json_type_to_python(self, json_type: str) -> Type[Any]: def _map_json_type_to_python(self, json_type: str) -> type[Any]:
"""Map basic JSON schema types to Python types.""" """Map basic JSON schema types to Python types."""
type_mapping = { type_mapping = {
"string": str, "string": str,
@@ -210,7 +214,7 @@ class EnterpriseActionTool(BaseTool):
} }
return type_mapping.get(json_type, str) return type_mapping.get(json_type, str)
def _get_required_nullable_fields(self) -> List[str]: def _get_required_nullable_fields(self) -> list[str]:
"""Get a list of required nullable fields from the action schema.""" """Get a list of required nullable fields from the action schema."""
schema_props, required = self._extract_schema_info(self.action_schema) schema_props, required = self._extract_schema_info(self.action_schema)
@@ -222,7 +226,7 @@ class EnterpriseActionTool(BaseTool):
return required_nullable_fields return required_nullable_fields
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool: def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
"""Check if a schema represents a nullable type.""" """Check if a schema represents a nullable type."""
if "anyOf" in schema: if "anyOf" in schema:
return any(t.get("type") == "null" for t in schema["anyOf"]) return any(t.get("type") == "null" for t in schema["anyOf"])
@@ -242,8 +246,9 @@ class EnterpriseActionTool(BaseTool):
if field_name not in cleaned_kwargs: if field_name not in cleaned_kwargs:
cleaned_kwargs[field_name] = None cleaned_kwargs[field_name] = None
api_url = (
api_url = f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute" f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
)
headers = { headers = {
"Authorization": f"Bearer {self.enterprise_action_token}", "Authorization": f"Bearer {self.enterprise_action_token}",
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -262,7 +267,7 @@ class EnterpriseActionTool(BaseTool):
return json.dumps(data, indent=2) return json.dumps(data, indent=2)
except Exception as e: except Exception as e:
return f"Error executing action {self.action_name}: {str(e)}" return f"Error executing action {self.action_name}: {e!s}"
class EnterpriseActionKitToolAdapter: class EnterpriseActionKitToolAdapter:
@@ -271,15 +276,17 @@ class EnterpriseActionKitToolAdapter:
def __init__( def __init__(
self, self,
enterprise_action_token: str, enterprise_action_token: str,
enterprise_api_base_url: Optional[str] = None, enterprise_api_base_url: str | None = None,
): ):
"""Initialize the adapter with an enterprise action token.""" """Initialize the adapter with an enterprise action token."""
self._set_enterprise_action_token(enterprise_action_token) self._set_enterprise_action_token(enterprise_action_token)
self._actions_schema = {} self._actions_schema = {}
self._tools = None self._tools = None
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url() self.enterprise_api_base_url = (
enterprise_api_base_url or get_enterprise_api_base_url()
)
def tools(self) -> List[BaseTool]: def tools(self) -> list[BaseTool]:
"""Get the list of tools created from enterprise actions.""" """Get the list of tools created from enterprise actions."""
if self._tools is None: if self._tools is None:
self._fetch_actions() self._fetch_actions()
@@ -289,13 +296,10 @@ class EnterpriseActionKitToolAdapter:
def _fetch_actions(self): def _fetch_actions(self):
"""Fetch available actions from the API.""" """Fetch available actions from the API."""
try: try:
actions_url = f"{self.enterprise_api_base_url}/actions" actions_url = f"{self.enterprise_api_base_url}/actions"
headers = {"Authorization": f"Bearer {self.enterprise_action_token}"} headers = {"Authorization": f"Bearer {self.enterprise_action_token}"}
response = requests.get( response = requests.get(actions_url, headers=headers, timeout=30)
actions_url, headers=headers, timeout=30
)
response.raise_for_status() response.raise_for_status()
raw_data = response.json() raw_data = response.json()
@@ -306,7 +310,7 @@ class EnterpriseActionKitToolAdapter:
parsed_schema = {} parsed_schema = {}
action_categories = raw_data["actions"] action_categories = raw_data["actions"]
for integration_type, action_list in action_categories.items(): for action_list in action_categories.values():
if isinstance(action_list, list): if isinstance(action_list, list):
for action in action_list: for action in action_list:
action_name = action.get("name") action_name = action.get("name")
@@ -314,8 +318,10 @@ class EnterpriseActionKitToolAdapter:
action_schema = { action_schema = {
"function": { "function": {
"name": action_name, "name": action_name,
"description": action.get("description", f"Execute {action_name}"), "description": action.get(
"parameters": action.get("parameters", {}) "description", f"Execute {action_name}"
),
"parameters": action.get("parameters", {}),
} }
} }
parsed_schema[action_name] = action_schema parsed_schema[action_name] = action_schema
@@ -329,8 +335,8 @@ class EnterpriseActionKitToolAdapter:
traceback.print_exc() traceback.print_exc()
def _generate_detailed_description( def _generate_detailed_description(
self, schema: Dict[str, Any], indent: int = 0 self, schema: dict[str, Any], indent: int = 0
) -> List[str]: ) -> list[str]:
"""Generate detailed description for nested schema structures.""" """Generate detailed description for nested schema structures."""
descriptions = [] descriptions = []
indent_str = " " * indent indent_str = " " * indent
@@ -407,15 +413,17 @@ class EnterpriseActionKitToolAdapter:
self._tools = tools self._tools = tools
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]): def _set_enterprise_action_token(self, enterprise_action_token: str | None):
if enterprise_action_token and not enterprise_action_token.startswith("PK_"): if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
warnings.warn( warnings.warn(
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.", "Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
token = enterprise_action_token or os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN") token = enterprise_action_token or os.environ.get(
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
)
self.enterprise_action_token = token self.enterprise_action_token = token

View File

@@ -1,14 +1,14 @@
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any
from crewai_tools.tools.rag.rag_tool import Adapter
from lancedb import DBConnection as LanceDBConnection from lancedb import DBConnection as LanceDBConnection
from lancedb import connect as lancedb_connect from lancedb import connect as lancedb_connect
from lancedb.table import Table as LanceDBTable from lancedb.table import Table as LanceDBTable
from openai import Client as OpenAIClient from openai import Client as OpenAIClient
from pydantic import Field, PrivateAttr from pydantic import Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
def _default_embedding_function(): def _default_embedding_function():
client = OpenAIClient() client = OpenAIClient()

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.adapters.tool_collection import ToolCollection from crewai_tools.adapters.tool_collection import ToolCollection
""" """
MCPServer for CrewAI. MCPServer for CrewAI.
@@ -103,8 +104,8 @@ class MCPServerAdapter:
try: try:
subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True) subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True)
except subprocess.CalledProcessError: except subprocess.CalledProcessError as e:
raise ImportError("Failed to install mcp package") raise ImportError("Failed to install mcp package") from e
else: else:
raise ImportError( raise ImportError(
"`mcp` package not found, please run `uv add crewai-tools[mcp]`" "`mcp` package not found, please run `uv add crewai-tools[mcp]`"
@@ -112,7 +113,9 @@ class MCPServerAdapter:
try: try:
self._serverparams = serverparams self._serverparams = serverparams
self._adapter = MCPAdapt(self._serverparams, CrewAIAdapter(), connect_timeout) self._adapter = MCPAdapt(
self._serverparams, CrewAIAdapter(), connect_timeout
)
self.start() self.start()
except Exception as e: except Exception as e:

View File

@@ -1,41 +0,0 @@
from typing import Any, Optional
from crewai_tools.tools.rag.rag_tool import Adapter
try:
from embedchain import App
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
class PDFEmbedchainAdapter(Adapter):
embedchain_app: Any # Will be App when embedchain is available
summarize: bool = False
src: Optional[str] = None
def __init__(self, **data):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**data)
def query(self, question: str) -> str:
where = (
{"app_id": self.embedchain_app.config.id, "source": self.src}
if self.src
else None
)
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize), where=where
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.src = args[0] if args else None
self.embedchain_app.add(*args, **kwargs)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any
from crewai_tools.rag.core import RAG from crewai_tools.rag.core import RAG
from crewai_tools.tools.rag.rag_tool import Adapter from crewai_tools.tools.rag.rag_tool import Adapter
@@ -8,26 +8,23 @@ class RAGAdapter(Adapter):
def __init__( def __init__(
self, self,
collection_name: str = "crewai_knowledge_base", collection_name: str = "crewai_knowledge_base",
persist_directory: Optional[str] = None, persist_directory: str | None = None,
embedding_model: str = "text-embedding-3-small", embedding_model: str = "text-embedding-3-small",
top_k: int = 5, top_k: int = 5,
embedding_api_key: Optional[str] = None, embedding_api_key: str | None = None,
**embedding_kwargs **embedding_kwargs,
): ):
super().__init__() super().__init__()
# Prepare embedding configuration # Prepare embedding configuration
embedding_config = { embedding_config = {"api_key": embedding_api_key, **embedding_kwargs}
"api_key": embedding_api_key,
**embedding_kwargs
}
self._adapter = RAG( self._adapter = RAG(
collection_name=collection_name, collection_name=collection_name,
persist_directory=persist_directory, persist_directory=persist_directory,
embedding_model=embedding_model, embedding_model=embedding_model,
top_k=top_k, top_k=top_k,
embedding_config=embedding_config embedding_config=embedding_config,
) )
def query(self, question: str) -> str: def query(self, question: str) -> str:

View File

@@ -1,7 +1,10 @@
from typing import List, Optional, Union, TypeVar, Generic, Dict, Callable from collections.abc import Callable
from typing import Generic, TypeVar
from crewai.tools import BaseTool from crewai.tools import BaseTool
T = TypeVar('T', bound=BaseTool) T = TypeVar("T", bound=BaseTool)
class ToolCollection(list, Generic[T]): class ToolCollection(list, Generic[T]):
""" """
@@ -18,15 +21,15 @@ class ToolCollection(list, Generic[T]):
search_tool = tools["search"] search_tool = tools["search"]
""" """
def __init__(self, tools: Optional[List[T]] = None): def __init__(self, tools: list[T] | None = None):
super().__init__(tools or []) super().__init__(tools or [])
self._name_cache: Dict[str, T] = {} self._name_cache: dict[str, T] = {}
self._build_name_cache() self._build_name_cache()
def _build_name_cache(self) -> None: def _build_name_cache(self) -> None:
self._name_cache = {tool.name.lower(): tool for tool in self} self._name_cache = {tool.name.lower(): tool for tool in self}
def __getitem__(self, key: Union[int, str]) -> T: def __getitem__(self, key: int | str) -> T:
if isinstance(key, str): if isinstance(key, str):
return self._name_cache[key.lower()] return self._name_cache[key.lower()]
return super().__getitem__(key) return super().__getitem__(key)
@@ -35,7 +38,7 @@ class ToolCollection(list, Generic[T]):
super().append(tool) super().append(tool)
self._name_cache[tool.name.lower()] = tool self._name_cache[tool.name.lower()] = tool
def extend(self, tools: List[T]) -> None: def extend(self, tools: list[T]) -> None:
super().extend(tools) super().extend(tools)
self._build_name_cache() self._build_name_cache()
@@ -54,7 +57,7 @@ class ToolCollection(list, Generic[T]):
del self._name_cache[tool.name.lower()] del self._name_cache[tool.name.lower()]
return tool return tool
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]": def filter_by_names(self, names: list[str] | None = None) -> "ToolCollection[T]":
if names is None: if names is None:
return self return self
@@ -71,4 +74,4 @@ class ToolCollection(list, Generic[T]):
def clear(self) -> None: def clear(self) -> None:
super().clear() super().clear()
self._name_cache.clear() self._name_cache.clear()

View File

@@ -1,6 +1,5 @@
import os
import logging import logging
from typing import List import os
import requests import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -42,7 +41,7 @@ class ZapierActionTool(BaseTool):
execute_url = f"{ACTIONS_URL}/{self.action_id}/execute/" execute_url = f"{ACTIONS_URL}/{self.action_id}/execute/"
response = requests.request( response = requests.request(
"POST", execute_url, headers=headers, json=action_params "POST", execute_url, headers=headers, json=action_params, timeout=30
) )
response.raise_for_status() response.raise_for_status()
@@ -57,7 +56,7 @@ class ZapierActionsAdapter:
api_key: str api_key: str
def __init__(self, api_key: str = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("ZAPIER_API_KEY") self.api_key = api_key or os.getenv("ZAPIER_API_KEY")
if not self.api_key: if not self.api_key:
logger.error("Zapier Actions API key is required") logger.error("Zapier Actions API key is required")
@@ -67,13 +66,12 @@ class ZapierActionsAdapter:
headers = { headers = {
"x-api-key": self.api_key, "x-api-key": self.api_key,
} }
response = requests.request("GET", ACTIONS_URL, headers=headers) response = requests.request("GET", ACTIONS_URL, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
response_json = response.json() return response.json()
return response_json
def tools(self) -> List[BaseTool]: def tools(self) -> list[BaseTool]:
"""Convert Zapier actions to BaseTool instances""" """Convert Zapier actions to BaseTool instances"""
actions_response = self.get_zapier_actions() actions_response = self.get_zapier_actions()
tools = [] tools = []

View File

@@ -1,16 +1,16 @@
from .s3 import S3ReaderTool, S3WriterTool
from .bedrock import ( from .bedrock import (
BedrockKBRetrieverTool,
BedrockInvokeAgentTool, BedrockInvokeAgentTool,
BedrockKBRetrieverTool,
create_browser_toolkit, create_browser_toolkit,
create_code_interpreter_toolkit, create_code_interpreter_toolkit,
) )
from .s3 import S3ReaderTool, S3WriterTool
__all__ = [ __all__ = [
"BedrockInvokeAgentTool",
"BedrockKBRetrieverTool",
"S3ReaderTool", "S3ReaderTool",
"S3WriterTool", "S3WriterTool",
"BedrockKBRetrieverTool",
"BedrockInvokeAgentTool",
"create_browser_toolkit", "create_browser_toolkit",
"create_code_interpreter_toolkit" "create_code_interpreter_toolkit",
] ]

View File

@@ -1,11 +1,11 @@
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
from .agents.invoke_agent_tool import BedrockInvokeAgentTool from .agents.invoke_agent_tool import BedrockInvokeAgentTool
from .browser import create_browser_toolkit from .browser import create_browser_toolkit
from .code_interpreter import create_code_interpreter_toolkit from .code_interpreter import create_code_interpreter_toolkit
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
__all__ = [ __all__ = [
"BedrockKBRetrieverTool",
"BedrockInvokeAgentTool", "BedrockInvokeAgentTool",
"BedrockKBRetrieverTool",
"create_browser_toolkit", "create_browser_toolkit",
"create_code_interpreter_toolkit" "create_code_interpreter_toolkit",
] ]

View File

@@ -1,12 +1,11 @@
from typing import Type, Optional, Dict, Any, List
import os
import json import json
import uuid import os
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from dotenv import load_dotenv from typing import ClassVar
from crewai.tools import BaseTool from crewai.tools import BaseTool
from dotenv import load_dotenv
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..exceptions import BedrockAgentError, BedrockValidationError from ..exceptions import BedrockAgentError, BedrockValidationError
@@ -17,29 +16,30 @@ load_dotenv()
class BedrockInvokeAgentToolInput(BaseModel): class BedrockInvokeAgentToolInput(BaseModel):
"""Input schema for BedrockInvokeAgentTool.""" """Input schema for BedrockInvokeAgentTool."""
query: str = Field(..., description="The query to send to the agent") query: str = Field(..., description="The query to send to the agent")
class BedrockInvokeAgentTool(BaseTool): class BedrockInvokeAgentTool(BaseTool):
name: str = "Bedrock Agent Invoke Tool" name: str = "Bedrock Agent Invoke Tool"
description: str = "An agent responsible for policy analysis." description: str = "An agent responsible for policy analysis."
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput args_schema: type[BaseModel] = BedrockInvokeAgentToolInput
agent_id: str = None agent_id: str = None
agent_alias_id: str = None agent_alias_id: str = None
session_id: str = None session_id: str = None
enable_trace: bool = False enable_trace: bool = False
end_session: bool = False end_session: bool = False
package_dependencies: List[str] = ["boto3"] package_dependencies: ClassVar[list[str]] = ["boto3"]
def __init__( def __init__(
self, self,
agent_id: str = None, agent_id: str | None = None,
agent_alias_id: str = None, agent_alias_id: str | None = None,
session_id: str = None, session_id: str | None = None,
enable_trace: bool = False, enable_trace: bool = False,
end_session: bool = False, end_session: bool = False,
description: Optional[str] = None, description: str | None = None,
**kwargs **kwargs,
): ):
"""Initialize the BedrockInvokeAgentTool with agent configuration. """Initialize the BedrockInvokeAgentTool with agent configuration.
@@ -54,9 +54,11 @@ class BedrockInvokeAgentTool(BaseTool):
super().__init__(**kwargs) super().__init__(**kwargs)
# Get values from environment variables if not provided # Get values from environment variables if not provided
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID') self.agent_id = agent_id or os.getenv("BEDROCK_AGENT_ID")
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID') self.agent_alias_id = agent_alias_id or os.getenv("BEDROCK_AGENT_ALIAS_ID")
self.session_id = session_id or str(int(time.time())) # Use timestamp as session ID if not provided self.session_id = session_id or str(
int(time.time())
) # Use timestamp as session ID if not provided
self.enable_trace = enable_trace self.enable_trace = enable_trace
self.end_session = end_session self.end_session = end_session
@@ -87,20 +89,22 @@ class BedrockInvokeAgentTool(BaseTool):
raise BedrockValidationError("session_id must be a string") raise BedrockValidationError("session_id must be a string")
except BedrockValidationError as e: except BedrockValidationError as e:
raise BedrockValidationError(f"Parameter validation failed: {str(e)}") raise BedrockValidationError(f"Parameter validation failed: {e!s}") from e
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
try: try:
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
except ImportError: except ImportError as e:
raise ImportError("`boto3` package not found, please run `uv add boto3`") raise ImportError("`boto3` package not found, please run `uv add boto3`") from e
try: try:
# Initialize the Bedrock Agent Runtime client # Initialize the Bedrock Agent Runtime client
bedrock_agent = boto3.client( bedrock_agent = boto3.client(
"bedrock-agent-runtime", "bedrock-agent-runtime",
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-west-2')) region_name=os.getenv(
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")
),
) )
# Format the prompt with current time # Format the prompt with current time
@@ -119,28 +123,28 @@ Below is the users query or task. Complete it and answer it consicely and to the
sessionId=self.session_id, sessionId=self.session_id,
inputText=prompt, inputText=prompt,
enableTrace=self.enable_trace, enableTrace=self.enable_trace,
endSession=self.end_session endSession=self.end_session,
) )
# Process the response # Process the response
completion = "" completion = ""
# Check if response contains a completion field # Check if response contains a completion field
if 'completion' in response: if "completion" in response:
# Process streaming response format # Process streaming response format
for event in response.get('completion', []): for event in response.get("completion", []):
if 'chunk' in event and 'bytes' in event['chunk']: if "chunk" in event and "bytes" in event["chunk"]:
chunk_bytes = event['chunk']['bytes'] chunk_bytes = event["chunk"]["bytes"]
if isinstance(chunk_bytes, (bytes, bytearray)): if isinstance(chunk_bytes, (bytes, bytearray)):
completion += chunk_bytes.decode('utf-8') completion += chunk_bytes.decode("utf-8")
else: else:
completion += str(chunk_bytes) completion += str(chunk_bytes)
# If no completion found in streaming format, try direct format # If no completion found in streaming format, try direct format
if not completion and 'chunk' in response and 'bytes' in response['chunk']: if not completion and "chunk" in response and "bytes" in response["chunk"]:
chunk_bytes = response['chunk']['bytes'] chunk_bytes = response["chunk"]["bytes"]
if isinstance(chunk_bytes, (bytes, bytearray)): if isinstance(chunk_bytes, (bytes, bytearray)):
completion = chunk_bytes.decode('utf-8') completion = chunk_bytes.decode("utf-8")
else: else:
completion = str(chunk_bytes) completion = str(chunk_bytes)
@@ -148,14 +152,16 @@ Below is the users query or task. Complete it and answer it consicely and to the
if not completion: if not completion:
debug_info = { debug_info = {
"error": "Could not extract completion from response", "error": "Could not extract completion from response",
"response_keys": list(response.keys()) "response_keys": list(response.keys()),
} }
# Add more debug info # Add more debug info
if 'chunk' in response: if "chunk" in response:
debug_info["chunk_keys"] = list(response['chunk'].keys()) debug_info["chunk_keys"] = list(response["chunk"].keys())
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}") raise BedrockAgentError(
f"Failed to extract completion: {json.dumps(debug_info, indent=2)}"
)
return completion return completion
@@ -164,13 +170,13 @@ Below is the users query or task. Complete it and answer it consicely and to the
error_message = str(e) error_message = str(e)
# Try to extract error code if available # Try to extract error code if available
if hasattr(e, 'response') and 'Error' in e.response: if hasattr(e, "response") and "Error" in e.response:
error_code = e.response['Error'].get('Code', 'Unknown') error_code = e.response["Error"].get("Code", "Unknown")
error_message = e.response['Error'].get('Message', str(e)) error_message = e.response["Error"].get("Message", str(e))
raise BedrockAgentError(f"Error ({error_code}): {error_message}") raise BedrockAgentError(f"Error ({error_code}): {error_message}") from e
except BedrockAgentError: except BedrockAgentError:
# Re-raise BedrockAgentError exceptions # Re-raise BedrockAgentError exceptions
raise raise
except Exception as e: except Exception as e:
raise BedrockAgentError(f"Unexpected error: {str(e)}") raise BedrockAgentError(f"Unexpected error: {e!s}") from e

View File

@@ -1,3 +1,3 @@
from .browser_toolkit import BrowserToolkit, create_browser_toolkit from .browser_toolkit import BrowserToolkit, create_browser_toolkit
__all__ = ["BrowserToolkit", "create_browser_toolkit"] __all__ = ["BrowserToolkit", "create_browser_toolkit"]

View File

@@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from bedrock_agentcore.tools.browser_client import BrowserClient
from playwright.async_api import Browser as AsyncBrowser from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser from playwright.sync_api import Browser as SyncBrowser
from bedrock_agentcore.tools.browser_client import BrowserClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -28,8 +28,8 @@ class BrowserSessionManager:
region: AWS region for browser client region: AWS region for browser client
""" """
self.region = region self.region = region
self._async_sessions: Dict[str, Tuple[BrowserClient, AsyncBrowser]] = {} self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
self._sync_sessions: Dict[str, Tuple[BrowserClient, SyncBrowser]] = {} self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
async def get_async_browser(self, thread_id: str) -> AsyncBrowser: async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
""" """
@@ -75,6 +75,7 @@ class BrowserSessionManager:
Exception: If browser session creation fails Exception: If browser session creation fails
""" """
from bedrock_agentcore.tools.browser_client import BrowserClient from bedrock_agentcore.tools.browser_client import BrowserClient
browser_client = BrowserClient(region=self.region) browser_client = BrowserClient(region=self.region)
try: try:
@@ -132,6 +133,7 @@ class BrowserSessionManager:
Exception: If browser session creation fails Exception: If browser session creation fails
""" """
from bedrock_agentcore.tools.browser_client import BrowserClient from bedrock_agentcore.tools.browser_client import BrowserClient
browser_client = BrowserClient(region=self.region) browser_client = BrowserClient(region=self.region)
try: try:
@@ -257,4 +259,4 @@ class BrowserSessionManager:
for thread_id in sync_thread_ids: for thread_id in sync_thread_ids:
self.close_sync_browser(thread_id) self.close_sync_browser(thread_id)
logger.info("All browser sessions closed") logger.info("All browser sessions closed")

View File

@@ -1,9 +1,9 @@
"""Toolkit for navigating web with AWS browser.""" """Toolkit for navigating web with AWS browser."""
import asyncio
import json import json
import logging import logging
import asyncio from typing import Any
from typing import Dict, List, Tuple, Any, Type
from urllib.parse import urlparse from urllib.parse import urlparse
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -18,78 +18,100 @@ logger = logging.getLogger(__name__)
# Input schemas # Input schemas
class NavigateToolInput(BaseModel): class NavigateToolInput(BaseModel):
"""Input for NavigateTool.""" """Input for NavigateTool."""
url: str = Field(description="URL to navigate to") url: str = Field(description="URL to navigate to")
thread_id: str = Field(default="default", description="Thread ID for the browser session") thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class ClickToolInput(BaseModel): class ClickToolInput(BaseModel):
"""Input for ClickTool.""" """Input for ClickTool."""
selector: str = Field(description="CSS selector for the element to click on") selector: str = Field(description="CSS selector for the element to click on")
thread_id: str = Field(default="default", description="Thread ID for the browser session") thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class GetElementsToolInput(BaseModel): class GetElementsToolInput(BaseModel):
"""Input for GetElementsTool.""" """Input for GetElementsTool."""
selector: str = Field(description="CSS selector for elements to get") selector: str = Field(description="CSS selector for elements to get")
thread_id: str = Field(default="default", description="Thread ID for the browser session") thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class ExtractTextToolInput(BaseModel): class ExtractTextToolInput(BaseModel):
"""Input for ExtractTextTool.""" """Input for ExtractTextTool."""
thread_id: str = Field(default="default", description="Thread ID for the browser session")
thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class ExtractHyperlinksToolInput(BaseModel): class ExtractHyperlinksToolInput(BaseModel):
"""Input for ExtractHyperlinksTool.""" """Input for ExtractHyperlinksTool."""
thread_id: str = Field(default="default", description="Thread ID for the browser session")
thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class NavigateBackToolInput(BaseModel): class NavigateBackToolInput(BaseModel):
"""Input for NavigateBackTool.""" """Input for NavigateBackTool."""
thread_id: str = Field(default="default", description="Thread ID for the browser session")
thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
class CurrentWebPageToolInput(BaseModel): class CurrentWebPageToolInput(BaseModel):
"""Input for CurrentWebPageTool.""" """Input for CurrentWebPageTool."""
thread_id: str = Field(default="default", description="Thread ID for the browser session")
thread_id: str = Field(
default="default", description="Thread ID for the browser session"
)
# Base tool class # Base tool class
class BrowserBaseTool(BaseTool): class BrowserBaseTool(BaseTool):
"""Base class for browser tools.""" """Base class for browser tools."""
def __init__(self, session_manager: BrowserSessionManager): def __init__(self, session_manager: BrowserSessionManager):
"""Initialize with a session manager.""" """Initialize with a session manager."""
super().__init__() super().__init__()
self._session_manager = session_manager self._session_manager = session_manager
if self._is_in_asyncio_loop() and hasattr(self, '_arun'): if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
self._original_run = self._run self._original_run = self._run
# Override _run to use _arun when in an asyncio loop # Override _run to use _arun when in an asyncio loop
def patched_run(*args, **kwargs): def patched_run(*args, **kwargs):
try: try:
import nest_asyncio import nest_asyncio
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
nest_asyncio.apply(loop) nest_asyncio.apply(loop)
return asyncio.get_event_loop().run_until_complete( return asyncio.get_event_loop().run_until_complete(
self._arun(*args, **kwargs) self._arun(*args, **kwargs)
) )
except Exception as e: except Exception as e:
return f"Error in patched _run: {str(e)}" return f"Error in patched _run: {e!s}"
self._run = patched_run self._run = patched_run
async def get_async_page(self, thread_id: str) -> Any: async def get_async_page(self, thread_id: str) -> Any:
"""Get or create a page for the specified thread.""" """Get or create a page for the specified thread."""
browser = await self._session_manager.get_async_browser(thread_id) browser = await self._session_manager.get_async_browser(thread_id)
page = await aget_current_page(browser) return await aget_current_page(browser)
return page
def get_sync_page(self, thread_id: str) -> Any: def get_sync_page(self, thread_id: str) -> Any:
"""Get or create a page for the specified thread.""" """Get or create a page for the specified thread."""
browser = self._session_manager.get_sync_browser(thread_id) browser = self._session_manager.get_sync_browser(thread_id)
page = get_current_page(browser) return get_current_page(browser)
return page
def _is_in_asyncio_loop(self) -> bool: def _is_in_asyncio_loop(self) -> bool:
"""Check if we're currently in an asyncio event loop.""" """Check if we're currently in an asyncio event loop."""
try: try:
@@ -105,8 +127,8 @@ class NavigateTool(BrowserBaseTool):
name: str = "navigate_browser" name: str = "navigate_browser"
description: str = "Navigate a browser to the specified URL" description: str = "Navigate a browser to the specified URL"
args_schema: Type[BaseModel] = NavigateToolInput args_schema: type[BaseModel] = NavigateToolInput
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str: def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -123,7 +145,7 @@ class NavigateTool(BrowserBaseTool):
status = response.status if response else "unknown" status = response.status if response else "unknown"
return f"Navigating to {url} returned status code {status}" return f"Navigating to {url} returned status code {status}"
except Exception as e: except Exception as e:
return f"Error navigating to {url}: {str(e)}" return f"Error navigating to {url}: {e!s}"
async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str: async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -141,7 +163,7 @@ class NavigateTool(BrowserBaseTool):
status = response.status if response else "unknown" status = response.status if response else "unknown"
return f"Navigating to {url} returned status code {status}" return f"Navigating to {url} returned status code {status}"
except Exception as e: except Exception as e:
return f"Error navigating to {url}: {str(e)}" return f"Error navigating to {url}: {e!s}"
class ClickTool(BrowserBaseTool): class ClickTool(BrowserBaseTool):
@@ -149,8 +171,8 @@ class ClickTool(BrowserBaseTool):
name: str = "click_element" name: str = "click_element"
description: str = "Click on an element with the given CSS selector" description: str = "Click on an element with the given CSS selector"
args_schema: Type[BaseModel] = ClickToolInput args_schema: type[BaseModel] = ClickToolInput
visible_only: bool = True visible_only: bool = True
"""Whether to consider only visible elements.""" """Whether to consider only visible elements."""
playwright_strict: bool = False playwright_strict: bool = False
@@ -162,7 +184,7 @@ class ClickTool(BrowserBaseTool):
if not self.visible_only: if not self.visible_only:
return selector return selector
return f"{selector} >> visible=1" return f"{selector} >> visible=1"
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str: def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -172,7 +194,7 @@ class ClickTool(BrowserBaseTool):
# Click on the element # Click on the element
selector_effective = self._selector_effective(selector=selector) selector_effective = self._selector_effective(selector=selector)
from playwright.sync_api import TimeoutError as PlaywrightTimeoutError from playwright.sync_api import TimeoutError as PlaywrightTimeoutError
try: try:
page.click( page.click(
selector_effective, selector_effective,
@@ -182,11 +204,11 @@ class ClickTool(BrowserBaseTool):
except PlaywrightTimeoutError: except PlaywrightTimeoutError:
return f"Unable to click on element '{selector}'" return f"Unable to click on element '{selector}'"
except Exception as click_error: except Exception as click_error:
return f"Unable to click on element '{selector}': {str(click_error)}" return f"Unable to click on element '{selector}': {click_error!s}"
return f"Clicked element '{selector}'" return f"Clicked element '{selector}'"
except Exception as e: except Exception as e:
return f"Error clicking on element: {str(e)}" return f"Error clicking on element: {e!s}"
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str: async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -197,7 +219,7 @@ class ClickTool(BrowserBaseTool):
# Click on the element # Click on the element
selector_effective = self._selector_effective(selector=selector) selector_effective = self._selector_effective(selector=selector)
from playwright.async_api import TimeoutError as PlaywrightTimeoutError from playwright.async_api import TimeoutError as PlaywrightTimeoutError
try: try:
await page.click( await page.click(
selector_effective, selector_effective,
@@ -207,19 +229,20 @@ class ClickTool(BrowserBaseTool):
except PlaywrightTimeoutError: except PlaywrightTimeoutError:
return f"Unable to click on element '{selector}'" return f"Unable to click on element '{selector}'"
except Exception as click_error: except Exception as click_error:
return f"Unable to click on element '{selector}': {str(click_error)}" return f"Unable to click on element '{selector}': {click_error!s}"
return f"Clicked element '{selector}'" return f"Clicked element '{selector}'"
except Exception as e: except Exception as e:
return f"Error clicking on element: {str(e)}" return f"Error clicking on element: {e!s}"
class NavigateBackTool(BrowserBaseTool): class NavigateBackTool(BrowserBaseTool):
"""Tool for navigating back in browser history.""" """Tool for navigating back in browser history."""
name: str = "navigate_back" name: str = "navigate_back"
description: str = "Navigate back to the previous page" description: str = "Navigate back to the previous page"
args_schema: Type[BaseModel] = NavigateBackToolInput args_schema: type[BaseModel] = NavigateBackToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str: def _run(self, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -231,9 +254,9 @@ class NavigateBackTool(BrowserBaseTool):
page.go_back() page.go_back()
return "Navigated back to the previous page" return "Navigated back to the previous page"
except Exception as nav_error: except Exception as nav_error:
return f"Unable to navigate back: {str(nav_error)}" return f"Unable to navigate back: {nav_error!s}"
except Exception as e: except Exception as e:
return f"Error navigating back: {str(e)}" return f"Error navigating back: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str: async def _arun(self, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -246,17 +269,18 @@ class NavigateBackTool(BrowserBaseTool):
await page.go_back() await page.go_back()
return "Navigated back to the previous page" return "Navigated back to the previous page"
except Exception as nav_error: except Exception as nav_error:
return f"Unable to navigate back: {str(nav_error)}" return f"Unable to navigate back: {nav_error!s}"
except Exception as e: except Exception as e:
return f"Error navigating back: {str(e)}" return f"Error navigating back: {e!s}"
class ExtractTextTool(BrowserBaseTool): class ExtractTextTool(BrowserBaseTool):
"""Tool for extracting text from a webpage.""" """Tool for extracting text from a webpage."""
name: str = "extract_text" name: str = "extract_text"
description: str = "Extract all the text on the current webpage" description: str = "Extract all the text on the current webpage"
args_schema: Type[BaseModel] = ExtractTextToolInput args_schema: type[BaseModel] = ExtractTextToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str: def _run(self, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -268,7 +292,7 @@ class ExtractTextTool(BrowserBaseTool):
"The 'beautifulsoup4' package is required to use this tool." "The 'beautifulsoup4' package is required to use this tool."
" Please install it with 'pip install beautifulsoup4'." " Please install it with 'pip install beautifulsoup4'."
) )
# Get the current page # Get the current page
page = self.get_sync_page(thread_id) page = self.get_sync_page(thread_id)
@@ -277,7 +301,7 @@ class ExtractTextTool(BrowserBaseTool):
soup = BeautifulSoup(content, "html.parser") soup = BeautifulSoup(content, "html.parser")
return soup.get_text(separator="\n").strip() return soup.get_text(separator="\n").strip()
except Exception as e: except Exception as e:
return f"Error extracting text: {str(e)}" return f"Error extracting text: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str: async def _arun(self, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -290,7 +314,7 @@ class ExtractTextTool(BrowserBaseTool):
"The 'beautifulsoup4' package is required to use this tool." "The 'beautifulsoup4' package is required to use this tool."
" Please install it with 'pip install beautifulsoup4'." " Please install it with 'pip install beautifulsoup4'."
) )
# Get the current page # Get the current page
page = await self.get_async_page(thread_id) page = await self.get_async_page(thread_id)
@@ -299,15 +323,16 @@ class ExtractTextTool(BrowserBaseTool):
soup = BeautifulSoup(content, "html.parser") soup = BeautifulSoup(content, "html.parser")
return soup.get_text(separator="\n").strip() return soup.get_text(separator="\n").strip()
except Exception as e: except Exception as e:
return f"Error extracting text: {str(e)}" return f"Error extracting text: {e!s}"
class ExtractHyperlinksTool(BrowserBaseTool): class ExtractHyperlinksTool(BrowserBaseTool):
"""Tool for extracting hyperlinks from a webpage.""" """Tool for extracting hyperlinks from a webpage."""
name: str = "extract_hyperlinks" name: str = "extract_hyperlinks"
description: str = "Extract all hyperlinks on the current webpage" description: str = "Extract all hyperlinks on the current webpage"
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput args_schema: type[BaseModel] = ExtractHyperlinksToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str: def _run(self, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -319,7 +344,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
"The 'beautifulsoup4' package is required to use this tool." "The 'beautifulsoup4' package is required to use this tool."
" Please install it with 'pip install beautifulsoup4'." " Please install it with 'pip install beautifulsoup4'."
) )
# Get the current page # Get the current page
page = self.get_sync_page(thread_id) page = self.get_sync_page(thread_id)
@@ -330,15 +355,15 @@ class ExtractHyperlinksTool(BrowserBaseTool):
for link in soup.find_all("a", href=True): for link in soup.find_all("a", href=True):
text = link.get_text().strip() text = link.get_text().strip()
href = link["href"] href = link["href"]
if href.startswith("http") or href.startswith("https"): if href.startswith(("http", "https")):
links.append({"text": text, "url": href}) links.append({"text": text, "url": href})
if not links: if not links:
return "No hyperlinks found on the current page." return "No hyperlinks found on the current page."
return json.dumps(links, indent=2) return json.dumps(links, indent=2)
except Exception as e: except Exception as e:
return f"Error extracting hyperlinks: {str(e)}" return f"Error extracting hyperlinks: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str: async def _arun(self, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -351,7 +376,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
"The 'beautifulsoup4' package is required to use this tool." "The 'beautifulsoup4' package is required to use this tool."
" Please install it with 'pip install beautifulsoup4'." " Please install it with 'pip install beautifulsoup4'."
) )
# Get the current page # Get the current page
page = await self.get_async_page(thread_id) page = await self.get_async_page(thread_id)
@@ -362,23 +387,24 @@ class ExtractHyperlinksTool(BrowserBaseTool):
for link in soup.find_all("a", href=True): for link in soup.find_all("a", href=True):
text = link.get_text().strip() text = link.get_text().strip()
href = link["href"] href = link["href"]
if href.startswith("http") or href.startswith("https"): if href.startswith(("http", "https")):
links.append({"text": text, "url": href}) links.append({"text": text, "url": href})
if not links: if not links:
return "No hyperlinks found on the current page." return "No hyperlinks found on the current page."
return json.dumps(links, indent=2) return json.dumps(links, indent=2)
except Exception as e: except Exception as e:
return f"Error extracting hyperlinks: {str(e)}" return f"Error extracting hyperlinks: {e!s}"
class GetElementsTool(BrowserBaseTool): class GetElementsTool(BrowserBaseTool):
"""Tool for getting elements from a webpage.""" """Tool for getting elements from a webpage."""
name: str = "get_elements" name: str = "get_elements"
description: str = "Get elements from the webpage using a CSS selector" description: str = "Get elements from the webpage using a CSS selector"
args_schema: Type[BaseModel] = GetElementsToolInput args_schema: type[BaseModel] = GetElementsToolInput
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str: def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -389,15 +415,15 @@ class GetElementsTool(BrowserBaseTool):
elements = page.query_selector_all(selector) elements = page.query_selector_all(selector)
if not elements: if not elements:
return f"No elements found with selector '{selector}'" return f"No elements found with selector '{selector}'"
elements_text = [] elements_text = []
for i, element in enumerate(elements): for i, element in enumerate(elements):
text = element.text_content() text = element.text_content()
elements_text.append(f"Element {i+1}: {text.strip()}") elements_text.append(f"Element {i + 1}: {text.strip()}")
return "\n".join(elements_text) return "\n".join(elements_text)
except Exception as e: except Exception as e:
return f"Error getting elements: {str(e)}" return f"Error getting elements: {e!s}"
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str: async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -409,23 +435,24 @@ class GetElementsTool(BrowserBaseTool):
elements = await page.query_selector_all(selector) elements = await page.query_selector_all(selector)
if not elements: if not elements:
return f"No elements found with selector '{selector}'" return f"No elements found with selector '{selector}'"
elements_text = [] elements_text = []
for i, element in enumerate(elements): for i, element in enumerate(elements):
text = await element.text_content() text = await element.text_content()
elements_text.append(f"Element {i+1}: {text.strip()}") elements_text.append(f"Element {i + 1}: {text.strip()}")
return "\n".join(elements_text) return "\n".join(elements_text)
except Exception as e: except Exception as e:
return f"Error getting elements: {str(e)}" return f"Error getting elements: {e!s}"
class CurrentWebPageTool(BrowserBaseTool): class CurrentWebPageTool(BrowserBaseTool):
"""Tool for getting information about the current webpage.""" """Tool for getting information about the current webpage."""
name: str = "current_webpage" name: str = "current_webpage"
description: str = "Get information about the current webpage" description: str = "Get information about the current webpage"
args_schema: Type[BaseModel] = CurrentWebPageToolInput args_schema: type[BaseModel] = CurrentWebPageToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str: def _run(self, thread_id: str = "default", **kwargs) -> str:
"""Use the sync tool.""" """Use the sync tool."""
try: try:
@@ -437,7 +464,7 @@ class CurrentWebPageTool(BrowserBaseTool):
title = page.title() title = page.title()
return f"URL: {url}\nTitle: {title}" return f"URL: {url}\nTitle: {title}"
except Exception as e: except Exception as e:
return f"Error getting current webpage info: {str(e)}" return f"Error getting current webpage info: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str: async def _arun(self, thread_id: str = "default", **kwargs) -> str:
"""Use the async tool.""" """Use the async tool."""
@@ -450,7 +477,7 @@ class CurrentWebPageTool(BrowserBaseTool):
title = await page.title() title = await page.title()
return f"URL: {url}\nTitle: {title}" return f"URL: {url}\nTitle: {title}"
except Exception as e: except Exception as e:
return f"Error getting current webpage info: {str(e)}" return f"Error getting current webpage info: {e!s}"
class BrowserToolkit: class BrowserToolkit:
@@ -504,10 +531,10 @@ class BrowserToolkit:
""" """
self.region = region self.region = region
self.session_manager = BrowserSessionManager(region=region) self.session_manager = BrowserSessionManager(region=region)
self.tools: List[BaseTool] = [] self.tools: list[BaseTool] = []
self._nest_current_loop() self._nest_current_loop()
self._setup_tools() self._setup_tools()
def _nest_current_loop(self): def _nest_current_loop(self):
"""Apply nest_asyncio if we're in an asyncio loop.""" """Apply nest_asyncio if we're in an asyncio loop."""
try: try:
@@ -515,9 +542,10 @@ class BrowserToolkit:
if loop.is_running(): if loop.is_running():
try: try:
import nest_asyncio import nest_asyncio
nest_asyncio.apply(loop) nest_asyncio.apply(loop)
except Exception as e: except Exception as e:
logger.warning(f"Failed to apply nest_asyncio: {str(e)}") logger.warning(f"Failed to apply nest_asyncio: {e!s}")
except RuntimeError: except RuntimeError:
pass pass
@@ -530,10 +558,10 @@ class BrowserToolkit:
ExtractTextTool(session_manager=self.session_manager), ExtractTextTool(session_manager=self.session_manager),
ExtractHyperlinksTool(session_manager=self.session_manager), ExtractHyperlinksTool(session_manager=self.session_manager),
GetElementsTool(session_manager=self.session_manager), GetElementsTool(session_manager=self.session_manager),
CurrentWebPageTool(session_manager=self.session_manager) CurrentWebPageTool(session_manager=self.session_manager),
] ]
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> list[BaseTool]:
""" """
Get the list of browser tools Get the list of browser tools
@@ -542,7 +570,7 @@ class BrowserToolkit:
""" """
return self.tools return self.tools
def get_tools_by_name(self) -> Dict[str, BaseTool]: def get_tools_by_name(self) -> dict[str, BaseTool]:
""" """
Get a dictionary of tools mapped by their names Get a dictionary of tools mapped by their names
@@ -555,11 +583,11 @@ class BrowserToolkit:
"""Clean up all browser sessions asynchronously""" """Clean up all browser sessions asynchronously"""
await self.session_manager.close_all_browsers() await self.session_manager.close_all_browsers()
logger.info("All browser sessions cleaned up") logger.info("All browser sessions cleaned up")
def sync_cleanup(self) -> None: def sync_cleanup(self) -> None:
"""Clean up all browser sessions from synchronous code""" """Clean up all browser sessions from synchronous code"""
import asyncio import asyncio
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
@@ -572,7 +600,7 @@ class BrowserToolkit:
def create_browser_toolkit( def create_browser_toolkit(
region: str = "us-west-2", region: str = "us-west-2",
) -> Tuple[BrowserToolkit, List[BaseTool]]: ) -> tuple[BrowserToolkit, list[BaseTool]]:
""" """
Create a BrowserToolkit Create a BrowserToolkit

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser from playwright.async_api import Browser as AsyncBrowser
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from playwright.sync_api import Page as SyncPage from playwright.sync_api import Page as SyncPage
async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage: async def aget_current_page(browser: AsyncBrowser | Any) -> AsyncPage:
""" """
Asynchronously get the current page of the browser. Asynchronously get the current page of the browser.
Args: Args:
@@ -26,7 +26,7 @@ async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
return context.pages[-1] return context.pages[-1]
def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage: def get_current_page(browser: SyncBrowser | Any) -> SyncPage:
""" """
Get the current page of the browser. Get the current page of the browser.
Args: Args:
@@ -40,4 +40,4 @@ def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage:
context = browser.contexts[0] context = browser.contexts[0]
if not context.pages: if not context.pages:
return context.new_page() return context.new_page()
return context.pages[-1] return context.pages[-1]

View File

@@ -1,3 +1,6 @@
from .code_interpreter_toolkit import CodeInterpreterToolkit, create_code_interpreter_toolkit from .code_interpreter_toolkit import (
CodeInterpreterToolkit,
create_code_interpreter_toolkit,
)
__all__ = ["CodeInterpreterToolkit", "create_code_interpreter_toolkit"] __all__ = ["CodeInterpreterToolkit", "create_code_interpreter_toolkit"]

View File

@@ -1,9 +1,10 @@
"""Toolkit for working with AWS Bedrock Code Interpreter.""" """Toolkit for working with AWS Bedrock Code Interpreter."""
from __future__ import annotations from __future__ import annotations
import json import json
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Optional, Type, Any from typing import TYPE_CHECKING, Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -39,124 +40,184 @@ def extract_output_from_stream(response):
output.append(f"==== File: {file_path} ====\n{file_content}\n") output.append(f"==== File: {file_path} ====\n{file_content}\n")
else: else:
output.append(json.dumps(resource)) output.append(json.dumps(resource))
return "\n".join(output) return "\n".join(output)
# Input schemas # Input schemas
class ExecuteCodeInput(BaseModel): class ExecuteCodeInput(BaseModel):
"""Input for ExecuteCode.""" """Input for ExecuteCode."""
code: str = Field(description="The code to execute") code: str = Field(description="The code to execute")
language: str = Field(default="python", description="The programming language of the code") language: str = Field(
clear_context: bool = Field(default=False, description="Whether to clear execution context") default="python", description="The programming language of the code"
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") )
clear_context: bool = Field(
default=False, description="Whether to clear execution context"
)
thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class ExecuteCommandInput(BaseModel): class ExecuteCommandInput(BaseModel):
"""Input for ExecuteCommand.""" """Input for ExecuteCommand."""
command: str = Field(description="The command to execute") command: str = Field(description="The command to execute")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class ReadFilesInput(BaseModel): class ReadFilesInput(BaseModel):
"""Input for ReadFiles.""" """Input for ReadFiles."""
paths: List[str] = Field(description="List of file paths to read")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") paths: list[str] = Field(description="List of file paths to read")
thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class ListFilesInput(BaseModel): class ListFilesInput(BaseModel):
"""Input for ListFiles.""" """Input for ListFiles."""
directory_path: str = Field(default="", description="Path to the directory to list") directory_path: str = Field(default="", description="Path to the directory to list")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class DeleteFilesInput(BaseModel): class DeleteFilesInput(BaseModel):
"""Input for DeleteFiles.""" """Input for DeleteFiles."""
paths: List[str] = Field(description="List of file paths to delete")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") paths: list[str] = Field(description="List of file paths to delete")
thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class WriteFilesInput(BaseModel): class WriteFilesInput(BaseModel):
"""Input for WriteFiles.""" """Input for WriteFiles."""
files: List[Dict[str, str]] = Field(description="List of dictionaries with path and text fields")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") files: list[dict[str, str]] = Field(
description="List of dictionaries with path and text fields"
)
thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class StartCommandInput(BaseModel): class StartCommandInput(BaseModel):
"""Input for StartCommand.""" """Input for StartCommand."""
command: str = Field(description="The command to execute asynchronously") command: str = Field(description="The command to execute asynchronously")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class GetTaskInput(BaseModel): class GetTaskInput(BaseModel):
"""Input for GetTask.""" """Input for GetTask."""
task_id: str = Field(description="The ID of the task to check") task_id: str = Field(description="The ID of the task to check")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
class StopTaskInput(BaseModel): class StopTaskInput(BaseModel):
"""Input for StopTask.""" """Input for StopTask."""
task_id: str = Field(description="The ID of the task to stop") task_id: str = Field(description="The ID of the task to stop")
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session") thread_id: str = Field(
default="default", description="Thread ID for the code interpreter session"
)
# Tool classes # Tool classes
class ExecuteCodeTool(BaseTool): class ExecuteCodeTool(BaseTool):
"""Tool for executing code in various languages.""" """Tool for executing code in various languages."""
name: str = "execute_code" name: str = "execute_code"
description: str = "Execute code in various languages (primarily Python)" description: str = "Execute code in various languages (primarily Python)"
args_schema: Type[BaseModel] = ExecuteCodeInput args_schema: type[BaseModel] = ExecuteCodeInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, code: str, language: str = "python", clear_context: bool = False, thread_id: str = "default") -> str: def _run(
self,
code: str,
language: str = "python",
clear_context: bool = False,
thread_id: str = "default",
) -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Execute code # Execute code
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="executeCode", method="executeCode",
params={"code": code, "language": language, "clearContext": clear_context}, params={
"code": code,
"language": language,
"clearContext": clear_context,
},
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error executing code: {str(e)}" return f"Error executing code: {e!s}"
async def _arun(self, code: str, language: str = "python", clear_context: bool = False, thread_id: str = "default") -> str: async def _arun(
self,
code: str,
language: str = "python",
clear_context: bool = False,
thread_id: str = "default",
) -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(code=code, language=language, clear_context=clear_context, thread_id=thread_id) return self._run(
code=code,
language=language,
clear_context=clear_context,
thread_id=thread_id,
)
class ExecuteCommandTool(BaseTool): class ExecuteCommandTool(BaseTool):
"""Tool for running shell commands in the code interpreter environment.""" """Tool for running shell commands in the code interpreter environment."""
name: str = "execute_command" name: str = "execute_command"
description: str = "Run shell commands in the code interpreter environment" description: str = "Run shell commands in the code interpreter environment"
args_schema: Type[BaseModel] = ExecuteCommandInput args_schema: type[BaseModel] = ExecuteCommandInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, command: str, thread_id: str = "default") -> str: def _run(self, command: str, thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Execute command # Execute command
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="executeCommand", params={"command": command} method="executeCommand", params={"command": command}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error executing command: {str(e)}" return f"Error executing command: {e!s}"
async def _arun(self, command: str, thread_id: str = "default") -> str: async def _arun(self, command: str, thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(command=command, thread_id=thread_id) return self._run(command=command, thread_id=thread_id)
@@ -164,57 +225,65 @@ class ExecuteCommandTool(BaseTool):
class ReadFilesTool(BaseTool): class ReadFilesTool(BaseTool):
"""Tool for reading content of files in the environment.""" """Tool for reading content of files in the environment."""
name: str = "read_files" name: str = "read_files"
description: str = "Read content of files in the environment" description: str = "Read content of files in the environment"
args_schema: Type[BaseModel] = ReadFilesInput args_schema: type[BaseModel] = ReadFilesInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, paths: List[str], thread_id: str = "default") -> str: def _run(self, paths: list[str], thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Read files # Read files
response = code_interpreter.invoke(method="readFiles", params={"paths": paths}) response = code_interpreter.invoke(
method="readFiles", params={"paths": paths}
)
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error reading files: {str(e)}" return f"Error reading files: {e!s}"
async def _arun(self, paths: List[str], thread_id: str = "default") -> str: async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(paths=paths, thread_id=thread_id) return self._run(paths=paths, thread_id=thread_id)
class ListFilesTool(BaseTool): class ListFilesTool(BaseTool):
"""Tool for listing files in directories in the environment.""" """Tool for listing files in directories in the environment."""
name: str = "list_files" name: str = "list_files"
description: str = "List files in directories in the environment" description: str = "List files in directories in the environment"
args_schema: Type[BaseModel] = ListFilesInput args_schema: type[BaseModel] = ListFilesInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, directory_path: str = "", thread_id: str = "default") -> str: def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# List files # List files
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="listFiles", params={"directoryPath": directory_path} method="listFiles", params={"directoryPath": directory_path}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error listing files: {str(e)}" return f"Error listing files: {e!s}"
async def _arun(self, directory_path: str = "", thread_id: str = "default") -> str: async def _arun(self, directory_path: str = "", thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(directory_path=directory_path, thread_id=thread_id) return self._run(directory_path=directory_path, thread_id=thread_id)
@@ -222,89 +291,100 @@ class ListFilesTool(BaseTool):
class DeleteFilesTool(BaseTool): class DeleteFilesTool(BaseTool):
"""Tool for removing files from the environment.""" """Tool for removing files from the environment."""
name: str = "delete_files" name: str = "delete_files"
description: str = "Remove files from the environment" description: str = "Remove files from the environment"
args_schema: Type[BaseModel] = DeleteFilesInput args_schema: type[BaseModel] = DeleteFilesInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, paths: List[str], thread_id: str = "default") -> str: def _run(self, paths: list[str], thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Remove files # Remove files
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="removeFiles", params={"paths": paths} method="removeFiles", params={"paths": paths}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error deleting files: {str(e)}" return f"Error deleting files: {e!s}"
async def _arun(self, paths: List[str], thread_id: str = "default") -> str: async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(paths=paths, thread_id=thread_id) return self._run(paths=paths, thread_id=thread_id)
class WriteFilesTool(BaseTool): class WriteFilesTool(BaseTool):
"""Tool for creating or updating files in the environment.""" """Tool for creating or updating files in the environment."""
name: str = "write_files" name: str = "write_files"
description: str = "Create or update files in the environment" description: str = "Create or update files in the environment"
args_schema: Type[BaseModel] = WriteFilesInput args_schema: type[BaseModel] = WriteFilesInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, files: List[Dict[str, str]], thread_id: str = "default") -> str: def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Write files # Write files
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="writeFiles", params={"content": files} method="writeFiles", params={"content": files}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error writing files: {str(e)}" return f"Error writing files: {e!s}"
async def _arun(self, files: List[Dict[str, str]], thread_id: str = "default") -> str: async def _arun(
self, files: list[dict[str, str]], thread_id: str = "default"
) -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(files=files, thread_id=thread_id) return self._run(files=files, thread_id=thread_id)
class StartCommandTool(BaseTool): class StartCommandTool(BaseTool):
"""Tool for starting long-running commands asynchronously.""" """Tool for starting long-running commands asynchronously."""
name: str = "start_command_execution" name: str = "start_command_execution"
description: str = "Start long-running commands asynchronously" description: str = "Start long-running commands asynchronously"
args_schema: Type[BaseModel] = StartCommandInput args_schema: type[BaseModel] = StartCommandInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, command: str, thread_id: str = "default") -> str: def _run(self, command: str, thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Start command execution # Start command execution
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="startCommandExecution", params={"command": command} method="startCommandExecution", params={"command": command}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error starting command: {str(e)}" return f"Error starting command: {e!s}"
async def _arun(self, command: str, thread_id: str = "default") -> str: async def _arun(self, command: str, thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(command=command, thread_id=thread_id) return self._run(command=command, thread_id=thread_id)
@@ -312,27 +392,32 @@ class StartCommandTool(BaseTool):
class GetTaskTool(BaseTool): class GetTaskTool(BaseTool):
"""Tool for checking status of async tasks.""" """Tool for checking status of async tasks."""
name: str = "get_task" name: str = "get_task"
description: str = "Check status of async tasks" description: str = "Check status of async tasks"
args_schema: Type[BaseModel] = GetTaskInput args_schema: type[BaseModel] = GetTaskInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, task_id: str, thread_id: str = "default") -> str: def _run(self, task_id: str, thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Get task status # Get task status
response = code_interpreter.invoke(method="getTask", params={"taskId": task_id}) response = code_interpreter.invoke(
method="getTask", params={"taskId": task_id}
)
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error getting task status: {str(e)}" return f"Error getting task status: {e!s}"
async def _arun(self, task_id: str, thread_id: str = "default") -> str: async def _arun(self, task_id: str, thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(task_id=task_id, thread_id=thread_id) return self._run(task_id=task_id, thread_id=thread_id)
@@ -340,29 +425,32 @@ class GetTaskTool(BaseTool):
class StopTaskTool(BaseTool): class StopTaskTool(BaseTool):
"""Tool for stopping running tasks.""" """Tool for stopping running tasks."""
name: str = "stop_task" name: str = "stop_task"
description: str = "Stop running tasks" description: str = "Stop running tasks"
args_schema: Type[BaseModel] = StopTaskInput args_schema: type[BaseModel] = StopTaskInput
toolkit: Any = Field(default=None, exclude=True) toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit): def __init__(self, toolkit):
super().__init__() super().__init__()
self.toolkit = toolkit self.toolkit = toolkit
def _run(self, task_id: str, thread_id: str = "default") -> str: def _run(self, task_id: str, thread_id: str = "default") -> str:
try: try:
# Get or create code interpreter # Get or create code interpreter
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id) code_interpreter = self.toolkit._get_or_create_interpreter(
thread_id=thread_id
)
# Stop task # Stop task
response = code_interpreter.invoke( response = code_interpreter.invoke(
method="stopTask", params={"taskId": task_id} method="stopTask", params={"taskId": task_id}
) )
return extract_output_from_stream(response) return extract_output_from_stream(response)
except Exception as e: except Exception as e:
return f"Error stopping task: {str(e)}" return f"Error stopping task: {e!s}"
async def _arun(self, task_id: str, thread_id: str = "default") -> str: async def _arun(self, task_id: str, thread_id: str = "default") -> str:
# Use _run as we're working with a synchronous API that's thread-safe # Use _run as we're working with a synchronous API that's thread-safe
return self._run(task_id=task_id, thread_id=thread_id) return self._run(task_id=task_id, thread_id=thread_id)
@@ -429,8 +517,8 @@ class CodeInterpreterToolkit:
region: AWS region for the code interpreter region: AWS region for the code interpreter
""" """
self.region = region self.region = region
self._code_interpreters: Dict[str, CodeInterpreter] = {} self._code_interpreters: dict[str, CodeInterpreter] = {}
self.tools: List[BaseTool] = [] self.tools: list[BaseTool] = []
self._setup_tools() self._setup_tools()
def _setup_tools(self) -> None: def _setup_tools(self) -> None:
@@ -444,17 +532,15 @@ class CodeInterpreterToolkit:
WriteFilesTool(self), WriteFilesTool(self),
StartCommandTool(self), StartCommandTool(self),
GetTaskTool(self), GetTaskTool(self),
StopTaskTool(self) StopTaskTool(self),
] ]
def _get_or_create_interpreter( def _get_or_create_interpreter(self, thread_id: str = "default") -> CodeInterpreter:
self, thread_id: str = "default"
) -> CodeInterpreter:
"""Get or create a code interpreter for the specified thread. """Get or create a code interpreter for the specified thread.
Args: Args:
thread_id: Thread ID for the code interpreter session thread_id: Thread ID for the code interpreter session
Returns: Returns:
CodeInterpreter instance CodeInterpreter instance
""" """
@@ -463,6 +549,7 @@ class CodeInterpreterToolkit:
# Create a new code interpreter for this thread # Create a new code interpreter for this thread
from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter
code_interpreter = CodeInterpreter(region=self.region) code_interpreter = CodeInterpreter(region=self.region)
code_interpreter.start() code_interpreter.start()
logger.info( logger.info(
@@ -473,8 +560,7 @@ class CodeInterpreterToolkit:
self._code_interpreters[thread_id] = code_interpreter self._code_interpreters[thread_id] = code_interpreter
return code_interpreter return code_interpreter
def get_tools(self) -> list[BaseTool]:
def get_tools(self) -> List[BaseTool]:
""" """
Get the list of code interpreter tools Get the list of code interpreter tools
@@ -483,7 +569,7 @@ class CodeInterpreterToolkit:
""" """
return self.tools return self.tools
def get_tools_by_name(self) -> Dict[str, BaseTool]: def get_tools_by_name(self) -> dict[str, BaseTool]:
""" """
Get a dictionary of tools mapped by their names Get a dictionary of tools mapped by their names
@@ -492,9 +578,9 @@ class CodeInterpreterToolkit:
""" """
return {tool.name: tool for tool in self.tools} return {tool.name: tool for tool in self.tools}
async def cleanup(self, thread_id: Optional[str] = None) -> None: async def cleanup(self, thread_id: str | None = None) -> None:
"""Clean up resources """Clean up resources
Args: Args:
thread_id: Optional thread ID to clean up. If None, cleans up all sessions. thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
""" """
@@ -521,14 +607,14 @@ class CodeInterpreterToolkit:
logger.warning( logger.warning(
f"Error stopping code interpreter for thread {tid}: {e}" f"Error stopping code interpreter for thread {tid}: {e}"
) )
self._code_interpreters = {} self._code_interpreters = {}
logger.info("All code interpreter sessions cleaned up") logger.info("All code interpreter sessions cleaned up")
def create_code_interpreter_toolkit( def create_code_interpreter_toolkit(
region: str = "us-west-2", region: str = "us-west-2",
) -> Tuple[CodeInterpreterToolkit, List[BaseTool]]: ) -> tuple[CodeInterpreterToolkit, list[BaseTool]]:
""" """
Create a CodeInterpreterToolkit Create a CodeInterpreterToolkit
@@ -540,4 +626,4 @@ def create_code_interpreter_toolkit(
""" """
toolkit = CodeInterpreterToolkit(region=region) toolkit = CodeInterpreterToolkit(region=region)
tools = toolkit.get_tools() tools = toolkit.get_tools()
return toolkit, tools return toolkit, tools

View File

@@ -1,17 +1,17 @@
"""Custom exceptions for AWS Bedrock integration.""" """Custom exceptions for AWS Bedrock integration."""
class BedrockError(Exception): class BedrockError(Exception):
"""Base exception for Bedrock-related errors.""" """Base exception for Bedrock-related errors."""
pass
class BedrockAgentError(BedrockError): class BedrockAgentError(BedrockError):
"""Exception raised for errors in the Bedrock Agent operations.""" """Exception raised for errors in the Bedrock Agent operations."""
pass
class BedrockKnowledgeBaseError(BedrockError): class BedrockKnowledgeBaseError(BedrockError):
"""Exception raised for errors in the Bedrock Knowledge Base operations.""" """Exception raised for errors in the Bedrock Knowledge Base operations."""
pass
class BedrockValidationError(BedrockError): class BedrockValidationError(BedrockError):
"""Exception raised for validation errors in Bedrock operations.""" """Exception raised for validation errors in Bedrock operations."""
pass

View File

@@ -1,9 +1,9 @@
from typing import Type, Optional, List, Dict, Any
import os
import json import json
from dotenv import load_dotenv import os
from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from dotenv import load_dotenv
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
@@ -14,28 +14,33 @@ load_dotenv()
class BedrockKBRetrieverToolInput(BaseModel): class BedrockKBRetrieverToolInput(BaseModel):
"""Input schema for BedrockKBRetrieverTool.""" """Input schema for BedrockKBRetrieverTool."""
query: str = Field(..., description="The query to retrieve information from the knowledge base")
query: str = Field(
..., description="The query to retrieve information from the knowledge base"
)
class BedrockKBRetrieverTool(BaseTool): class BedrockKBRetrieverTool(BaseTool):
name: str = "Bedrock Knowledge Base Retriever Tool" name: str = "Bedrock Knowledge Base Retriever Tool"
description: str = "Retrieves information from an Amazon Bedrock Knowledge Base given a query" description: str = (
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput "Retrieves information from an Amazon Bedrock Knowledge Base given a query"
)
args_schema: type[BaseModel] = BedrockKBRetrieverToolInput
knowledge_base_id: str = None knowledge_base_id: str = None
number_of_results: Optional[int] = 5 number_of_results: int | None = 5
retrieval_configuration: Optional[Dict[str, Any]] = None retrieval_configuration: dict[str, Any] | None = None
guardrail_configuration: Optional[Dict[str, Any]] = None guardrail_configuration: dict[str, Any] | None = None
next_token: Optional[str] = None next_token: str | None = None
package_dependencies: List[str] = ["boto3"] package_dependencies: list[str] = ["boto3"]
def __init__( def __init__(
self, self,
knowledge_base_id: str = None, knowledge_base_id: str | None = None,
number_of_results: Optional[int] = 5, number_of_results: int | None = 5,
retrieval_configuration: Optional[Dict[str, Any]] = None, retrieval_configuration: dict[str, Any] | None = None,
guardrail_configuration: Optional[Dict[str, Any]] = None, guardrail_configuration: dict[str, Any] | None = None,
next_token: Optional[str] = None, next_token: str | None = None,
**kwargs **kwargs,
): ):
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration. """Initialize the BedrockKBRetrieverTool with knowledge base configuration.
@@ -49,7 +54,7 @@ class BedrockKBRetrieverTool(BaseTool):
super().__init__(**kwargs) super().__init__(**kwargs)
# Get knowledge_base_id from environment variable if not provided # Get knowledge_base_id from environment variable if not provided
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID') self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID")
self.number_of_results = number_of_results self.number_of_results = number_of_results
self.guardrail_configuration = guardrail_configuration self.guardrail_configuration = guardrail_configuration
self.next_token = next_token self.next_token = next_token
@@ -66,7 +71,7 @@ class BedrockKBRetrieverTool(BaseTool):
# Update the description to include the knowledge base details # Update the description to include the knowledge base details
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query" self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
def _build_retrieval_configuration(self) -> Dict[str, Any]: def _build_retrieval_configuration(self) -> dict[str, Any]:
"""Build the retrieval configuration based on provided parameters. """Build the retrieval configuration based on provided parameters.
Returns: Returns:
@@ -89,17 +94,23 @@ class BedrockKBRetrieverTool(BaseTool):
if not isinstance(self.knowledge_base_id, str): if not isinstance(self.knowledge_base_id, str):
raise BedrockValidationError("knowledge_base_id must be a string") raise BedrockValidationError("knowledge_base_id must be a string")
if len(self.knowledge_base_id) > 10: if len(self.knowledge_base_id) > 10:
raise BedrockValidationError("knowledge_base_id must be 10 characters or less") raise BedrockValidationError(
"knowledge_base_id must be 10 characters or less"
)
if not all(c.isalnum() for c in self.knowledge_base_id): if not all(c.isalnum() for c in self.knowledge_base_id):
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters") raise BedrockValidationError(
"knowledge_base_id must contain only alphanumeric characters"
)
# Validate next_token if provided # Validate next_token if provided
if self.next_token: if self.next_token:
if not isinstance(self.next_token, str): if not isinstance(self.next_token, str):
raise BedrockValidationError("next_token must be a string") raise BedrockValidationError("next_token must be a string")
if len(self.next_token) < 1 or len(self.next_token) > 2048: if len(self.next_token) < 1 or len(self.next_token) > 2048:
raise BedrockValidationError("next_token must be between 1 and 2048 characters") raise BedrockValidationError(
if ' ' in self.next_token: "next_token must be between 1 and 2048 characters"
)
if " " in self.next_token:
raise BedrockValidationError("next_token cannot contain spaces") raise BedrockValidationError("next_token cannot contain spaces")
# Validate number_of_results if provided # Validate number_of_results if provided
@@ -107,12 +118,14 @@ class BedrockKBRetrieverTool(BaseTool):
if not isinstance(self.number_of_results, int): if not isinstance(self.number_of_results, int):
raise BedrockValidationError("number_of_results must be an integer") raise BedrockValidationError("number_of_results must be an integer")
if self.number_of_results < 1: if self.number_of_results < 1:
raise BedrockValidationError("number_of_results must be greater than 0") raise BedrockValidationError(
"number_of_results must be greater than 0"
)
except BedrockValidationError as e: except BedrockValidationError as e:
raise BedrockValidationError(f"Parameter validation failed: {str(e)}") raise BedrockValidationError(f"Parameter validation failed: {e!s}")
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]: def _process_retrieval_result(self, result: dict[str, Any]) -> dict[str, Any]:
"""Process a single retrieval result from Bedrock Knowledge Base. """Process a single retrieval result from Bedrock Knowledge Base.
Args: Args:
@@ -122,57 +135,57 @@ class BedrockKBRetrieverTool(BaseTool):
Dict[str, Any]: Processed result with standardized format Dict[str, Any]: Processed result with standardized format
""" """
# Extract content # Extract content
content_obj = result.get('content', {}) content_obj = result.get("content", {})
content = content_obj.get('text', '') content = content_obj.get("text", "")
content_type = content_obj.get('type', 'text') content_type = content_obj.get("type", "text")
# Extract location information # Extract location information
location = result.get('location', {}) location = result.get("location", {})
location_type = location.get('type', 'unknown') location_type = location.get("type", "unknown")
source_uri = None source_uri = None
# Map for location types and their URI fields # Map for location types and their URI fields
location_mapping = { location_mapping = {
's3Location': {'field': 'uri', 'type': 'S3'}, "s3Location": {"field": "uri", "type": "S3"},
'confluenceLocation': {'field': 'url', 'type': 'Confluence'}, "confluenceLocation": {"field": "url", "type": "Confluence"},
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'}, "salesforceLocation": {"field": "url", "type": "Salesforce"},
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'}, "sharePointLocation": {"field": "url", "type": "SharePoint"},
'webLocation': {'field': 'url', 'type': 'Web'}, "webLocation": {"field": "url", "type": "Web"},
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'}, "customDocumentLocation": {"field": "id", "type": "CustomDocument"},
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'}, "kendraDocumentLocation": {"field": "uri", "type": "KendraDocument"},
'sqlLocation': {'field': 'query', 'type': 'SQL'} "sqlLocation": {"field": "query", "type": "SQL"},
} }
# Extract the URI based on location type # Extract the URI based on location type
for loc_key, config in location_mapping.items(): for loc_key, config in location_mapping.items():
if loc_key in location: if loc_key in location:
source_uri = location[loc_key].get(config['field']) source_uri = location[loc_key].get(config["field"])
if not location_type or location_type == 'unknown': if not location_type or location_type == "unknown":
location_type = config['type'] location_type = config["type"]
break break
# Create result object # Create result object
result_object = { result_object = {
'content': content, "content": content,
'content_type': content_type, "content_type": content_type,
'source_type': location_type, "source_type": location_type,
'source_uri': source_uri "source_uri": source_uri,
} }
# Add optional fields if available # Add optional fields if available
if 'score' in result: if "score" in result:
result_object['score'] = result['score'] result_object["score"] = result["score"]
if 'metadata' in result: if "metadata" in result:
result_object['metadata'] = result['metadata'] result_object["metadata"] = result["metadata"]
# Handle byte content if present # Handle byte content if present
if 'byteContent' in content_obj: if "byteContent" in content_obj:
result_object['byte_content'] = content_obj['byteContent'] result_object["byte_content"] = content_obj["byteContent"]
# Handle row content if present # Handle row content if present
if 'row' in content_obj: if "row" in content_obj:
result_object['row_content'] = content_obj['row'] result_object["row_content"] = content_obj["row"]
return result_object return result_object
@@ -186,35 +199,35 @@ class BedrockKBRetrieverTool(BaseTool):
try: try:
# Initialize the Bedrock Agent Runtime client # Initialize the Bedrock Agent Runtime client
bedrock_agent_runtime = boto3.client( bedrock_agent_runtime = boto3.client(
'bedrock-agent-runtime', "bedrock-agent-runtime",
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-east-1')), region_name=os.getenv(
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")
),
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment # AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
) )
# Prepare the request parameters # Prepare the request parameters
retrieve_params = { retrieve_params = {
'knowledgeBaseId': self.knowledge_base_id, "knowledgeBaseId": self.knowledge_base_id,
'retrievalQuery': { "retrievalQuery": {"text": query},
'text': query
}
} }
# Add optional parameters if provided # Add optional parameters if provided
if self.retrieval_configuration: if self.retrieval_configuration:
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration retrieve_params["retrievalConfiguration"] = self.retrieval_configuration
if self.guardrail_configuration: if self.guardrail_configuration:
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration retrieve_params["guardrailConfiguration"] = self.guardrail_configuration
if self.next_token: if self.next_token:
retrieve_params['nextToken'] = self.next_token retrieve_params["nextToken"] = self.next_token
# Make the retrieve API call # Make the retrieve API call
response = bedrock_agent_runtime.retrieve(**retrieve_params) response = bedrock_agent_runtime.retrieve(**retrieve_params)
# Process the response # Process the response
results = [] results = []
for result in response.get('retrievalResults', []): for result in response.get("retrievalResults", []):
processed_result = self._process_retrieval_result(result) processed_result = self._process_retrieval_result(result)
results.append(processed_result) results.append(processed_result)
@@ -239,10 +252,10 @@ class BedrockKBRetrieverTool(BaseTool):
error_message = str(e) error_message = str(e)
# Try to extract error code if available # Try to extract error code if available
if hasattr(e, 'response') and 'Error' in e.response: if hasattr(e, "response") and "Error" in e.response:
error_code = e.response['Error'].get('Code', 'Unknown') error_code = e.response["Error"].get("Code", "Unknown")
error_message = e.response['Error'].get('Message', str(e)) error_message = e.response["Error"].get("Message", str(e))
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}") raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
except Exception as e: except Exception as e:
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}") raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}")

View File

@@ -1,2 +1,2 @@
from .reader_tool import S3ReaderTool from .reader_tool import S3ReaderTool
from .writer_tool import S3WriterTool from .writer_tool import S3WriterTool

View File

@@ -1,4 +1,3 @@
from typing import Any, Type, List
import os import os
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -8,14 +7,16 @@ from pydantic import BaseModel, Field
class S3ReaderToolInput(BaseModel): class S3ReaderToolInput(BaseModel):
"""Input schema for S3ReaderTool.""" """Input schema for S3ReaderTool."""
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')") file_path: str = Field(
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
)
class S3ReaderTool(BaseTool): class S3ReaderTool(BaseTool):
name: str = "S3 Reader Tool" name: str = "S3 Reader Tool"
description: str = "Reads a file from Amazon S3 given an S3 file path" description: str = "Reads a file from Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3ReaderToolInput args_schema: type[BaseModel] = S3ReaderToolInput
package_dependencies: List[str] = ["boto3"] package_dependencies: list[str] = ["boto3"]
def _run(self, file_path: str) -> str: def _run(self, file_path: str) -> str:
try: try:
@@ -28,19 +29,18 @@ class S3ReaderTool(BaseTool):
bucket_name, object_key = self._parse_s3_path(file_path) bucket_name, object_key = self._parse_s3_path(file_path)
s3 = boto3.client( s3 = boto3.client(
's3', "s3",
region_name=os.getenv('CREW_AWS_REGION', 'us-east-1'), region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
aws_access_key_id=os.getenv('CREW_AWS_ACCESS_KEY_ID'), aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv('CREW_AWS_SEC_ACCESS_KEY') aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
) )
# Read file content from S3 # Read file content from S3
response = s3.get_object(Bucket=bucket_name, Key=object_key) response = s3.get_object(Bucket=bucket_name, Key=object_key)
file_content = response['Body'].read().decode('utf-8') return response["Body"].read().decode("utf-8")
return file_content
except ClientError as e: except ClientError as e:
return f"Error reading file from S3: {str(e)}" return f"Error reading file from S3: {e!s}"
def _parse_s3_path(self, file_path: str) -> tuple: def _parse_s3_path(self, file_path: str) -> tuple:
parts = file_path.replace("s3://", "").split("/", 1) parts = file_path.replace("s3://", "").split("/", 1)

View File

@@ -1,20 +1,23 @@
from typing import Type, List
import os import os
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class S3WriterToolInput(BaseModel): class S3WriterToolInput(BaseModel):
"""Input schema for S3WriterTool.""" """Input schema for S3WriterTool."""
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')")
file_path: str = Field(
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
)
content: str = Field(..., description="Content to write to the file") content: str = Field(..., description="Content to write to the file")
class S3WriterTool(BaseTool): class S3WriterTool(BaseTool):
name: str = "S3 Writer Tool" name: str = "S3 Writer Tool"
description: str = "Writes content to a file in Amazon S3 given an S3 file path" description: str = "Writes content to a file in Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3WriterToolInput args_schema: type[BaseModel] = S3WriterToolInput
package_dependencies: List[str] = ["boto3"] package_dependencies: list[str] = ["boto3"]
def _run(self, file_path: str, content: str) -> str: def _run(self, file_path: str, content: str) -> str:
try: try:
@@ -27,16 +30,18 @@ class S3WriterTool(BaseTool):
bucket_name, object_key = self._parse_s3_path(file_path) bucket_name, object_key = self._parse_s3_path(file_path)
s3 = boto3.client( s3 = boto3.client(
's3', "s3",
region_name=os.getenv('CREW_AWS_REGION', 'us-east-1'), region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
aws_access_key_id=os.getenv('CREW_AWS_ACCESS_KEY_ID'), aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv('CREW_AWS_SEC_ACCESS_KEY') aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
) )
s3.put_object(Bucket=bucket_name, Key=object_key, Body=content.encode('utf-8')) s3.put_object(
Bucket=bucket_name, Key=object_key, Body=content.encode("utf-8")
)
return f"Successfully wrote content to {file_path}" return f"Successfully wrote content to {file_path}"
except ClientError as e: except ClientError as e:
return f"Error writing file to S3: {str(e)}" return f"Error writing file to S3: {e!s}"
def _parse_s3_path(self, file_path: str) -> tuple: def _parse_s3_path(self, file_path: str) -> tuple:
parts = file_path.replace("s3://", "").split("/", 1) parts = file_path.replace("s3://", "").split("/", 1)

View File

@@ -1,13 +1,11 @@
"""Utility for colored console output.""" """Utility for colored console output."""
from typing import Optional
class Printer: class Printer:
"""Handles colored console output formatting.""" """Handles colored console output formatting."""
@staticmethod @staticmethod
def print(content: str, color: Optional[str] = None) -> None: def print(content: str, color: str | None = None) -> None:
"""Prints content with optional color formatting. """Prints content with optional color formatting.
Args: Args:
@@ -29,7 +27,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold purple. content: The string to be printed in bold purple.
""" """
print("\033[1m\033[95m {}\033[00m".format(content)) print(f"\033[1m\033[95m {content}\033[00m")
@staticmethod @staticmethod
def _print_bold_green(content: str) -> None: def _print_bold_green(content: str) -> None:
@@ -38,7 +36,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold green. content: The string to be printed in bold green.
""" """
print("\033[1m\033[92m {}\033[00m".format(content)) print(f"\033[1m\033[92m {content}\033[00m")
@staticmethod @staticmethod
def _print_purple(content: str) -> None: def _print_purple(content: str) -> None:
@@ -47,7 +45,7 @@ class Printer:
Args: Args:
content: The string to be printed in purple. content: The string to be printed in purple.
""" """
print("\033[95m {}\033[00m".format(content)) print(f"\033[95m {content}\033[00m")
@staticmethod @staticmethod
def _print_red(content: str) -> None: def _print_red(content: str) -> None:
@@ -56,7 +54,7 @@ class Printer:
Args: Args:
content: The string to be printed in red. content: The string to be printed in red.
""" """
print("\033[91m {}\033[00m".format(content)) print(f"\033[91m {content}\033[00m")
@staticmethod @staticmethod
def _print_bold_blue(content: str) -> None: def _print_bold_blue(content: str) -> None:
@@ -65,7 +63,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold blue. content: The string to be printed in bold blue.
""" """
print("\033[1m\033[94m {}\033[00m".format(content)) print(f"\033[1m\033[94m {content}\033[00m")
@staticmethod @staticmethod
def _print_yellow(content: str) -> None: def _print_yellow(content: str) -> None:
@@ -74,7 +72,7 @@ class Printer:
Args: Args:
content: The string to be printed in yellow. content: The string to be printed in yellow.
""" """
print("\033[93m {}\033[00m".format(content)) print(f"\033[93m {content}\033[00m")
@staticmethod @staticmethod
def _print_bold_yellow(content: str) -> None: def _print_bold_yellow(content: str) -> None:
@@ -83,7 +81,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold yellow. content: The string to be printed in bold yellow.
""" """
print("\033[1m\033[93m {}\033[00m".format(content)) print(f"\033[1m\033[93m {content}\033[00m")
@staticmethod @staticmethod
def _print_cyan(content: str) -> None: def _print_cyan(content: str) -> None:
@@ -92,7 +90,7 @@ class Printer:
Args: Args:
content: The string to be printed in cyan. content: The string to be printed in cyan.
""" """
print("\033[96m {}\033[00m".format(content)) print(f"\033[96m {content}\033[00m")
@staticmethod @staticmethod
def _print_bold_cyan(content: str) -> None: def _print_bold_cyan(content: str) -> None:
@@ -101,7 +99,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold cyan. content: The string to be printed in bold cyan.
""" """
print("\033[1m\033[96m {}\033[00m".format(content)) print(f"\033[1m\033[96m {content}\033[00m")
@staticmethod @staticmethod
def _print_magenta(content: str) -> None: def _print_magenta(content: str) -> None:
@@ -110,7 +108,7 @@ class Printer:
Args: Args:
content: The string to be printed in magenta. content: The string to be printed in magenta.
""" """
print("\033[35m {}\033[00m".format(content)) print(f"\033[35m {content}\033[00m")
@staticmethod @staticmethod
def _print_bold_magenta(content: str) -> None: def _print_bold_magenta(content: str) -> None:
@@ -119,7 +117,7 @@ class Printer:
Args: Args:
content: The string to be printed in bold magenta. content: The string to be printed in bold magenta.
""" """
print("\033[1m\033[35m {}\033[00m".format(content)) print(f"\033[1m\033[35m {content}\033[00m")
@staticmethod @staticmethod
def _print_green(content: str) -> None: def _print_green(content: str) -> None:
@@ -128,4 +126,4 @@ class Printer:
Args: Args:
content: The string to be printed in green. content: The string to be printed in green.
""" """
print("\033[32m {}\033[00m".format(content)) print(f"\033[32m {content}\033[00m")

View File

@@ -3,6 +3,6 @@ from crewai_tools.rag.data_types import DataType
__all__ = [ __all__ = [
"RAG", "RAG",
"EmbeddingService",
"DataType", "DataType",
"EmbeddingService",
] ]

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai_tools.rag.misc import compute_sha256 from crewai_tools.rag.misc import compute_sha256
@@ -9,19 +10,22 @@ from crewai_tools.rag.source_content import SourceContent
class LoaderResult(BaseModel): class LoaderResult(BaseModel):
content: str = Field(description="The text content of the source") content: str = Field(description="The text content of the source")
source: str = Field(description="The source of the content", default="unknown") source: str = Field(description="The source of the content", default="unknown")
metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict) metadata: dict[str, Any] = Field(
description="The metadata of the source", default_factory=dict
)
doc_id: str = Field(description="The id of the document") doc_id: str = Field(description="The id of the document")
class BaseLoader(ABC): class BaseLoader(ABC):
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: dict[str, Any] | None = None):
self.config = config or {} self.config = config or {}
@abstractmethod @abstractmethod
def load(self, content: SourceContent, **kwargs) -> LoaderResult: def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
...
def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str: def generate_doc_id(
self, source_ref: str | None = None, content: str | None = None
) -> str:
""" """
Generate a unique document id based on the source reference and content. Generate a unique document id based on the source reference and content.
If the source reference is not provided, the content is used as the source reference. If the source reference is not provided, the content is used as the source reference.

View File

@@ -1,15 +1,19 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker from crewai_tools.rag.chunkers.structured_chunker import (
from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker CsvChunker,
JsonChunker,
XmlChunker,
)
from crewai_tools.rag.chunkers.text_chunker import DocxChunker, MdxChunker, TextChunker
__all__ = [ __all__ = [
"BaseChunker", "BaseChunker",
"DefaultChunker",
"TextChunker",
"DocxChunker",
"MdxChunker",
"CsvChunker", "CsvChunker",
"DefaultChunker",
"DocxChunker",
"JsonChunker", "JsonChunker",
"MdxChunker",
"TextChunker",
"XmlChunker", "XmlChunker",
] ]

View File

@@ -1,6 +1,6 @@
from typing import List, Optional
import re import re
class RecursiveCharacterTextSplitter: class RecursiveCharacterTextSplitter:
""" """
A text splitter that recursively splits text based on a hierarchy of separators. A text splitter that recursively splits text based on a hierarchy of separators.
@@ -10,7 +10,7 @@ class RecursiveCharacterTextSplitter:
self, self,
chunk_size: int = 4000, chunk_size: int = 4000,
chunk_overlap: int = 200, chunk_overlap: int = 200,
separators: Optional[List[str]] = None, separators: list[str] | None = None,
keep_separator: bool = True, keep_separator: bool = True,
): ):
""" """
@@ -23,7 +23,9 @@ class RecursiveCharacterTextSplitter:
keep_separator: Whether to keep the separator in the split text keep_separator: Whether to keep the separator in the split text
""" """
if chunk_overlap >= chunk_size: if chunk_overlap >= chunk_size:
raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})") raise ValueError(
f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})"
)
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
@@ -36,10 +38,10 @@ class RecursiveCharacterTextSplitter:
"", "",
] ]
def split_text(self, text: str) -> List[str]: def split_text(self, text: str) -> list[str]:
return self._split_text(text, self._separators) return self._split_text(text, self._separators)
def _split_text(self, text: str, separators: List[str]) -> List[str]: def _split_text(self, text: str, separators: list[str]) -> list[str]:
separator = separators[-1] separator = separators[-1]
new_separators = [] new_separators = []
@@ -49,7 +51,7 @@ class RecursiveCharacterTextSplitter:
break break
if re.search(re.escape(sep), text): if re.search(re.escape(sep), text):
separator = sep separator = sep
new_separators = separators[i + 1:] new_separators = separators[i + 1 :]
break break
splits = self._split_text_with_separator(text, separator) splits = self._split_text_with_separator(text, separator)
@@ -68,7 +70,7 @@ class RecursiveCharacterTextSplitter:
return self._merge_splits(good_splits, separator) return self._merge_splits(good_splits, separator)
def _split_text_with_separator(self, text: str, separator: str) -> List[str]: def _split_text_with_separator(self, text: str, separator: str) -> list[str]:
if separator == "": if separator == "":
return list(text) return list(text)
@@ -90,16 +92,15 @@ class RecursiveCharacterTextSplitter:
splits[-1] += separator splits[-1] += separator
return [s for s in splits if s] return [s for s in splits if s]
else: return text.split(separator)
return text.split(separator)
def _split_by_characters(self, text: str) -> List[str]: def _split_by_characters(self, text: str) -> list[str]:
chunks = [] chunks = []
for i in range(0, len(text), self._chunk_size): for i in range(0, len(text), self._chunk_size):
chunks.append(text[i:i + self._chunk_size]) chunks.append(text[i : i + self._chunk_size])
return chunks return chunks
def _merge_splits(self, splits: List[str], separator: str) -> List[str]: def _merge_splits(self, splits: list[str], separator: str) -> list[str]:
"""Merge splits into chunks with proper overlap.""" """Merge splits into chunks with proper overlap."""
docs = [] docs = []
current_doc = [] current_doc = []
@@ -112,7 +113,10 @@ class RecursiveCharacterTextSplitter:
if separator == "": if separator == "":
doc = "".join(current_doc) doc = "".join(current_doc)
else: else:
doc = separator.join(current_doc) if self._keep_separator and separator == " ":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc: if doc:
docs.append(doc) docs.append(doc)
@@ -133,15 +137,25 @@ class RecursiveCharacterTextSplitter:
if separator == "": if separator == "":
doc = "".join(current_doc) doc = "".join(current_doc)
else: else:
doc = separator.join(current_doc) if self._keep_separator and separator == " ":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc: if doc:
docs.append(doc) docs.append(doc)
return docs return docs
class BaseChunker: class BaseChunker:
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
separators: list[str] | None = None,
keep_separator: bool = True,
):
""" """
Initialize the Chunker Initialize the Chunker
@@ -159,8 +173,7 @@ class BaseChunker:
keep_separator=keep_separator, keep_separator=keep_separator,
) )
def chunk(self, text: str) -> list[str]:
def chunk(self, text: str) -> List[str]:
if not text or not text.strip(): if not text or not text.strip():
return [] return []

View File

@@ -1,6 +1,12 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class DefaultChunker(BaseChunker): class DefaultChunker(BaseChunker):
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 2000,
chunk_overlap: int = 20,
separators: list[str] | None = None,
keep_separator: bool = True,
):
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -1,49 +1,66 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class CsvChunker(BaseChunker): class CsvChunker(BaseChunker):
def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 1200,
chunk_overlap: int = 100,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\nRow ", # Row boundaries (from CSVLoader format) "\nRow ", # Row boundaries (from CSVLoader format)
"\n", # Line breaks "\n", # Line breaks
" | ", # Column separators " | ", # Column separators
", ", # Comma separators ", ", # Comma separators
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class JsonChunker(BaseChunker): class JsonChunker(BaseChunker):
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 2000,
chunk_overlap: int = 200,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n\n", # Object/array boundaries "\n\n", # Object/array boundaries
"\n", # Line breaks "\n", # Line breaks
"},", # Object endings "},", # Object endings
"],", # Array endings "],", # Array endings
", ", # Property separators ", ", # Property separators
": ", # Key-value separators ": ", # Key-value separators
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class XmlChunker(BaseChunker): class XmlChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 2500,
chunk_overlap: int = 250,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n\n", # Element boundaries "\n\n", # Element boundaries
"\n", # Line breaks "\n", # Line breaks
">", # Tag endings ">", # Tag endings
". ", # Sentence endings (for text content) ". ", # Sentence endings (for text content)
"! ", # Exclamation endings "! ", # Exclamation endings
"? ", # Question endings "? ", # Question endings
", ", # Comma separators ", ", # Comma separators
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -1,59 +1,76 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class TextChunker(BaseChunker): class TextChunker(BaseChunker):
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 1500,
chunk_overlap: int = 150,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n\n\n", # Multiple line breaks (sections) "\n\n\n", # Multiple line breaks (sections)
"\n\n", # Paragraph breaks "\n\n", # Paragraph breaks
"\n", # Line breaks "\n", # Line breaks
". ", # Sentence endings ". ", # Sentence endings
"! ", # Exclamation endings "! ", # Exclamation endings
"? ", # Question endings "? ", # Question endings
"; ", # Semicolon breaks "; ", # Semicolon breaks
", ", # Comma breaks ", ", # Comma breaks
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class DocxChunker(BaseChunker): class DocxChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 2500,
chunk_overlap: int = 250,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n\n\n", # Multiple line breaks (major sections) "\n\n\n", # Multiple line breaks (major sections)
"\n\n", # Paragraph breaks "\n\n", # Paragraph breaks
"\n", # Line breaks "\n", # Line breaks
". ", # Sentence endings ". ", # Sentence endings
"! ", # Exclamation endings "! ", # Exclamation endings
"? ", # Question endings "? ", # Question endings
"; ", # Semicolon breaks "; ", # Semicolon breaks
", ", # Comma breaks ", ", # Comma breaks
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class MdxChunker(BaseChunker): class MdxChunker(BaseChunker):
def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 3000,
chunk_overlap: int = 300,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n## ", # H2 headers (major sections) "\n## ", # H2 headers (major sections)
"\n### ", # H3 headers (subsections) "\n### ", # H3 headers (subsections)
"\n#### ", # H4 headers (sub-subsections) "\n#### ", # H4 headers (sub-subsections)
"\n\n", # Paragraph breaks "\n\n", # Paragraph breaks
"\n```", # Code block boundaries "\n```", # Code block boundaries
"\n", # Line breaks "\n", # Line breaks
". ", # Sentence endings ". ", # Sentence endings
"! ", # Exclamation endings "! ", # Exclamation endings
"? ", # Question endings "? ", # Question endings
"; ", # Semicolon breaks "; ", # Semicolon breaks
", ", # Comma breaks ", ", # Comma breaks
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -1,20 +1,25 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class WebsiteChunker(BaseChunker): class WebsiteChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): def __init__(
self,
chunk_size: int = 2500,
chunk_overlap: int = 250,
separators: list[str] | None = None,
keep_separator: bool = True,
):
if separators is None: if separators is None:
separators = [ separators = [
"\n\n\n", # Major section breaks "\n\n\n", # Major section breaks
"\n\n", # Paragraph breaks "\n\n", # Paragraph breaks
"\n", # Line breaks "\n", # Line breaks
". ", # Sentence endings ". ", # Sentence endings
"! ", # Exclamation endings "! ", # Exclamation endings
"? ", # Question endings "? ", # Question endings
"; ", # Semicolon breaks "; ", # Semicolon breaks
", ", # Comma breaks ", ", # Comma breaks
" ", # Word breaks " ", # Word breaks
"", # Character level "", # Character level
] ]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator) super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -1,18 +1,18 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any
from uuid import uuid4 from uuid import uuid4
import chromadb import chromadb
import litellm import litellm
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.base_loader import BaseLoader from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import compute_sha256 from crewai_tools.rag.misc import compute_sha256
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.tools.rag.rag_tool import Adapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,29 +22,21 @@ class EmbeddingService:
self.model = model self.model = model
self.kwargs = kwargs self.kwargs = kwargs
def embed_text(self, text: str) -> List[float]: def embed_text(self, text: str) -> list[float]:
try: try:
response = litellm.embedding( response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
model=self.model, return response.data[0]["embedding"]
input=[text],
**self.kwargs
)
return response.data[0]['embedding']
except Exception as e: except Exception as e:
logger.error(f"Error generating embedding: {e}") logger.error(f"Error generating embedding: {e}")
raise raise
def embed_batch(self, texts: List[str]) -> List[List[float]]: def embed_batch(self, texts: list[str]) -> list[list[float]]:
if not texts: if not texts:
return [] return []
try: try:
response = litellm.embedding( response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
model=self.model, return [data["embedding"] for data in response.data]
input=texts,
**self.kwargs
)
return [data['embedding'] for data in response.data]
except Exception as e: except Exception as e:
logger.error(f"Error generating batch embeddings: {e}") logger.error(f"Error generating batch embeddings: {e}")
raise raise
@@ -53,18 +45,18 @@ class EmbeddingService:
class Document(BaseModel): class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4())) id: str = Field(default_factory=lambda: str(uuid4()))
content: str content: str
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
data_type: DataType = DataType.TEXT data_type: DataType = DataType.TEXT
source: Optional[str] = None source: str | None = None
class RAG(Adapter): class RAG(Adapter):
collection_name: str = "crewai_knowledge_base" collection_name: str = "crewai_knowledge_base"
persist_directory: Optional[str] = None persist_directory: str | None = None
embedding_model: str = "text-embedding-3-large" embedding_model: str = "text-embedding-3-large"
summarize: bool = False summarize: bool = False
top_k: int = 5 top_k: int = 5
embedding_config: Dict[str, Any] = Field(default_factory=dict) embedding_config: dict[str, Any] = Field(default_factory=dict)
_client: Any = PrivateAttr() _client: Any = PrivateAttr()
_collection: Any = PrivateAttr() _collection: Any = PrivateAttr()
@@ -79,10 +71,15 @@ class RAG(Adapter):
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=self.collection_name, name=self.collection_name,
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"} metadata={
"hnsw:space": "cosine",
"description": "CrewAI Knowledge Base",
},
) )
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config) self._embedding_service = EmbeddingService(
model=self.embedding_model, **self.embedding_config
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}") logger.error(f"Failed to initialize ChromaDB: {e}")
raise raise
@@ -92,11 +89,11 @@ class RAG(Adapter):
def add( def add(
self, self,
content: str | Path, content: str | Path,
data_type: Optional[Union[str, DataType]] = None, data_type: str | DataType | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
loader: Optional[BaseLoader] = None, loader: BaseLoader | None = None,
chunker: Optional[BaseChunker] = None, chunker: BaseChunker | None = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
source_content = SourceContent(content) source_content = SourceContent(content)
@@ -111,11 +108,19 @@ class RAG(Adapter):
loader_result = loader.load(source_content) loader_result = loader.load(source_content)
doc_id = loader_result.doc_id doc_id = loader_result.doc_id
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1) existing_doc = self._collection.get(
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None where={"source": source_content.source_ref}, limit=1
)
existing_doc_id = (
existing_doc and existing_doc["metadatas"][0]["doc_id"]
if existing_doc["metadatas"]
else None
)
if existing_doc_id == doc_id: if existing_doc_id == doc_id:
logger.warning(f"Document with source {loader_result.source} already exists") logger.warning(
f"Document with source {loader_result.source} already exists"
)
return return
# Document with same source ref does exists but the content has changed, deleting the oldest reference # Document with same source ref does exists but the content has changed, deleting the oldest reference
@@ -128,14 +133,16 @@ class RAG(Adapter):
chunks = chunker.chunk(loader_result.content) chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy() doc_metadata = (metadata or {}).copy()
doc_metadata['chunk_index'] = i doc_metadata["chunk_index"] = i
documents.append(Document( documents.append(
id=compute_sha256(chunk), Document(
content=chunk, id=compute_sha256(chunk),
metadata=doc_metadata, content=chunk,
data_type=data_type, metadata=doc_metadata,
source=loader_result.source data_type=data_type,
)) source=loader_result.source,
)
)
if not documents: if not documents:
logger.warning("No documents to add") logger.warning("No documents to add")
@@ -153,11 +160,13 @@ class RAG(Adapter):
for doc in documents: for doc in documents:
doc_metadata = doc.metadata.copy() doc_metadata = doc.metadata.copy()
doc_metadata.update({ doc_metadata.update(
"data_type": doc.data_type.value, {
"source": doc.source, "data_type": doc.data_type.value,
"doc_id": doc_id "source": doc.source,
}) "doc_id": doc_id,
}
)
metadatas.append(doc_metadata) metadatas.append(doc_metadata)
try: try:
@@ -171,7 +180,7 @@ class RAG(Adapter):
except Exception as e: except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}") logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str: def query(self, question: str, where: dict[str, Any] | None = None) -> str:
try: try:
question_embedding = self._embedding_service.embed_text(question) question_embedding = self._embedding_service.embed_text(question)
@@ -179,10 +188,14 @@ class RAG(Adapter):
query_embeddings=[question_embedding], query_embeddings=[question_embedding],
n_results=self.top_k, n_results=self.top_k,
where=where, where=where,
include=["documents", "metadatas", "distances"] include=["documents", "metadatas", "distances"],
) )
if not results or not results.get("documents") or not results["documents"][0]: if (
not results
or not results.get("documents")
or not results["documents"][0]
):
return "No relevant content found." return "No relevant content found."
documents = results["documents"][0] documents = results["documents"][0]
@@ -195,8 +208,12 @@ class RAG(Adapter):
metadata = metadatas[i] if i < len(metadatas) else {} metadata = metadatas[i] if i < len(metadatas) else {}
distance = distances[i] if i < len(distances) else 1.0 distance = distances[i] if i < len(distances) else 1.0
source = metadata.get("source", "unknown") if metadata else "unknown" source = metadata.get("source", "unknown") if metadata else "unknown"
score = 1 - distance if distance is not None else 0 # Convert distance to similarity score = (
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}") 1 - distance if distance is not None else 0
) # Convert distance to similarity
formatted_results.append(
f"[Source: {source}, Relevance: {score:.3f}]\n{doc}"
)
return "\n\n".join(formatted_results) return "\n\n".join(formatted_results)
except Exception as e: except Exception as e:
@@ -210,23 +227,25 @@ class RAG(Adapter):
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection: {e}") logger.error(f"Failed to delete collection: {e}")
def get_collection_info(self) -> Dict[str, Any]: def get_collection_info(self) -> dict[str, Any]:
try: try:
count = self._collection.count() count = self._collection.count()
return { return {
"name": self.collection_name, "name": self.collection_name,
"count": count, "count": count,
"embedding_model": self.embedding_model "embedding_model": self.embedding_model,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get collection info: {e}") logger.error(f"Failed to get collection info: {e}")
return {"error": str(e)} return {"error": str(e)}
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType: def _get_data_type(
self, content: SourceContent, data_type: str | DataType | None = None
) -> DataType:
try: try:
if isinstance(data_type, str): if isinstance(data_type, str):
return DataType(data_type) return DataType(data_type)
except Exception as e: except Exception:
pass pass
return content.data_type return content.data_type

View File

@@ -1,9 +1,11 @@
import os
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import os
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.base_loader import BaseLoader from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
class DataType(str, Enum): class DataType(str, Enum):
PDF_FILE = "pdf_file" PDF_FILE = "pdf_file"
@@ -25,29 +27,38 @@ class DataType(str, Enum):
# Web types # Web types
WEBSITE = "website" WEBSITE = "website"
DOCS_SITE = "docs_site" DOCS_SITE = "docs_site"
YOUTUBE_VIDEO = "youtube_video"
YOUTUBE_CHANNEL = "youtube_channel"
# Raw types # Raw types
TEXT = "text" TEXT = "text"
def get_chunker(self) -> BaseChunker: def get_chunker(self) -> BaseChunker:
from importlib import import_module from importlib import import_module
chunkers = { chunkers = {
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
DataType.TEXT_FILE: ("text_chunker", "TextChunker"), DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
DataType.TEXT: ("text_chunker", "TextChunker"), DataType.TEXT: ("text_chunker", "TextChunker"),
DataType.DOCX: ("text_chunker", "DocxChunker"), DataType.DOCX: ("text_chunker", "DocxChunker"),
DataType.MDX: ("text_chunker", "MdxChunker"), DataType.MDX: ("text_chunker", "MdxChunker"),
# Structured formats # Structured formats
DataType.CSV: ("structured_chunker", "CsvChunker"), DataType.CSV: ("structured_chunker", "CsvChunker"),
DataType.JSON: ("structured_chunker", "JsonChunker"), DataType.JSON: ("structured_chunker", "JsonChunker"),
DataType.XML: ("structured_chunker", "XmlChunker"), DataType.XML: ("structured_chunker", "XmlChunker"),
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"), DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
DataType.GITHUB: ("text_chunker", "TextChunker"),
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
DataType.MYSQL: ("text_chunker", "TextChunker"),
DataType.POSTGRES: ("text_chunker", "TextChunker"),
} }
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker")) if self not in chunkers:
raise ValueError(f"No chunker defined for {self}")
module_name, class_name = chunkers[self]
module_path = f"crewai_tools.rag.chunkers.{module_name}" module_path = f"crewai_tools.rag.chunkers.{module_name}"
try: try:
@@ -60,6 +71,7 @@ class DataType(str, Enum):
from importlib import import_module from importlib import import_module
loaders = { loaders = {
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
DataType.TEXT: ("text_loader", "TextLoader"), DataType.TEXT: ("text_loader", "TextLoader"),
DataType.XML: ("xml_loader", "XMLLoader"), DataType.XML: ("xml_loader", "XMLLoader"),
@@ -69,9 +81,20 @@ class DataType(str, Enum):
DataType.DOCX: ("docx_loader", "DOCXLoader"), DataType.DOCX: ("docx_loader", "DOCXLoader"),
DataType.CSV: ("csv_loader", "CSVLoader"), DataType.CSV: ("csv_loader", "CSVLoader"),
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"), DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
DataType.YOUTUBE_CHANNEL: (
"youtube_channel_loader",
"YoutubeChannelLoader",
),
DataType.GITHUB: ("github_loader", "GithubLoader"),
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
} }
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader")) if self not in loaders:
raise ValueError(f"No loader defined for {self}")
module_name, class_name = loaders[self]
module_path = f"crewai_tools.rag.loaders.{module_name}" module_path = f"crewai_tools.rag.loaders.{module_name}"
try: try:
module = import_module(module_path) module = import_module(module_path)
@@ -79,6 +102,7 @@ class DataType(str, Enum):
except Exception as e: except Exception as e:
raise ValueError(f"Error loading loader for {self}: {e}") raise ValueError(f"Error loading loader for {self}: {e}")
class DataTypes: class DataTypes:
@staticmethod @staticmethod
def from_content(content: str | Path | None = None) -> DataType: def from_content(content: str | Path | None = None) -> DataType:

View File

@@ -1,20 +1,26 @@
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
from crewai_tools.rag.loaders.xml_loader import XMLLoader
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.loaders.csv_loader import CSVLoader from crewai_tools.rag.loaders.csv_loader import CSVLoader
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.loaders.xml_loader import XMLLoader
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
__all__ = [ __all__ = [
"CSVLoader",
"DOCXLoader",
"DirectoryLoader",
"JSONLoader",
"MDXLoader",
"PDFLoader",
"TextFileLoader", "TextFileLoader",
"TextLoader", "TextLoader",
"XMLLoader",
"WebPageLoader", "WebPageLoader",
"MDXLoader", "XMLLoader",
"JSONLoader", "YoutubeChannelLoader",
"DOCXLoader", "YoutubeVideoLoader",
"CSVLoader",
"DirectoryLoader",
] ]

View File

@@ -17,21 +17,23 @@ class CSVLoader(BaseLoader):
return self._parse_csv(content_str, source_ref) return self._parse_csv(content_str, source_ref)
def _load_from_url(self, url: str, kwargs: dict) -> str: def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests import requests
headers = kwargs.get("headers", { headers = kwargs.get(
"Accept": "text/csv, application/csv, text/plain", "headers",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)" {
}) "Accept": "text/csv, application/csv, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)",
},
)
try: try:
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except Exception as e: except Exception as e:
raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}") raise ValueError(f"Error fetching CSV from URL {url}: {e!s}")
def _load_from_file(self, path: str) -> str: def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file: with open(path, "r", encoding="utf-8") as file:
@@ -57,7 +59,7 @@ class CSVLoader(BaseLoader):
metadata = { metadata = {
"format": "csv", "format": "csv",
"columns": headers, "columns": headers,
"rows": len(text_parts) - 2 if headers else 0 "rows": len(text_parts) - 2 if headers else 0,
} }
except Exception as e: except Exception as e:
@@ -68,5 +70,5 @@ class CSVLoader(BaseLoader):
content=text, content=text,
source=source_ref, source=source_ref,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text) doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
) )

View File

@@ -1,6 +1,5 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.source_content import SourceContent
@@ -22,7 +21,9 @@ class DirectoryLoader(BaseLoader):
source_ref = source_content.source_ref source_ref = source_content.source_ref
if source_content.is_url(): if source_content.is_url():
raise ValueError("URL directory loading is not supported. Please provide a local directory path.") raise ValueError(
"URL directory loading is not supported. Please provide a local directory path."
)
if not os.path.exists(source_ref): if not os.path.exists(source_ref):
raise FileNotFoundError(f"Directory does not exist: {source_ref}") raise FileNotFoundError(f"Directory does not exist: {source_ref}")
@@ -38,7 +39,9 @@ class DirectoryLoader(BaseLoader):
exclude_extensions = kwargs.get("exclude_extensions", None) exclude_extensions = kwargs.get("exclude_extensions", None)
max_files = kwargs.get("max_files", None) max_files = kwargs.get("max_files", None)
files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions) files = self._find_files(
dir_path, recursive, include_extensions, exclude_extensions
)
if max_files and len(files) > max_files: if max_files and len(files) > max_files:
files = files[:max_files] files = files[:max_files]
@@ -52,13 +55,15 @@ class DirectoryLoader(BaseLoader):
result = self._process_single_file(file_path) result = self._process_single_file(file_path)
if result: if result:
all_contents.append(f"=== File: {file_path} ===\n{result.content}") all_contents.append(f"=== File: {file_path} ===\n{result.content}")
processed_files.append({ processed_files.append(
"path": file_path, {
"metadata": result.metadata, "path": file_path,
"source": result.source "metadata": result.metadata,
}) "source": result.source,
}
)
except Exception as e: except Exception as e:
error_msg = f"Error processing {file_path}: {str(e)}" error_msg = f"Error processing {file_path}: {e!s}"
errors.append(error_msg) errors.append(error_msg)
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}") all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
@@ -71,23 +76,29 @@ class DirectoryLoader(BaseLoader):
"processed_files": len(processed_files), "processed_files": len(processed_files),
"errors": len(errors), "errors": len(errors),
"file_details": processed_files, "file_details": processed_files,
"error_details": errors "error_details": errors,
} }
return LoaderResult( return LoaderResult(
content=combined_content, content=combined_content,
source=dir_path, source=dir_path,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content) doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content),
) )
def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]: def _find_files(
self,
dir_path: str,
recursive: bool,
include_ext: list[str] | None = None,
exclude_ext: list[str] | None = None,
) -> list[str]:
"""Find all files in directory matching criteria.""" """Find all files in directory matching criteria."""
files = [] files = []
if recursive: if recursive:
for root, dirs, filenames in os.walk(dir_path): for root, dirs, filenames in os.walk(dir_path):
dirs[:] = [d for d in dirs if not d.startswith('.')] dirs[:] = [d for d in dirs if not d.startswith(".")]
for filename in filenames: for filename in filenames:
if self._should_include_file(filename, include_ext, exclude_ext): if self._should_include_file(filename, include_ext, exclude_ext):
@@ -96,26 +107,37 @@ class DirectoryLoader(BaseLoader):
try: try:
for item in os.listdir(dir_path): for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item) item_path = os.path.join(dir_path, item)
if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext): if os.path.isfile(item_path) and self._should_include_file(
item, include_ext, exclude_ext
):
files.append(item_path) files.append(item_path)
except PermissionError: except PermissionError:
pass pass
return sorted(files) return sorted(files)
def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool: def _should_include_file(
self,
filename: str,
include_ext: list[str] | None = None,
exclude_ext: list[str] | None = None,
) -> bool:
"""Determine if a file should be included based on criteria.""" """Determine if a file should be included based on criteria."""
if filename.startswith('.'): if filename.startswith("."):
return False return False
_, ext = os.path.splitext(filename.lower()) _, ext = os.path.splitext(filename.lower())
if include_ext: if include_ext:
if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]: if ext not in [
e.lower() if e.startswith(".") else f".{e.lower()}" for e in include_ext
]:
return False return False
if exclude_ext: if exclude_ext:
if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]: if ext in [
e.lower() if e.startswith(".") else f".{e.lower()}" for e in exclude_ext
]:
return False return False
return True return True
@@ -132,11 +154,13 @@ class DirectoryLoader(BaseLoader):
if result.metadata is None: if result.metadata is None:
result.metadata = {} result.metadata = {}
result.metadata.update({ result.metadata.update(
"file_path": file_path, {
"file_size": os.path.getsize(file_path), "file_path": file_path,
"data_type": str(data_type), "file_size": os.path.getsize(file_path),
"loader_type": loader.__class__.__name__ "data_type": str(data_type),
}) "loader_type": loader.__class__.__name__,
}
)
return result return result

View File

@@ -0,0 +1,106 @@
"""Documentation site loader."""
from urllib.parse import urljoin, urlparse
import requests
from bs4 import BeautifulSoup
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class DocsSiteLoader(BaseLoader):
"""Loader for documentation websites."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a documentation site.
Args:
source: Documentation site URL
**kwargs: Additional arguments
Returns:
LoaderResult with documentation content
"""
docs_url = source.source
try:
response = requests.get(docs_url, timeout=30)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style"]):
script.decompose()
title = soup.find("title")
title_text = title.get_text(strip=True) if title else "Documentation"
main_content = None
for selector in [
"main",
"article",
'[role="main"]',
".content",
"#content",
".documentation",
]:
main_content = soup.select_one(selector)
if main_content:
break
if not main_content:
main_content = soup.find("body")
if not main_content:
raise ValueError(
f"Unable to extract content from documentation site: {docs_url}"
)
text_parts = [f"Title: {title_text}", ""]
headings = main_content.find_all(["h1", "h2", "h3"])
if headings:
text_parts.append("Table of Contents:")
for heading in headings[:15]:
level = int(heading.name[1])
indent = " " * (level - 1)
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
text_parts.append("")
text = main_content.get_text(separator="\n", strip=True)
lines = [line.strip() for line in text.split("\n") if line.strip()]
text_parts.extend(lines)
nav_links = []
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
nav = soup.select_one(nav_selector)
if nav:
links = nav.find_all("a", href=True)
for link in links[:20]:
href = link["href"]
if not href.startswith(("http://", "https://", "mailto:", "#")):
full_url = urljoin(docs_url, href)
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
if nav_links:
text_parts.append("")
text_parts.append("Related documentation pages:")
text_parts.extend(nav_links[:10])
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": docs_url,
"title": title_text,
"domain": urlparse(docs_url).netloc,
},
doc_id=self.generate_doc_id(source_ref=docs_url, content=content),
)

View File

@@ -10,7 +10,9 @@ class DOCXLoader(BaseLoader):
try: try:
from docx import Document as DocxDocument from docx import Document as DocxDocument
except ImportError: except ImportError:
raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]") raise ImportError(
"python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]"
)
source_ref = source_content.source_ref source_ref = source_content.source_ref
@@ -23,28 +25,35 @@ class DOCXLoader(BaseLoader):
elif source_content.path_exists(): elif source_content.path_exists():
return self._load_from_file(source_ref, source_ref, DocxDocument) return self._load_from_file(source_ref, source_ref, DocxDocument)
else: else:
raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}") raise ValueError(
f"Source must be a valid file path or URL, got: {source_content.source}"
)
def _download_from_url(self, url: str, kwargs: dict) -> str: def _download_from_url(self, url: str, kwargs: dict) -> str:
import requests import requests
headers = kwargs.get("headers", { headers = kwargs.get(
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "headers",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)" {
}) "Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)",
},
)
try: try:
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
# Create temporary file to save the DOCX content # Create temporary file to save the DOCX content
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as temp_file:
temp_file.write(response.content) temp_file.write(response.content)
return temp_file.name return temp_file.name
except Exception as e: except Exception as e:
raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}") raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}")
def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult: def _load_from_file(
self, file_path: str, source_ref: str, DocxDocument
) -> LoaderResult:
try: try:
doc = DocxDocument(file_path) doc = DocxDocument(file_path)
@@ -58,15 +67,15 @@ class DOCXLoader(BaseLoader):
metadata = { metadata = {
"format": "docx", "format": "docx",
"paragraphs": len(doc.paragraphs), "paragraphs": len(doc.paragraphs),
"tables": len(doc.tables) "tables": len(doc.tables),
} }
return LoaderResult( return LoaderResult(
content=content, content=content,
source=source_ref, source=source_ref,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=content) doc_id=self.generate_doc_id(source_ref=source_ref, content=content),
) )
except Exception as e: except Exception as e:
raise ValueError(f"Error loading DOCX file: {str(e)}") raise ValueError(f"Error loading DOCX file: {e!s}")

View File

@@ -0,0 +1,112 @@
"""GitHub repository content loader."""
from github import Github, GithubException
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class GithubLoader(BaseLoader):
"""Loader for GitHub repository content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a GitHub repository.
Args:
source: GitHub repository URL
**kwargs: Additional arguments including gh_token and content_types
Returns:
LoaderResult with repository content
"""
metadata = kwargs.get("metadata", {})
gh_token = metadata.get("gh_token")
content_types = metadata.get("content_types", ["code", "repo"])
repo_url = source.source
if not repo_url.startswith("https://github.com/"):
raise ValueError(f"Invalid GitHub URL: {repo_url}")
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
if len(parts) < 2:
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
repo_name = f"{parts[0]}/{parts[1]}"
g = Github(gh_token) if gh_token else Github()
try:
repo = g.get_repo(repo_name)
except GithubException as e:
raise ValueError(f"Unable to access repository {repo_name}: {e}")
all_content = []
if "repo" in content_types:
all_content.append(f"Repository: {repo.full_name}")
all_content.append(f"Description: {repo.description or 'No description'}")
all_content.append(f"Language: {repo.language or 'Not specified'}")
all_content.append(f"Stars: {repo.stargazers_count}")
all_content.append(f"Forks: {repo.forks_count}")
all_content.append("")
if "code" in content_types:
try:
readme = repo.get_readme()
all_content.append("README:")
all_content.append(
readme.decoded_content.decode("utf-8", errors="ignore")
)
all_content.append("")
except GithubException:
pass
try:
contents = repo.get_contents("")
if isinstance(contents, list):
all_content.append("Repository structure:")
for content_file in contents[:20]:
all_content.append(
f"- {content_file.path} ({content_file.type})"
)
all_content.append("")
except GithubException:
pass
if "pr" in content_types:
prs = repo.get_pulls(state="open")
pr_list = list(prs[:5])
if pr_list:
all_content.append("Recent Pull Requests:")
for pr in pr_list:
all_content.append(f"- PR #{pr.number}: {pr.title}")
if pr.body:
body_preview = pr.body[:200].replace("\n", " ")
all_content.append(f" {body_preview}")
all_content.append("")
if "issue" in content_types:
issues = repo.get_issues(state="open")
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
if issue_list:
all_content.append("Recent Issues:")
for issue in issue_list:
all_content.append(f"- Issue #{issue.number}: {issue.title}")
if issue.body:
body_preview = issue.body[:200].replace("\n", " ")
all_content.append(f" {body_preview}")
all_content.append("")
if not all_content:
raise ValueError(f"No content could be loaded from repository: {repo_url}")
content = "\n".join(all_content)
return LoaderResult(
content=content,
metadata={
"source": repo_url,
"repo": repo_name,
"content_types": content_types,
},
doc_id=self.generate_doc_id(source_ref=repo_url, content=content),
)

View File

@@ -1,7 +1,7 @@
import json import json
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class JSONLoader(BaseLoader): class JSONLoader(BaseLoader):
@@ -19,17 +19,24 @@ class JSONLoader(BaseLoader):
def _load_from_url(self, url: str, kwargs: dict) -> str: def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests import requests
headers = kwargs.get("headers", { headers = kwargs.get(
"Accept": "application/json", "headers",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)" {
}) "Accept": "application/json",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)",
},
)
try: try:
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2) return (
response.text
if not self._is_json_response(response)
else json.dumps(response.json(), indent=2)
)
except Exception as e: except Exception as e:
raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}") raise ValueError(f"Error fetching JSON from URL {url}: {e!s}")
def _is_json_response(self, response) -> bool: def _is_json_response(self, response) -> bool:
try: try:
@@ -46,7 +53,9 @@ class JSONLoader(BaseLoader):
try: try:
data = json.loads(content) data = json.loads(content)
if isinstance(data, dict): if isinstance(data, dict):
text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()) text = "\n".join(
f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()
)
elif isinstance(data, list): elif isinstance(data, list):
text = "\n".join(json.dumps(item, indent=0) for item in data) text = "\n".join(json.dumps(item, indent=0) for item in data)
else: else:
@@ -55,7 +64,7 @@ class JSONLoader(BaseLoader):
metadata = { metadata = {
"format": "json", "format": "json",
"type": type(data).__name__, "type": type(data).__name__,
"size": len(data) if isinstance(data, (list, dict)) else 1 "size": len(data) if isinstance(data, (list, dict)) else 1,
} }
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
text = content text = content
@@ -65,5 +74,5 @@ class JSONLoader(BaseLoader):
content=text, content=text,
source=source_ref, source=source_ref,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text) doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
) )

View File

@@ -3,6 +3,7 @@ import re
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.source_content import SourceContent
class MDXLoader(BaseLoader): class MDXLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref source_ref = source_content.source_ref
@@ -18,17 +19,20 @@ class MDXLoader(BaseLoader):
def _load_from_url(self, url: str, kwargs: dict) -> str: def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests import requests
headers = kwargs.get("headers", { headers = kwargs.get(
"Accept": "text/markdown, text/x-markdown, text/plain", "headers",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)" {
}) "Accept": "text/markdown, text/x-markdown, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)",
},
)
try: try:
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except Exception as e: except Exception as e:
raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}") raise ValueError(f"Error fetching MDX from URL {url}: {e!s}")
def _load_from_file(self, path: str) -> str: def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file: with open(path, "r", encoding="utf-8") as file:
@@ -38,16 +42,20 @@ class MDXLoader(BaseLoader):
cleaned_content = content cleaned_content = content
# Remove import statements # Remove import statements
cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE) cleaned_content = re.sub(
r"^import\s+.*?\n", "", cleaned_content, flags=re.MULTILINE
)
# Remove export statements # Remove export statements
cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE) cleaned_content = re.sub(
r"^export\s+.*?(?:\n|$)", "", cleaned_content, flags=re.MULTILINE
)
# Remove JSX tags (simple approach) # Remove JSX tags (simple approach)
cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content) cleaned_content = re.sub(r"<[^>]+>", "", cleaned_content)
# Clean up extra whitespace # Clean up extra whitespace
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content) cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content)
cleaned_content = cleaned_content.strip() cleaned_content = cleaned_content.strip()
metadata = {"format": "mdx"} metadata = {"format": "mdx"}
@@ -55,5 +63,5 @@ class MDXLoader(BaseLoader):
content=cleaned_content, content=cleaned_content,
source=source_ref, source=source_ref,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content) doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content),
) )

View File

@@ -0,0 +1,100 @@
"""MySQL database loader."""
from urllib.parse import urlparse
import pymysql
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class MySQLLoader(BaseLoader):
"""Loader for MySQL database content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a MySQL database table.
Args:
source: SQL query (e.g., "SELECT * FROM table_name")
**kwargs: Additional arguments including db_uri
Returns:
LoaderResult with database content
"""
metadata = kwargs.get("metadata", {})
db_uri = metadata.get("db_uri")
if not db_uri:
raise ValueError("Database URI is required for MySQL loader")
query = source.source
parsed = urlparse(db_uri)
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
connection_params = {
"host": parsed.hostname or "localhost",
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path.lstrip("/") if parsed.path else None,
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor,
}
if not connection_params["database"]:
raise ValueError("Database name is required in the URI")
try:
connection = pymysql.connect(**connection_params)
try:
with connection.cursor() as cursor:
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
content = "No data found in the table"
return LoaderResult(
content=content,
metadata={"source": query, "row_count": 0},
doc_id=self.generate_doc_id(
source_ref=query, content=content
),
)
text_parts = []
columns = list(rows[0].keys())
text_parts.append(f"Columns: {', '.join(columns)}")
text_parts.append(f"Total rows: {len(rows)}")
text_parts.append("")
for i, row in enumerate(rows, 1):
text_parts.append(f"Row {i}:")
for col, val in row.items():
if val is not None:
text_parts.append(f" {col}: {val}")
text_parts.append("")
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": query,
"database": connection_params["database"],
"row_count": len(rows),
"columns": columns,
},
doc_id=self.generate_doc_id(source_ref=query, content=content),
)
finally:
connection.close()
except pymysql.Error as e:
raise ValueError(f"MySQL database error: {e}")
except Exception as e:
raise ValueError(f"Failed to load data from MySQL: {e}")

View File

@@ -0,0 +1,71 @@
"""PDF loader for extracting text from PDF files."""
import os
from pathlib import Path
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class PDFLoader(BaseLoader):
"""Loader for PDF files."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract text from a PDF file.
Args:
source: The source content containing the PDF file path
Returns:
LoaderResult with extracted text content
Raises:
FileNotFoundError: If the PDF file doesn't exist
ImportError: If required PDF libraries aren't installed
"""
try:
import pypdf
except ImportError:
try:
import PyPDF2 as pypdf
except ImportError:
raise ImportError(
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
)
file_path = source.source
if not os.path.isfile(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")
text_content = []
metadata: dict[str, Any] = {
"source": str(file_path),
"file_name": Path(file_path).name,
"file_type": "pdf",
}
try:
with open(file_path, "rb") as file:
pdf_reader = pypdf.PdfReader(file)
metadata["num_pages"] = len(pdf_reader.pages)
for page_num, page in enumerate(pdf_reader.pages, 1):
page_text = page.extract_text()
if page_text.strip():
text_content.append(f"Page {page_num}:\n{page_text}")
except Exception as e:
raise ValueError(f"Error reading PDF file {file_path}: {e!s}")
if not text_content:
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
else:
content = "\n\n".join(text_content)
return LoaderResult(
content=content,
source=str(file_path),
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
)

View File

@@ -0,0 +1,100 @@
"""PostgreSQL database loader."""
from urllib.parse import urlparse
import psycopg2
from psycopg2.extras import RealDictCursor
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class PostgresLoader(BaseLoader):
"""Loader for PostgreSQL database content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a PostgreSQL database table.
Args:
source: SQL query (e.g., "SELECT * FROM table_name")
**kwargs: Additional arguments including db_uri
Returns:
LoaderResult with database content
"""
metadata = kwargs.get("metadata", {})
db_uri = metadata.get("db_uri")
if not db_uri:
raise ValueError("Database URI is required for PostgreSQL loader")
query = source.source
parsed = urlparse(db_uri)
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
connection_params = {
"host": parsed.hostname or "localhost",
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path.lstrip("/") if parsed.path else None,
"cursor_factory": RealDictCursor,
}
if not connection_params["database"]:
raise ValueError("Database name is required in the URI")
try:
connection = psycopg2.connect(**connection_params)
try:
with connection.cursor() as cursor:
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
content = "No data found in the table"
return LoaderResult(
content=content,
metadata={"source": query, "row_count": 0},
doc_id=self.generate_doc_id(
source_ref=query, content=content
),
)
text_parts = []
columns = list(rows[0].keys())
text_parts.append(f"Columns: {', '.join(columns)}")
text_parts.append(f"Total rows: {len(rows)}")
text_parts.append("")
for i, row in enumerate(rows, 1):
text_parts.append(f"Row {i}:")
for col, val in row.items():
if val is not None:
text_parts.append(f" {col}: {val}")
text_parts.append("")
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": query,
"database": connection_params["database"],
"row_count": len(rows),
"columns": columns,
},
doc_id=self.generate_doc_id(source_ref=query, content=content),
)
finally:
connection.close()
except psycopg2.Error as e:
raise ValueError(f"PostgreSQL database error: {e}")
except Exception as e:
raise ValueError(f"Failed to load data from PostgreSQL: {e}")

View File

@@ -1,18 +1,23 @@
import re import re
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.source_content import SourceContent
class WebPageLoader(BaseLoader): class WebPageLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
url = source_content.source url = source_content.source
headers = kwargs.get("headers", { headers = kwargs.get(
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36", "headers",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", {
"Accept-Language": "en-US,en;q=0.9", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
}) "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"Accept-Language": "en-US,en;q=0.9",
},
)
try: try:
response = requests.get(url, timeout=15, headers=headers) response = requests.get(url, timeout=15, headers=headers)
@@ -28,20 +33,22 @@ class WebPageLoader(BaseLoader):
text = re.sub("\\s+\n\\s+", "\n", text) text = re.sub("\\s+\n\\s+", "\n", text)
text = text.strip() text = text.strip()
title = soup.title.string.strip() if soup.title and soup.title.string else "" title = (
soup.title.string.strip() if soup.title and soup.title.string else ""
)
metadata = { metadata = {
"url": url, "url": url,
"title": title, "title": title,
"status_code": response.status_code, "status_code": response.status_code,
"content_type": response.headers.get("content-type", "") "content_type": response.headers.get("content-type", ""),
} }
return LoaderResult( return LoaderResult(
content=text, content=text,
source=url, source=url,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=url, content=text) doc_id=self.generate_doc_id(source_ref=url, content=text),
) )
except Exception as e: except Exception as e:
raise ValueError(f"Error loading webpage {url}: {str(e)}") raise ValueError(f"Error loading webpage {url}: {e!s}")

View File

@@ -1,9 +1,9 @@
import os
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.source_content import SourceContent
class XMLLoader(BaseLoader): class XMLLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref source_ref = source_content.source_ref
@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
if source_content.is_url(): if source_content.is_url():
content = self._load_from_url(source_ref, kwargs) content = self._load_from_url(source_ref, kwargs)
elif os.path.exists(source_ref): elif source_content.path_exists():
content = self._load_from_file(source_ref) content = self._load_from_file(source_ref)
return self._parse_xml(content, source_ref) return self._parse_xml(content, source_ref)
@@ -19,17 +19,20 @@ class XMLLoader(BaseLoader):
def _load_from_url(self, url: str, kwargs: dict) -> str: def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests import requests
headers = kwargs.get("headers", { headers = kwargs.get(
"Accept": "application/xml, text/xml, text/plain", "headers",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)" {
}) "Accept": "application/xml, text/xml, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)",
},
)
try: try:
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except Exception as e: except Exception as e:
raise ValueError(f"Error fetching XML from URL {url}: {str(e)}") raise ValueError(f"Error fetching XML from URL {url}: {e!s}")
def _load_from_file(self, path: str) -> str: def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file: with open(path, "r", encoding="utf-8") as file:
@@ -37,7 +40,7 @@ class XMLLoader(BaseLoader):
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult: def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
try: try:
if content.strip().startswith('<'): if content.strip().startswith("<"):
root = ET.fromstring(content) root = ET.fromstring(content)
else: else:
root = ET.parse(source_ref).getroot() root = ET.parse(source_ref).getroot()
@@ -57,5 +60,5 @@ class XMLLoader(BaseLoader):
content=text, content=text,
source=source_ref, source=source_ref,
metadata=metadata, metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text) doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
) )

View File

@@ -0,0 +1,162 @@
"""YouTube channel loader for extracting content from YouTube channels."""
import re
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class YoutubeChannelLoader(BaseLoader):
"""Loader for YouTube channels."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract content from a YouTube channel.
Args:
source: The source content containing the YouTube channel URL
Returns:
LoaderResult with channel content
Raises:
ImportError: If required YouTube libraries aren't installed
ValueError: If the URL is not a valid YouTube channel URL
"""
try:
from pytube import Channel
except ImportError:
raise ImportError(
"YouTube channel support requires pytube. Install with: uv add pytube"
)
channel_url = source.source
if not any(
pattern in channel_url
for pattern in [
"youtube.com/channel/",
"youtube.com/c/",
"youtube.com/@",
"youtube.com/user/",
]
):
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
metadata: dict[str, Any] = {
"source": channel_url,
"data_type": "youtube_channel",
}
try:
channel = Channel(channel_url)
metadata["channel_name"] = channel.channel_name
metadata["channel_id"] = channel.channel_id
max_videos = kwargs.get("max_videos", 10)
video_urls = list(channel.video_urls)[:max_videos]
metadata["num_videos_loaded"] = len(video_urls)
metadata["total_videos"] = len(list(channel.video_urls))
content_parts = [
f"YouTube Channel: {channel.channel_name}",
f"Channel ID: {channel.channel_id}",
f"Total Videos: {metadata['total_videos']}",
f"Videos Loaded: {metadata['num_videos_loaded']}",
"\n--- Video Summaries ---\n",
]
try:
from pytube import YouTube
from youtube_transcript_api import YouTubeTranscriptApi
for i, video_url in enumerate(video_urls, 1):
try:
video_id = self._extract_video_id(video_url)
if not video_id:
continue
yt = YouTube(video_url)
title = yt.title or f"Video {i}"
description = (
yt.description[:200] if yt.description else "No description"
)
content_parts.append(f"\n{i}. {title}")
content_parts.append(f" URL: {video_url}")
content_parts.append(f" Description: {description}...")
try:
api = YouTubeTranscriptApi()
transcript_list = api.list(video_id)
transcript = None
try:
transcript = transcript_list.find_transcript(["en"])
except:
try:
transcript = (
transcript_list.find_generated_transcript(
["en"]
)
)
except:
transcript = next(iter(transcript_list), None)
if transcript:
transcript_data = transcript.fetch()
text_parts = []
char_count = 0
for entry in transcript_data:
text = (
entry.text.strip()
if hasattr(entry, "text")
else ""
)
if text:
text_parts.append(text)
char_count += len(text)
if char_count > 500:
break
if text_parts:
preview = " ".join(text_parts)[:500]
content_parts.append(
f" Transcript Preview: {preview}..."
)
except:
content_parts.append(" Transcript: Not available")
except Exception as e:
content_parts.append(f"\n{i}. Error loading video: {e!s}")
except ImportError:
for i, video_url in enumerate(video_urls, 1):
content_parts.append(f"\n{i}. {video_url}")
content = "\n".join(content_parts)
except Exception as e:
raise ValueError(
f"Unable to load YouTube channel {channel_url}: {e!s}"
) from e
return LoaderResult(
content=content,
source=channel_url,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=channel_url, content=content),
)
def _extract_video_id(self, url: str) -> str | None:
"""Extract video ID from YouTube URL."""
patterns = [
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None

View File

@@ -0,0 +1,134 @@
"""YouTube video loader for extracting transcripts from YouTube videos."""
import re
from typing import Any
from urllib.parse import parse_qs, urlparse
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class YoutubeVideoLoader(BaseLoader):
"""Loader for YouTube videos."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract transcript from a YouTube video.
Args:
source: The source content containing the YouTube URL
Returns:
LoaderResult with transcript content
Raises:
ImportError: If required YouTube libraries aren't installed
ValueError: If the URL is not a valid YouTube video URL
"""
try:
from youtube_transcript_api import YouTubeTranscriptApi
except ImportError:
raise ImportError(
"YouTube support requires youtube-transcript-api. "
"Install with: uv add youtube-transcript-api"
)
video_url = source.source
video_id = self._extract_video_id(video_url)
if not video_id:
raise ValueError(f"Invalid YouTube URL: {video_url}")
metadata: dict[str, Any] = {
"source": video_url,
"video_id": video_id,
"data_type": "youtube_video",
}
try:
api = YouTubeTranscriptApi()
transcript_list = api.list(video_id)
transcript = None
try:
transcript = transcript_list.find_transcript(["en"])
except:
try:
transcript = transcript_list.find_generated_transcript(["en"])
except:
transcript = next(iter(transcript_list))
if transcript:
metadata["language"] = transcript.language
metadata["is_generated"] = transcript.is_generated
transcript_data = transcript.fetch()
text_content = []
for entry in transcript_data:
text = entry.text.strip() if hasattr(entry, "text") else ""
if text:
text_content.append(text)
content = " ".join(text_content)
try:
from pytube import YouTube
yt = YouTube(video_url)
metadata["title"] = yt.title
metadata["author"] = yt.author
metadata["length_seconds"] = yt.length
metadata["description"] = (
yt.description[:500] if yt.description else None
)
if yt.title:
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
except:
pass
else:
raise ValueError(
f"No transcript available for YouTube video: {video_id}"
)
except Exception as e:
raise ValueError(
f"Unable to extract transcript from YouTube video {video_id}: {e!s}"
) from e
return LoaderResult(
content=content,
source=video_url,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=video_url, content=content),
)
def _extract_video_id(self, url: str) -> str | None:
"""Extract video ID from various YouTube URL formats."""
patterns = [
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
try:
parsed = urlparse(url)
hostname = parsed.hostname
if hostname:
hostname_lower = hostname.lower()
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
if (
hostname_lower == "youtube.com"
or hostname_lower.endswith(".youtube.com")
or hostname_lower == "youtu.be"
):
query_params = parse_qs(parsed.query)
if "v" in query_params:
return query_params["v"][0]
except:
pass
return None

View File

@@ -1,4 +1,31 @@
import hashlib import hashlib
from typing import Any
def compute_sha256(content: str) -> str: def compute_sha256(content: str) -> str:
return hashlib.sha256(content.encode("utf-8")).hexdigest() return hashlib.sha256(content.encode("utf-8")).hexdigest()
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
"""Sanitize metadata to ensure ChromaDB compatibility.
ChromaDB only accepts str, int, float, or bool values in metadata.
This function converts other types to strings.
Args:
metadata: Dictionary of metadata to sanitize
Returns:
Sanitized metadata dictionary with only ChromaDB-compatible types
"""
sanitized = {}
for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)) or value is None:
sanitized[key] = value
elif isinstance(value, (list, tuple)):
# Convert lists/tuples to pipe-separated strings
sanitized[key] = " | ".join(str(v) for v in value)
else:
# Convert other types to string
sanitized[key] = str(value)
return sanitized

View File

@@ -1,8 +1,8 @@
import os import os
from urllib.parse import urlparse
from typing import TYPE_CHECKING
from pathlib import Path
from functools import cached_property from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from crewai_tools.rag.misc import compute_sha256 from crewai_tools.rag.misc import compute_sha256
@@ -34,7 +34,7 @@ class SourceContent:
@cached_property @cached_property
def source_ref(self) -> str: def source_ref(self) -> str:
"""" """ "
Returns the source reference for the content. Returns the source reference for the content.
If the content is a URL or a local file, returns the source. If the content is a URL or a local file, returns the source.
Otherwise, returns the hash of the content. Otherwise, returns the hash of the content.

View File

@@ -70,6 +70,9 @@ from .oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool impo
from .oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import ( from .oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
OxylabsUniversalScraperTool, OxylabsUniversalScraperTool,
) )
from .parallel_tools import (
ParallelSearchTool,
)
from .patronus_eval_tool import ( from .patronus_eval_tool import (
PatronusEvalTool, PatronusEvalTool,
PatronusLocalEvaluatorTool, PatronusLocalEvaluatorTool,
@@ -122,6 +125,3 @@ from .youtube_channel_search_tool.youtube_channel_search_tool import (
) )
from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool
from .zapier_action_tool.zapier_action_tool import ZapierActionTools from .zapier_action_tool.zapier_action_tool import ZapierActionTools
from .parallel_tools import (
ParallelSearchTool,
)

View File

@@ -1,6 +1,6 @@
import os import os
import secrets import secrets
from typing import Any, Dict, List, Optional, Type from typing import Any
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from openai import OpenAI from openai import OpenAI
@@ -28,20 +28,22 @@ class AIMindTool(BaseTool):
"and Google BigQuery. " "and Google BigQuery. "
"Input should be a question in natural language." "Input should be a question in natural language."
) )
args_schema: Type[BaseModel] = AIMindToolInputSchema args_schema: type[BaseModel] = AIMindToolInputSchema
api_key: Optional[str] = None api_key: str | None = None
datasources: Optional[List[Dict[str, Any]]] = None datasources: list[dict[str, Any]] | None = None
mind_name: Optional[str] = None mind_name: str | None = None
package_dependencies: List[str] = ["minds-sdk"] package_dependencies: list[str] = ["minds-sdk"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="MINDS_API_KEY", description="API key for AI-Minds", required=True), EnvVar(name="MINDS_API_KEY", description="API key for AI-Minds", required=True),
] ]
def __init__(self, api_key: Optional[str] = None, **kwargs): def __init__(self, api_key: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = api_key or os.getenv("MINDS_API_KEY") self.api_key = api_key or os.getenv("MINDS_API_KEY")
if not self.api_key: if not self.api_key:
raise ValueError("API key must be provided either through constructor or MINDS_API_KEY environment variable") raise ValueError(
"API key must be provided either through constructor or MINDS_API_KEY environment variable"
)
try: try:
from minds.client import Client # type: ignore from minds.client import Client # type: ignore
@@ -74,13 +76,12 @@ class AIMindTool(BaseTool):
self.mind_name = mind.name self.mind_name = mind.name
def _run( def _run(self, query: str):
self,
query: str
):
# Run the query on the AI-Mind. # Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used. # The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
openai_client = OpenAI(base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key) openai_client = OpenAI(
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key
)
completion = openai_client.chat.completions.create( completion = openai_client.chat.completions.create(
model=self.mind_name, model=self.mind_name,

View File

@@ -1,14 +1,20 @@
import os
from typing import TYPE_CHECKING, Any, ClassVar
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import Field from pydantic import Field
from typing import TYPE_CHECKING, Any, Dict, List
import os
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_apify import ApifyActorsTool as _ApifyActorsTool from langchain_apify import ApifyActorsTool as _ApifyActorsTool
class ApifyActorsTool(BaseTool): class ApifyActorsTool(BaseTool):
env_vars: List[EnvVar] = [ env_vars: ClassVar[list[EnvVar]] = [
EnvVar(name="APIFY_API_TOKEN", description="API token for Apify platform access", required=True), EnvVar(
name="APIFY_API_TOKEN",
description="API token for Apify platform access",
required=True,
),
] ]
"""Tool that runs Apify Actors. """Tool that runs Apify Actors.
@@ -40,15 +46,10 @@ class ApifyActorsTool(BaseTool):
print(f"URL: {result['metadata']['url']}") print(f"URL: {result['metadata']['url']}")
print(f"Content: {result.get('markdown', 'N/A')[:100]}...") print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
""" """
actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool") actor_tool: "_ApifyActorsTool" = Field(description="Apify Actor Tool")
package_dependencies: List[str] = ["langchain-apify"] package_dependencies: ClassVar[list[str]] = ["langchain-apify"]
def __init__( def __init__(self, actor_name: str, *args: Any, **kwargs: Any) -> None:
self,
actor_name: str,
*args: Any,
**kwargs: Any
) -> None:
if not os.environ.get("APIFY_API_TOKEN"): if not os.environ.get("APIFY_API_TOKEN"):
msg = ( msg = (
"APIFY_API_TOKEN environment variable is not set. " "APIFY_API_TOKEN environment variable is not set. "
@@ -59,11 +60,11 @@ class ApifyActorsTool(BaseTool):
try: try:
from langchain_apify import ApifyActorsTool as _ApifyActorsTool from langchain_apify import ApifyActorsTool as _ApifyActorsTool
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import langchain_apify python package. " "Could not import langchain_apify python package. "
"Please install it with `pip install langchain-apify` or `uv add langchain-apify`." "Please install it with `pip install langchain-apify` or `uv add langchain-apify`."
) ) from e
actor_tool = _ApifyActorsTool(actor_name) actor_tool = _ApifyActorsTool(actor_name)
kwargs.update( kwargs.update(
@@ -76,7 +77,7 @@ class ApifyActorsTool(BaseTool):
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _run(self, run_input: Dict[str, Any]) -> List[Dict[str, Any]]: def _run(self, run_input: dict[str, Any]) -> list[dict[str, Any]]:
"""Run the Actor tool with the given input. """Run the Actor tool with the given input.
Returns: Returns:
@@ -89,8 +90,8 @@ class ApifyActorsTool(BaseTool):
return self.actor_tool._run(run_input) return self.actor_tool._run(run_input)
except Exception as e: except Exception as e:
msg = ( msg = (
f'Failed to run ApifyActorsTool {self.name}. ' f"Failed to run ApifyActorsTool {self.name}. "
'Please check your Apify account Actor run logs for more details.' "Please check your Apify account Actor run logs for more details."
f'Error: {e}' f"Error: {e}"
) )
raise RuntimeError(msg) from e raise RuntimeError(msg) from e

View File

@@ -1,35 +1,44 @@
import logging
import re import re
import time import time
import urllib.request
import urllib.parse
import urllib.error import urllib.error
import urllib.parse
import urllib.request
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Type, List, Optional, ClassVar
from pydantic import BaseModel, Field
from crewai.tools import BaseTool,EnvVar
import logging
from pathlib import Path from pathlib import Path
from typing import ClassVar
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
class ArxivToolInput(BaseModel): class ArxivToolInput(BaseModel):
search_query: str = Field(..., description="Search query for Arxiv, e.g., 'transformer neural network'") search_query: str = Field(
max_results: int = Field(5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100") ..., description="Search query for Arxiv, e.g., 'transformer neural network'"
)
max_results: int = Field(
5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100"
)
class ArxivPaperTool(BaseTool): class ArxivPaperTool(BaseTool):
BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query" BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query"
SLEEP_DURATION: ClassVar[int] = 1 SLEEP_DURATION: ClassVar[int] = 1
SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300 SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300
ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}" ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}"
REQUEST_TIMEOUT: ClassVar[int] = 10 REQUEST_TIMEOUT: ClassVar[int] = 10
name: str = "Arxiv Paper Fetcher and Downloader" name: str = "Arxiv Paper Fetcher and Downloader"
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs." description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
args_schema: Type[BaseModel] = ArxivToolInput args_schema: type[BaseModel] = ArxivToolInput
model_config = {"extra": "allow"} model_config = {"extra": "allow"}
package_dependencies: List[str] = ["pydantic"] package_dependencies: list[str] = ["pydantic"]
env_vars: List[EnvVar] = [] env_vars: list[EnvVar] = []
def __init__(self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False): def __init__(
self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False
):
super().__init__() super().__init__()
self.download_pdfs = download_pdfs self.download_pdfs = download_pdfs
self.save_dir = save_dir self.save_dir = save_dir
@@ -38,44 +47,49 @@ class ArxivPaperTool(BaseTool):
def _run(self, search_query: str, max_results: int = 5) -> str: def _run(self, search_query: str, max_results: int = 5) -> str:
try: try:
args = ArxivToolInput(search_query=search_query, max_results=max_results) args = ArxivToolInput(search_query=search_query, max_results=max_results)
logger.info(f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, " logger.info(
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', " f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
f"use_title_as_filename={self.use_title_as_filename}") f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
f"use_title_as_filename={self.use_title_as_filename}"
)
papers = self.fetch_arxiv_data(args.search_query, args.max_results) papers = self.fetch_arxiv_data(args.search_query, args.max_results)
if self.download_pdfs: if self.download_pdfs:
save_dir = self._validate_save_path(self.save_dir) save_dir = self._validate_save_path(self.save_dir)
for paper in papers: for paper in papers:
if paper['pdf_url']: if paper["pdf_url"]:
if self.use_title_as_filename: if self.use_title_as_filename:
safe_title = re.sub(r'[\\/*?:"<>|]', "_", paper['title']).strip() safe_title = re.sub(
filename_base = safe_title or paper['arxiv_id'] r'[\\/*?:"<>|]', "_", paper["title"]
).strip()
filename_base = safe_title or paper["arxiv_id"]
else: else:
filename_base = paper['arxiv_id'] filename_base = paper["arxiv_id"]
filename = f"{filename_base[:500]}.pdf" filename = f"{filename_base[:500]}.pdf"
save_path = Path(save_dir) / filename save_path = Path(save_dir) / filename
self.download_pdf(paper['pdf_url'], save_path) self.download_pdf(paper["pdf_url"], save_path)
time.sleep(self.SLEEP_DURATION) time.sleep(self.SLEEP_DURATION)
results = [self._format_paper_result(p) for p in papers] results = [self._format_paper_result(p) for p in papers]
return "\n\n" + "-" * 80 + "\n\n".join(results) return "\n\n" + "-" * 80 + "\n\n".join(results)
except Exception as e: except Exception as e:
logger.error(f"ArxivTool Error: {str(e)}") logger.error(f"ArxivTool Error: {e!s}")
return f"Failed to fetch or download Arxiv papers: {str(e)}" return f"Failed to fetch or download Arxiv papers: {e!s}"
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]: def fetch_arxiv_data(self, search_query: str, max_results: int) -> list[dict]:
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}" api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
logger.info(f"Fetching data from Arxiv API: {api_url}") logger.info(f"Fetching data from Arxiv API: {api_url}")
try: try:
with urllib.request.urlopen(api_url, timeout=self.REQUEST_TIMEOUT) as response: with urllib.request.urlopen(
api_url, timeout=self.REQUEST_TIMEOUT
) as response:
if response.status != 200: if response.status != 200:
raise Exception(f"HTTP {response.status}: {response.reason}") raise Exception(f"HTTP {response.status}: {response.reason}")
data = response.read().decode('utf-8') data = response.read().decode("utf-8")
except urllib.error.URLError as e: except urllib.error.URLError as e:
logger.error(f"Error fetching data from Arxiv: {e}") logger.error(f"Error fetching data from Arxiv: {e}")
raise raise
@@ -85,7 +99,7 @@ class ArxivPaperTool(BaseTool):
for entry in root.findall(self.ATOM_NAMESPACE + "entry"): for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
raw_id = self._get_element_text(entry, "id") raw_id = self._get_element_text(entry, "id")
arxiv_id = raw_id.split('/')[-1].replace('.', '_') if raw_id else "unknown" arxiv_id = raw_id.split("/")[-1].replace(".", "_") if raw_id else "unknown"
title = self._get_element_text(entry, "title") or "No Title" title = self._get_element_text(entry, "title") or "No Title"
summary = self._get_element_text(entry, "summary") or "No Summary" summary = self._get_element_text(entry, "summary") or "No Summary"
@@ -97,41 +111,48 @@ class ArxivPaperTool(BaseTool):
pdf_url = self._extract_pdf_url(entry) pdf_url = self._extract_pdf_url(entry)
papers.append({ papers.append(
"arxiv_id": arxiv_id, {
"title": title, "arxiv_id": arxiv_id,
"summary": summary, "title": title,
"authors": authors, "summary": summary,
"published_date": published, "authors": authors,
"pdf_url": pdf_url "published_date": published,
}) "pdf_url": pdf_url,
}
)
return papers return papers
@staticmethod @staticmethod
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]: def _get_element_text(entry: ET.Element, element_name: str) -> str | None:
elem = entry.find(f'{ArxivPaperTool.ATOM_NAMESPACE}{element_name}') elem = entry.find(f"{ArxivPaperTool.ATOM_NAMESPACE}{element_name}")
return elem.text.strip() if elem is not None and elem.text else None return elem.text.strip() if elem is not None and elem.text else None
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]: def _extract_pdf_url(self, entry: ET.Element) -> str | None:
for link in entry.findall(self.ATOM_NAMESPACE + "link"): for link in entry.findall(self.ATOM_NAMESPACE + "link"):
if link.attrib.get('title', '').lower() == 'pdf': if link.attrib.get("title", "").lower() == "pdf":
return link.attrib.get('href') return link.attrib.get("href")
for link in entry.findall(self.ATOM_NAMESPACE + "link"): for link in entry.findall(self.ATOM_NAMESPACE + "link"):
href = link.attrib.get('href') href = link.attrib.get("href")
if href and 'pdf' in href: if href and "pdf" in href:
return href return href
return None return None
def _format_paper_result(self, paper: dict) -> str: def _format_paper_result(self, paper: dict) -> str:
summary = (paper['summary'][:self.SUMMARY_TRUNCATE_LENGTH] + '...') \ summary = (
if len(paper['summary']) > self.SUMMARY_TRUNCATE_LENGTH else paper['summary'] (paper["summary"][: self.SUMMARY_TRUNCATE_LENGTH] + "...")
authors_str = ', '.join(paper['authors']) if len(paper["summary"]) > self.SUMMARY_TRUNCATE_LENGTH
return (f"Title: {paper['title']}\n" else paper["summary"]
f"Authors: {authors_str}\n" )
f"Published: {paper['published_date']}\n" authors_str = ", ".join(paper["authors"])
f"PDF: {paper['pdf_url'] or 'N/A'}\n" return (
f"Summary: {summary}") f"Title: {paper['title']}\n"
f"Authors: {authors_str}\n"
f"Published: {paper['published_date']}\n"
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
f"Summary: {summary}"
)
@staticmethod @staticmethod
def _validate_save_path(path: str) -> Path: def _validate_save_path(path: str) -> Path:

View File

@@ -1,16 +1,19 @@
import pytest
import urllib.error import urllib.error
from unittest.mock import patch, MagicMock, mock_open
from pathlib import Path
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools import ArxivPaperTool from crewai_tools import ArxivPaperTool
@pytest.fixture @pytest.fixture
def tool(): def tool():
return ArxivPaperTool(download_pdfs=False) return ArxivPaperTool(download_pdfs=False)
def mock_arxiv_response(): def mock_arxiv_response():
return '''<?xml version="1.0" encoding="UTF-8"?> return """<?xml version="1.0" encoding="UTF-8"?>
<feed xmlns="http://www.w3.org/2005/Atom"> <feed xmlns="http://www.w3.org/2005/Atom">
<entry> <entry>
<id>http://arxiv.org/abs/1234.5678</id> <id>http://arxiv.org/abs/1234.5678</id>
@@ -20,7 +23,8 @@ def mock_arxiv_response():
<author><name>John Doe</name></author> <author><name>John Doe</name></author>
<link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/> <link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/>
</entry> </entry>
</feed>''' </feed>"""
@patch("urllib.request.urlopen") @patch("urllib.request.urlopen")
def test_fetch_arxiv_data(mock_urlopen, tool): def test_fetch_arxiv_data(mock_urlopen, tool):
@@ -31,24 +35,30 @@ def test_fetch_arxiv_data(mock_urlopen, tool):
results = tool.fetch_arxiv_data("transformer", 1) results = tool.fetch_arxiv_data("transformer", 1)
assert isinstance(results, list) assert isinstance(results, list)
assert results[0]['title'] == "Sample Paper" assert results[0]["title"] == "Sample Paper"
@patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout")) @patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout"))
def test_fetch_arxiv_data_network_error(mock_urlopen, tool): def test_fetch_arxiv_data_network_error(mock_urlopen, tool):
with pytest.raises(urllib.error.URLError): with pytest.raises(urllib.error.URLError):
tool.fetch_arxiv_data("transformer", 1) tool.fetch_arxiv_data("transformer", 1)
@patch("urllib.request.urlretrieve") @patch("urllib.request.urlretrieve")
def test_download_pdf_success(mock_urlretrieve): def test_download_pdf_success(mock_urlretrieve):
tool = ArxivPaperTool() tool = ArxivPaperTool()
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf")) tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf"))
mock_urlretrieve.assert_called_once() mock_urlretrieve.assert_called_once()
@patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied")) @patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied"))
def test_download_pdf_oserror(mock_urlretrieve): def test_download_pdf_oserror(mock_urlretrieve):
tool = ArxivPaperTool() tool = ArxivPaperTool()
with pytest.raises(OSError): with pytest.raises(OSError):
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf")) tool.download_pdf(
"http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf")
)
@patch("urllib.request.urlopen") @patch("urllib.request.urlopen")
@patch("urllib.request.urlretrieve") @patch("urllib.request.urlretrieve")
@@ -63,6 +73,7 @@ def test_run_with_download(mock_urlretrieve, mock_urlopen):
assert "Title: Sample Paper" in output assert "Title: Sample Paper" in output
mock_urlretrieve.assert_called_once() mock_urlretrieve.assert_called_once()
@patch("urllib.request.urlopen") @patch("urllib.request.urlopen")
def test_run_no_download(mock_urlopen): def test_run_no_download(mock_urlopen):
mock_response = MagicMock() mock_response = MagicMock()
@@ -74,12 +85,14 @@ def test_run_no_download(mock_urlopen):
result = tool._run("transformer", 1) result = tool._run("transformer", 1)
assert "Title: Sample Paper" in result assert "Title: Sample Paper" in result
@patch("pathlib.Path.mkdir") @patch("pathlib.Path.mkdir")
def test_validate_save_path_creates_directory(mock_mkdir): def test_validate_save_path_creates_directory(mock_mkdir):
path = ArxivPaperTool._validate_save_path("new_folder") path = ArxivPaperTool._validate_save_path("new_folder")
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
assert isinstance(path, Path) assert isinstance(path, Path)
@patch("urllib.request.urlopen") @patch("urllib.request.urlopen")
def test_run_handles_exception(mock_urlopen): def test_run_handles_exception(mock_urlopen):
mock_urlopen.side_effect = Exception("API failure") mock_urlopen.side_effect = Exception("API failure")
@@ -98,16 +111,20 @@ def test_invalid_xml_response(mock_urlopen, tool):
with pytest.raises(ET.ParseError): with pytest.raises(ET.ParseError):
tool.fetch_arxiv_data("quantum", 1) tool.fetch_arxiv_data("quantum", 1)
@patch.object(ArxivPaperTool, "fetch_arxiv_data") @patch.object(ArxivPaperTool, "fetch_arxiv_data")
def test_run_with_max_results(mock_fetch, tool): def test_run_with_max_results(mock_fetch, tool):
mock_fetch.return_value = [{ mock_fetch.return_value = [
"arxiv_id": f"test_{i}", {
"title": f"Title {i}", "arxiv_id": f"test_{i}",
"summary": "Summary", "title": f"Title {i}",
"authors": ["Author"], "summary": "Summary",
"published_date": "2023-01-01", "authors": ["Author"],
"pdf_url": None "published_date": "2023-01-01",
} for i in range(100)] "pdf_url": None,
}
for i in range(100)
]
result = tool._run(search_query="test", max_results=100) result = tool._run(search_query="test", max_results=100)
assert result.count("Title:") == 100 assert result.count("Title:") == 100

View File

@@ -1,7 +1,7 @@
import datetime import datetime
import os import os
import time import time
from typing import Any, ClassVar, List, Optional, Type from typing import Any, ClassVar
import requests import requests
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
@@ -41,15 +41,17 @@ class BraveSearchTool(BaseTool):
description: str = ( description: str = (
"A tool that can be used to search the internet with a search_query." "A tool that can be used to search the internet with a search_query."
) )
args_schema: Type[BaseModel] = BraveSearchToolSchema args_schema: type[BaseModel] = BraveSearchToolSchema
search_url: str = "https://api.search.brave.com/res/v1/web/search" search_url: str = "https://api.search.brave.com/res/v1/web/search"
country: Optional[str] = "" country: str | None = ""
n_results: int = 10 n_results: int = 10
save_file: bool = False save_file: bool = False
_last_request_time: ClassVar[float] = 0 _last_request_time: ClassVar[float] = 0
_min_request_interval: ClassVar[float] = 1.0 # seconds _min_request_interval: ClassVar[float] = 1.0 # seconds
env_vars: List[EnvVar] = [ env_vars: ClassVar[list[EnvVar]] = [
EnvVar(name="BRAVE_API_KEY", description="API key for Brave Search", required=True), EnvVar(
name="BRAVE_API_KEY", description="API key for Brave Search", required=True
),
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -87,7 +89,9 @@ class BraveSearchTool(BaseTool):
"Accept": "application/json", "Accept": "application/json",
} }
response = requests.get(self.search_url, headers=headers, params=payload) response = requests.get(
self.search_url, headers=headers, params=payload, timeout=30
)
response.raise_for_status() # Handle non-200 responses response.raise_for_status() # Handle non-200 responses
results = response.json() results = response.json()
@@ -111,11 +115,10 @@ class BraveSearchTool(BaseTool):
content = "\n".join(string) content = "\n".join(string)
except requests.RequestException as e: except requests.RequestException as e:
return f"Error performing search: {str(e)}" return f"Error performing search: {e!s}"
except KeyError as e: except KeyError as e:
return f"Error parsing search results: {str(e)}" return f"Error parsing search results: {e!s}"
if save_file: if save_file:
_save_results_to_file(content) _save_results_to_file(content)
return f"\nSearch results: {content}\n" return f"\nSearch results: {content}\n"
else: return content
return content

View File

@@ -2,8 +2,4 @@ from .brightdata_dataset import BrightDataDatasetTool
from .brightdata_serp import BrightDataSearchTool from .brightdata_serp import BrightDataSearchTool
from .brightdata_unlocker import BrightDataWebUnlockerTool from .brightdata_unlocker import BrightDataWebUnlockerTool
__all__ = [ __all__ = ["BrightDataDatasetTool", "BrightDataSearchTool", "BrightDataWebUnlockerTool"]
"BrightDataDatasetTool",
"BrightDataSearchTool",
"BrightDataWebUnlockerTool"
]

View File

@@ -1,11 +1,12 @@
import asyncio import asyncio
import os import os
from typing import Any, Dict, Optional, Type from typing import Any
import aiohttp import aiohttp
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class BrightDataConfig(BaseModel): class BrightDataConfig(BaseModel):
API_URL: str = "https://api.brightdata.com" API_URL: str = "https://api.brightdata.com"
DEFAULT_TIMEOUT: int = 600 DEFAULT_TIMEOUT: int = 600
@@ -16,8 +17,12 @@ class BrightDataConfig(BaseModel):
return cls( return cls(
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"), API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")), DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
DEFAULT_POLLING_INTERVAL=int(os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1")) DEFAULT_POLLING_INTERVAL=int(
os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1")
),
) )
class BrightDataDatasetToolException(Exception): class BrightDataDatasetToolException(Exception):
"""Exception raised for custom error in the application.""" """Exception raised for custom error in the application."""
@@ -43,15 +48,16 @@ class BrightDataDatasetToolSchema(BaseModel):
""" """
dataset_type: str = Field(..., description="The Bright Data Dataset Type") dataset_type: str = Field(..., description="The Bright Data Dataset Type")
format: Optional[str] = Field( format: str | None = Field(
default="json", description="Response format (json by default)" default="json", description="Response format (json by default)"
) )
url: str = Field(..., description="The URL to extract data from") url: str = Field(..., description="The URL to extract data from")
zipcode: Optional[str] = Field(default=None, description="Optional zipcode") zipcode: str | None = Field(default=None, description="Optional zipcode")
additional_params: Optional[Dict[str, Any]] = Field( additional_params: dict[str, Any] | None = Field(
default=None, description="Additional params if any" default=None, description="Additional params if any"
) )
config = BrightDataConfig.from_env() config = BrightDataConfig.from_env()
BRIGHTDATA_API_URL = config.API_URL BRIGHTDATA_API_URL = config.API_URL
@@ -404,14 +410,21 @@ class BrightDataDatasetTool(BaseTool):
name: str = "Bright Data Dataset Tool" name: str = "Bright Data Dataset Tool"
description: str = "Scrapes structured data using Bright Data Dataset API from a URL and optional input parameters" description: str = "Scrapes structured data using Bright Data Dataset API from a URL and optional input parameters"
args_schema: Type[BaseModel] = BrightDataDatasetToolSchema args_schema: type[BaseModel] = BrightDataDatasetToolSchema
dataset_type: Optional[str] = None dataset_type: str | None = None
url: Optional[str] = None url: str | None = None
format: str = "json" format: str = "json"
zipcode: Optional[str] = None zipcode: str | None = None
additional_params: Optional[Dict[str, Any]] = None additional_params: dict[str, Any] | None = None
def __init__(self, dataset_type: str = None, url: str = None, format: str = "json", zipcode: str = None, additional_params: Dict[str, Any] = None): def __init__(
self,
dataset_type: str | None = None,
url: str | None = None,
format: str = "json",
zipcode: str | None = None,
additional_params: dict[str, Any] | None = None,
):
super().__init__() super().__init__()
self.dataset_type = dataset_type self.dataset_type = dataset_type
self.url = url self.url = url
@@ -427,10 +440,10 @@ class BrightDataDatasetTool(BaseTool):
dataset_type: str, dataset_type: str,
output_format: str, output_format: str,
url: str, url: str,
zipcode: Optional[str] = None, zipcode: str | None = None,
additional_params: Optional[Dict[str, Any]] = None, additional_params: dict[str, Any] | None = None,
polling_interval: int = 1, polling_interval: int = 1,
) -> Dict: ) -> dict:
""" """
Asynchronously trigger and poll Bright Data dataset scraping. Asynchronously trigger and poll Bright Data dataset scraping.
@@ -509,7 +522,7 @@ class BrightDataDatasetTool(BaseTool):
if status_data.get("status") == "ready": if status_data.get("status") == "ready":
print("Job is ready") print("Job is ready")
break break
elif status_data.get("status") == "error": if status_data.get("status") == "error":
raise BrightDataDatasetToolException( raise BrightDataDatasetToolException(
f"Job failed: {status_data}", 0 f"Job failed: {status_data}", 0
) )
@@ -530,7 +543,15 @@ class BrightDataDatasetTool(BaseTool):
return await snapshot_response.text() return await snapshot_response.text()
def _run(self, url: str = None, dataset_type: str = None, format: str = None, zipcode: str = None, additional_params: Dict[str, Any] = None, **kwargs: Any) -> Any: def _run(
self,
url: str | None = None,
dataset_type: str | None = None,
format: str | None = None,
zipcode: str | None = None,
additional_params: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
dataset_type = dataset_type or self.dataset_type dataset_type = dataset_type or self.dataset_type
output_format = format or self.format output_format = format or self.format
url = url or self.url url = url or self.url
@@ -538,7 +559,9 @@ class BrightDataDatasetTool(BaseTool):
additional_params = additional_params or self.additional_params additional_params = additional_params or self.additional_params
if not dataset_type: if not dataset_type:
raise ValueError("dataset_type is required either in constructor or method call") raise ValueError(
"dataset_type is required either in constructor or method call"
)
if not url: if not url:
raise ValueError("url is required either in constructor or method call") raise ValueError("url is required either in constructor or method call")
@@ -563,8 +586,10 @@ class BrightDataDatasetTool(BaseTool):
) )
) )
except TimeoutError as e: except TimeoutError as e:
return f"Timeout Exception occured in method : get_dataset_data_async. Details - {str(e)}" return f"Timeout Exception occured in method : get_dataset_data_async. Details - {e!s}"
except BrightDataDatasetToolException as e: except BrightDataDatasetToolException as e:
return f"Exception occured in method : get_dataset_data_async. Details - {str(e)}" return (
f"Exception occured in method : get_dataset_data_async. Details - {e!s}"
)
except Exception as e: except Exception as e:
return f"Bright Data API error: {str(e)}" return f"Bright Data API error: {e!s}"

View File

@@ -1,20 +1,24 @@
import os import os
import urllib.parse import urllib.parse
from typing import Any, Optional, Type from typing import Any
import requests import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class BrightDataConfig(BaseModel): class BrightDataConfig(BaseModel):
API_URL: str = "https://api.brightdata.com/request" API_URL: str = "https://api.brightdata.com/request"
@classmethod @classmethod
def from_env(cls): def from_env(cls):
return cls( return cls(
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request") API_URL=os.environ.get(
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
)
) )
class BrightDataSearchToolSchema(BaseModel): class BrightDataSearchToolSchema(BaseModel):
""" """
Schema that defines the input arguments for the BrightDataSearchToolSchema. Schema that defines the input arguments for the BrightDataSearchToolSchema.
@@ -30,27 +34,27 @@ class BrightDataSearchToolSchema(BaseModel):
""" """
query: str = Field(..., description="Search query to perform") query: str = Field(..., description="Search query to perform")
search_engine: Optional[str] = Field( search_engine: str | None = Field(
default="google", default="google",
description="Search engine domain (e.g., 'google', 'bing', 'yandex')", description="Search engine domain (e.g., 'google', 'bing', 'yandex')",
) )
country: Optional[str] = Field( country: str | None = Field(
default="us", default="us",
description="Two-letter country code for geo-targeting (e.g., 'us', 'gb')", description="Two-letter country code for geo-targeting (e.g., 'us', 'gb')",
) )
language: Optional[str] = Field( language: str | None = Field(
default="en", default="en",
description="Language code (e.g., 'en', 'es') used in the query URL", description="Language code (e.g., 'en', 'es') used in the query URL",
) )
search_type: Optional[str] = Field( search_type: str | None = Field(
default=None, default=None,
description="Type of search (e.g., 'isch' for images, 'nws' for news)", description="Type of search (e.g., 'isch' for images, 'nws' for news)",
) )
device_type: Optional[str] = Field( device_type: str | None = Field(
default="desktop", default="desktop",
description="Device type to simulate (e.g., 'mobile', 'desktop', 'ios')", description="Device type to simulate (e.g., 'mobile', 'desktop', 'ios')",
) )
parse_results: Optional[bool] = Field( parse_results: bool | None = Field(
default=True, default=True,
description="Whether to parse and return JSON (True) or raw HTML/text (False)", description="Whether to parse and return JSON (True) or raw HTML/text (False)",
) )
@@ -75,20 +79,29 @@ class BrightDataSearchTool(BaseTool):
name: str = "Bright Data SERP Search" name: str = "Bright Data SERP Search"
description: str = "Tool to perform web search using Bright Data SERP API." description: str = "Tool to perform web search using Bright Data SERP API."
args_schema: Type[BaseModel] = BrightDataSearchToolSchema args_schema: type[BaseModel] = BrightDataSearchToolSchema
_config = BrightDataConfig.from_env() _config = BrightDataConfig.from_env()
base_url: str = "" base_url: str = ""
api_key: str = "" api_key: str = ""
zone: str = "" zone: str = ""
query: Optional[str] = None query: str | None = None
search_engine: str = "google" search_engine: str = "google"
country: str = "us" country: str = "us"
language: str = "en" language: str = "en"
search_type: Optional[str] = None search_type: str | None = None
device_type: str = "desktop" device_type: str = "desktop"
parse_results: bool = True parse_results: bool = True
def __init__(self, query: str = None, search_engine: str = "google", country: str = "us", language: str = "en", search_type: str = None, device_type: str = "desktop", parse_results: bool = True): def __init__(
self,
query: str | None = None,
search_engine: str = "google",
country: str = "us",
language: str = "en",
search_type: str | None = None,
device_type: str = "desktop",
parse_results: bool = True,
):
super().__init__() super().__init__()
self.base_url = self._config.API_URL self.base_url = self._config.API_URL
self.query = query self.query = query
@@ -109,11 +122,21 @@ class BrightDataSearchTool(BaseTool):
def get_search_url(self, engine: str, query: str): def get_search_url(self, engine: str, query: str):
if engine == "yandex": if engine == "yandex":
return f"https://yandex.com/search/?text=${query}" return f"https://yandex.com/search/?text=${query}"
elif engine == "bing": if engine == "bing":
return f"https://www.bing.com/search?q=${query}" return f"https://www.bing.com/search?q=${query}"
return f"https://www.google.com/search?q=${query}" return f"https://www.google.com/search?q=${query}"
def _run(self, query: str = None, search_engine: str = None, country: str = None, language: str = None, search_type: str = None, device_type: str = None, parse_results: bool = None, **kwargs) -> Any: def _run(
self,
query: str | None = None,
search_engine: str | None = None,
country: str | None = None,
language: str | None = None,
search_type: str | None = None,
device_type: str | None = None,
parse_results: bool | None = None,
**kwargs,
) -> Any:
""" """
Executes a search query using Bright Data SERP API and returns results. Executes a search query using Bright Data SERP API and returns results.
@@ -137,7 +160,9 @@ class BrightDataSearchTool(BaseTool):
language = language or self.language language = language or self.language
search_type = search_type or self.search_type search_type = search_type or self.search_type
device_type = device_type or self.device_type device_type = device_type or self.device_type
parse_results = parse_results if parse_results is not None else self.parse_results parse_results = (
parse_results if parse_results is not None else self.parse_results
)
results_count = kwargs.get("results_count", "10") results_count = kwargs.get("results_count", "10")
# Validate required parameters # Validate required parameters
@@ -161,7 +186,7 @@ class BrightDataSearchTool(BaseTool):
params.append(f"num={results_count}") params.append(f"num={results_count}")
if parse_results: if parse_results:
params.append(f"brd_json=1") params.append("brd_json=1")
if search_type: if search_type:
if search_type == "jobs": if search_type == "jobs":
@@ -202,6 +227,6 @@ class BrightDataSearchTool(BaseTool):
return response.text return response.text
except requests.RequestException as e: except requests.RequestException as e:
return f"Error performing BrightData search: {str(e)}" return f"Error performing BrightData search: {e!s}"
except Exception as e: except Exception as e:
return f"Error fetching results: {str(e)}" return f"Error fetching results: {e!s}"

View File

@@ -1,19 +1,23 @@
import os import os
from typing import Any, Optional, Type from typing import Any
import requests import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class BrightDataConfig(BaseModel): class BrightDataConfig(BaseModel):
API_URL: str = "https://api.brightdata.com/request" API_URL: str = "https://api.brightdata.com/request"
@classmethod @classmethod
def from_env(cls): def from_env(cls):
return cls( return cls(
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request") API_URL=os.environ.get(
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
)
) )
class BrightDataUnlockerToolSchema(BaseModel): class BrightDataUnlockerToolSchema(BaseModel):
""" """
Pydantic schema for input parameters used by the BrightDataWebUnlockerTool. Pydantic schema for input parameters used by the BrightDataWebUnlockerTool.
@@ -28,10 +32,10 @@ class BrightDataUnlockerToolSchema(BaseModel):
""" """
url: str = Field(..., description="URL to perform the web scraping") url: str = Field(..., description="URL to perform the web scraping")
format: Optional[str] = Field( format: str | None = Field(
default="raw", description="Response format (raw is standard)" default="raw", description="Response format (raw is standard)"
) )
data_format: Optional[str] = Field( data_format: str | None = Field(
default="markdown", description="Response data format (html by default)" default="markdown", description="Response data format (html by default)"
) )
@@ -59,16 +63,18 @@ class BrightDataWebUnlockerTool(BaseTool):
name: str = "Bright Data Web Unlocker Scraping" name: str = "Bright Data Web Unlocker Scraping"
description: str = "Tool to perform web scraping using Bright Data Web Unlocker" description: str = "Tool to perform web scraping using Bright Data Web Unlocker"
args_schema: Type[BaseModel] = BrightDataUnlockerToolSchema args_schema: type[BaseModel] = BrightDataUnlockerToolSchema
_config = BrightDataConfig.from_env() _config = BrightDataConfig.from_env()
base_url: str = "" base_url: str = ""
api_key: str = "" api_key: str = ""
zone: str = "" zone: str = ""
url: Optional[str] = None url: str | None = None
format: str = "raw" format: str = "raw"
data_format: str = "markdown" data_format: str = "markdown"
def __init__(self, url: str = None, format: str = "raw", data_format: str = "markdown"): def __init__(
self, url: str | None = None, format: str = "raw", data_format: str = "markdown"
):
super().__init__() super().__init__()
self.base_url = self._config.API_URL self.base_url = self._config.API_URL
self.url = url self.url = url
@@ -82,7 +88,13 @@ class BrightDataWebUnlockerTool(BaseTool):
if not self.zone: if not self.zone:
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.") raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
def _run(self, url: str = None, format: str = None, data_format: str = None, **kwargs: Any) -> Any: def _run(
self,
url: str | None = None,
format: str | None = None,
data_format: str | None = None,
**kwargs: Any,
) -> Any:
url = url or self.url url = url or self.url
format = format or self.format format = format or self.format
data_format = data_format or self.data_format data_format = data_format or self.data_format
@@ -119,4 +131,4 @@ class BrightDataWebUnlockerTool(BaseTool):
except requests.RequestException as e: except requests.RequestException as e:
return f"HTTP Error performing BrightData Web Unlocker Scrape: {e}\nResponse: {getattr(e.response, 'text', '')}" return f"HTTP Error performing BrightData Web Unlocker Scrape: {e}\nResponse: {getattr(e.response, 'text', '')}"
except Exception as e: except Exception as e:
return f"Error fetching results: {str(e)}" return f"Error fetching results: {e!s}"

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Optional, Type, List from typing import Any
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -12,26 +12,34 @@ class BrowserbaseLoadToolSchema(BaseModel):
class BrowserbaseLoadTool(BaseTool): class BrowserbaseLoadTool(BaseTool):
name: str = "Browserbase web load tool" name: str = "Browserbase web load tool"
description: str = "Load webpages url in a headless browser using Browserbase and return the contents" description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema args_schema: type[BaseModel] = BrowserbaseLoadToolSchema
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY") api_key: str | None = os.getenv("BROWSERBASE_API_KEY")
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID") project_id: str | None = os.getenv("BROWSERBASE_PROJECT_ID")
text_content: Optional[bool] = False text_content: bool | None = False
session_id: Optional[str] = None session_id: str | None = None
proxy: Optional[bool] = None proxy: bool | None = None
browserbase: Optional[Any] = None browserbase: Any | None = None
package_dependencies: List[str] = ["browserbase"] package_dependencies: list[str] = ["browserbase"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="BROWSERBASE_API_KEY", description="API key for Browserbase services", required=False), EnvVar(
EnvVar(name="BROWSERBASE_PROJECT_ID", description="Project ID for Browserbase services", required=False), name="BROWSERBASE_API_KEY",
description="API key for Browserbase services",
required=False,
),
EnvVar(
name="BROWSERBASE_PROJECT_ID",
description="Project ID for Browserbase services",
required=False,
),
] ]
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: str | None = None,
project_id: Optional[str] = None, project_id: str | None = None,
text_content: Optional[bool] = False, text_content: bool | None = False,
session_id: Optional[str] = None, session_id: str | None = None,
proxy: Optional[bool] = None, proxy: bool | None = None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -1,11 +1,4 @@
from typing import Any, Optional, Type from crewai_tools.rag.data_types import DataType
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -31,9 +24,9 @@ class CodeDocsSearchTool(RagTool):
description: str = ( description: str = (
"A tool that can be used to semantic search a query from a Code Docs content." "A tool that can be used to semantic search a query from a Code Docs content."
) )
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema args_schema: type[BaseModel] = CodeDocsSearchToolSchema
def __init__(self, docs_url: Optional[str] = None, **kwargs): def __init__(self, docs_url: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docs_url is not None: if docs_url is not None:
self.add(docs_url) self.add(docs_url)
@@ -42,15 +35,17 @@ class CodeDocsSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, docs_url: str) -> None: def add(self, docs_url: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(docs_url, data_type=DataType.DOCS_SITE) super().add(docs_url, data_type=DataType.DOCS_SITE)
def _run( def _run(
self, self,
search_query: str, search_query: str,
docs_url: Optional[str] = None, docs_url: str | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if docs_url is not None: if docs_url is not None:
self.add(docs_url) self.add(docs_url)
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -8,17 +8,16 @@ potentially unsafe operations and importing restricted modules.
import importlib.util import importlib.util
import os import os
from types import ModuleType from types import ModuleType
from typing import Any, Dict, List, Optional, Type from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.printer import Printer
from docker import DockerClient from docker import DockerClient
from docker import from_env as docker_from_env from docker import from_env as docker_from_env
from docker.errors import ImageNotFound, NotFound from docker.errors import ImageNotFound, NotFound
from docker.models.containers import Container from docker.models.containers import Container
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai_tools.printer import Printer
class CodeInterpreterSchema(BaseModel): class CodeInterpreterSchema(BaseModel):
"""Schema for defining inputs to the CodeInterpreterTool. """Schema for defining inputs to the CodeInterpreterTool.
@@ -32,7 +31,7 @@ class CodeInterpreterSchema(BaseModel):
description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code", description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code",
) )
libraries_used: List[str] = Field( libraries_used: list[str] = Field(
..., ...,
description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4", description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4",
) )
@@ -74,9 +73,9 @@ class SandboxPython:
@staticmethod @staticmethod
def restricted_import( def restricted_import(
name: str, name: str,
custom_globals: Optional[Dict[str, Any]] = None, custom_globals: dict[str, Any] | None = None,
custom_locals: Optional[Dict[str, Any]] = None, custom_locals: dict[str, Any] | None = None,
fromlist: Optional[List[str]] = None, fromlist: list[str] | None = None,
level: int = 0, level: int = 0,
) -> ModuleType: ) -> ModuleType:
"""A restricted import function that blocks importing of unsafe modules. """A restricted import function that blocks importing of unsafe modules.
@@ -99,7 +98,7 @@ class SandboxPython:
return __import__(name, custom_globals, custom_locals, fromlist or (), level) return __import__(name, custom_globals, custom_locals, fromlist or (), level)
@staticmethod @staticmethod
def safe_builtins() -> Dict[str, Any]: def safe_builtins() -> dict[str, Any]:
"""Creates a dictionary of built-in functions with unsafe ones removed. """Creates a dictionary of built-in functions with unsafe ones removed.
Returns: Returns:
@@ -116,7 +115,7 @@ class SandboxPython:
return safe_builtins return safe_builtins
@staticmethod @staticmethod
def exec(code: str, locals: Dict[str, Any]) -> None: def exec(code: str, locals: dict[str, Any]) -> None:
"""Executes Python code in a restricted environment. """Executes Python code in a restricted environment.
Args: Args:
@@ -136,11 +135,11 @@ class CodeInterpreterTool(BaseTool):
name: str = "Code Interpreter" name: str = "Code Interpreter"
description: str = "Interprets Python3 code strings with a final print statement." description: str = "Interprets Python3 code strings with a final print statement."
args_schema: Type[BaseModel] = CodeInterpreterSchema args_schema: type[BaseModel] = CodeInterpreterSchema
default_image_tag: str = "code-interpreter:latest" default_image_tag: str = "code-interpreter:latest"
code: Optional[str] = None code: str | None = None
user_dockerfile_path: Optional[str] = None user_dockerfile_path: str | None = None
user_docker_base_url: Optional[str] = None user_docker_base_url: str | None = None
unsafe_mode: bool = False unsafe_mode: bool = False
@staticmethod @staticmethod
@@ -205,10 +204,9 @@ class CodeInterpreterTool(BaseTool):
if self.unsafe_mode: if self.unsafe_mode:
return self.run_code_unsafe(code, libraries_used) return self.run_code_unsafe(code, libraries_used)
else: return self.run_code_safety(code, libraries_used)
return self.run_code_safety(code, libraries_used)
def _install_libraries(self, container: Container, libraries: List[str]) -> None: def _install_libraries(self, container: Container, libraries: list[str]) -> None:
"""Installs required Python libraries in the Docker container. """Installs required Python libraries in the Docker container.
Args: Args:
@@ -278,7 +276,7 @@ class CodeInterpreterTool(BaseTool):
Printer.print("Docker is not installed", color="bold_purple") Printer.print("Docker is not installed", color="bold_purple")
return False return False
def run_code_safety(self, code: str, libraries_used: List[str]) -> str: def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
"""Runs code in the safest available environment. """Runs code in the safest available environment.
Attempts to run code in Docker if available, falls back to a restricted Attempts to run code in Docker if available, falls back to a restricted
@@ -293,10 +291,9 @@ class CodeInterpreterTool(BaseTool):
""" """
if self._check_docker_available(): if self._check_docker_available():
return self.run_code_in_docker(code, libraries_used) return self.run_code_in_docker(code, libraries_used)
else: return self.run_code_in_restricted_sandbox(code)
return self.run_code_in_restricted_sandbox(code)
def run_code_in_docker(self, code: str, libraries_used: List[str]) -> str: def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
"""Runs Python code in a Docker container for safe isolation. """Runs Python code in a Docker container for safe isolation.
Creates a Docker container, installs the required libraries, executes the code, Creates a Docker container, installs the required libraries, executes the code,
@@ -342,9 +339,9 @@ class CodeInterpreterTool(BaseTool):
SandboxPython.exec(code=code, locals=exec_locals) SandboxPython.exec(code=code, locals=exec_locals)
return exec_locals.get("result", "No result variable found.") return exec_locals.get("result", "No result variable found.")
except Exception as e: except Exception as e:
return f"An error occurred: {str(e)}" return f"An error occurred: {e!s}"
def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str: def run_code_unsafe(self, code: str, libraries_used: list[str]) -> str:
"""Runs code directly on the host machine without any safety restrictions. """Runs code directly on the host machine without any safety restrictions.
WARNING: This mode is unsafe and should only be used in trusted environments WARNING: This mode is unsafe and should only be used in trusted environments
@@ -370,4 +367,4 @@ class CodeInterpreterTool(BaseTool):
exec(code, {}, exec_locals) exec(code, {}, exec_locals)
return exec_locals.get("result", "No result variable found.") return exec_locals.get("result", "No result variable found.")
except Exception as e: except Exception as e:
return f"An error occurred: {str(e)}" return f"An error occurred: {e!s}"

View File

@@ -12,8 +12,12 @@ class ComposioTool(BaseTool):
"""Wrapper for composio tools.""" """Wrapper for composio tools."""
composio_action: t.Callable composio_action: t.Callable
env_vars: t.List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="COMPOSIO_API_KEY", description="API key for Composio services", required=True), EnvVar(
name="COMPOSIO_API_KEY",
description="API key for Composio services",
required=True,
),
] ]
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
@@ -31,7 +35,7 @@ class ComposioTool(BaseTool):
return return
connections = t.cast( connections = t.cast(
t.List[ConnectedAccountModel], list[ConnectedAccountModel],
toolset.client.connected_accounts.get(), toolset.client.connected_accounts.get(),
) )
if tool.app not in [connection.appUniqueId for connection in connections]: if tool.app not in [connection.appUniqueId for connection in connections]:
@@ -66,7 +70,7 @@ class ComposioTool(BaseTool):
schema = action_schema.model_dump(exclude_none=True) schema = action_schema.model_dump(exclude_none=True)
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID) entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
def function(**kwargs: t.Any) -> t.Dict: def function(**kwargs: t.Any) -> dict:
"""Wrapper function for composio action.""" """Wrapper function for composio action."""
return toolset.execute_action( return toolset.execute_action(
action=Action(schema["name"]), action=Action(schema["name"]),
@@ -93,10 +97,10 @@ class ComposioTool(BaseTool):
def from_app( def from_app(
cls, cls,
*apps: t.Any, *apps: t.Any,
tags: t.Optional[t.List[str]] = None, tags: list[str] | None = None,
use_case: t.Optional[str] = None, use_case: str | None = None,
**kwargs: t.Any, **kwargs: t.Any,
) -> t.List[te.Self]: ) -> list[te.Self]:
"""Create toolset from an app.""" """Create toolset from an app."""
if len(apps) == 0: if len(apps) == 0:
raise ValueError("You need to provide at least one app name") raise ValueError("You need to provide at least one app name")

View File

@@ -1,32 +1,36 @@
from typing import Any, Optional, Type, List from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import os
class ContextualAICreateAgentSchema(BaseModel): class ContextualAICreateAgentSchema(BaseModel):
"""Schema for contextual create agent tool.""" """Schema for contextual create agent tool."""
agent_name: str = Field(..., description="Name for the new agent") agent_name: str = Field(..., description="Name for the new agent")
agent_description: str = Field(..., description="Description for the new agent") agent_description: str = Field(..., description="Description for the new agent")
datastore_name: str = Field(..., description="Name for the new datastore") datastore_name: str = Field(..., description="Name for the new datastore")
document_paths: List[str] = Field(..., description="List of file paths to upload") document_paths: list[str] = Field(..., description="List of file paths to upload")
class ContextualAICreateAgentTool(BaseTool): class ContextualAICreateAgentTool(BaseTool):
"""Tool to create Contextual AI RAG agents with documents.""" """Tool to create Contextual AI RAG agents with documents."""
name: str = "Contextual AI Create Agent Tool" name: str = "Contextual AI Create Agent Tool"
description: str = "Create a new Contextual AI RAG agent with documents and datastore" description: str = (
args_schema: Type[BaseModel] = ContextualAICreateAgentSchema "Create a new Contextual AI RAG agent with documents and datastore"
)
args_schema: type[BaseModel] = ContextualAICreateAgentSchema
api_key: str api_key: str
contextual_client: Any = None contextual_client: Any = None
package_dependencies: List[str] = ["contextual-client"] package_dependencies: list[str] = ["contextual-client"]
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
try: try:
from contextual import ContextualAI from contextual import ContextualAI
self.contextual_client = ContextualAI(api_key=self.api_key) self.contextual_client = ContextualAI(api_key=self.api_key)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -38,34 +42,38 @@ class ContextualAICreateAgentTool(BaseTool):
agent_name: str, agent_name: str,
agent_description: str, agent_description: str,
datastore_name: str, datastore_name: str,
document_paths: List[str] document_paths: list[str],
) -> str: ) -> str:
"""Create a complete RAG pipeline with documents.""" """Create a complete RAG pipeline with documents."""
try: try:
import os import os
# Create datastore # Create datastore
datastore = self.contextual_client.datastores.create(name=datastore_name) datastore = self.contextual_client.datastores.create(name=datastore_name)
datastore_id = datastore.id datastore_id = datastore.id
# Upload documents # Upload documents
document_ids = [] document_ids = []
for doc_path in document_paths: for doc_path in document_paths:
if not os.path.exists(doc_path): if not os.path.exists(doc_path):
raise FileNotFoundError(f"Document not found: {doc_path}") raise FileNotFoundError(f"Document not found: {doc_path}")
with open(doc_path, 'rb') as f: with open(doc_path, "rb") as f:
ingestion_result = self.contextual_client.datastores.documents.ingest(datastore_id, file=f) ingestion_result = (
self.contextual_client.datastores.documents.ingest(
datastore_id, file=f
)
)
document_ids.append(ingestion_result.id) document_ids.append(ingestion_result.id)
# Create agent # Create agent
agent = self.contextual_client.agents.create( agent = self.contextual_client.agents.create(
name=agent_name, name=agent_name,
description=agent_description, description=agent_description,
datastore_ids=[datastore_id] datastore_ids=[datastore_id],
) )
return f"Successfully created agent '{agent_name}' with ID: {agent.id} and datastore ID: {datastore_id}. Uploaded {len(document_ids)} documents." return f"Successfully created agent '{agent_name}' with ID: {agent.id} and datastore ID: {datastore_id}. Uploaded {len(document_ids)} documents."
except Exception as e: except Exception as e:
return f"Failed to create agent with documents: {str(e)}" return f"Failed to create agent with documents: {e!s}"

View File

@@ -1,51 +1,62 @@
from typing import Any, Optional, Type, List
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class ContextualAIParseSchema(BaseModel): class ContextualAIParseSchema(BaseModel):
"""Schema for contextual parse tool.""" """Schema for contextual parse tool."""
file_path: str = Field(..., description="Path to the document to parse") file_path: str = Field(..., description="Path to the document to parse")
parse_mode: str = Field(default="standard", description="Parsing mode") parse_mode: str = Field(default="standard", description="Parsing mode")
figure_caption_mode: str = Field(default="concise", description="Figure caption mode") figure_caption_mode: str = Field(
enable_document_hierarchy: bool = Field(default=True, description="Enable document hierarchy") default="concise", description="Figure caption mode"
page_range: Optional[str] = Field(default=None, description="Page range to parse (e.g., '0-5')") )
output_types: List[str] = Field(default=["markdown-per-page"], description="List of output types") enable_document_hierarchy: bool = Field(
default=True, description="Enable document hierarchy"
)
page_range: str | None = Field(
default=None, description="Page range to parse (e.g., '0-5')"
)
output_types: list[str] = Field(
default=["markdown-per-page"], description="List of output types"
)
class ContextualAIParseTool(BaseTool): class ContextualAIParseTool(BaseTool):
"""Tool to parse documents using Contextual AI's parser.""" """Tool to parse documents using Contextual AI's parser."""
name: str = "Contextual AI Document Parser" name: str = "Contextual AI Document Parser"
description: str = "Parse documents using Contextual AI's advanced document parser" description: str = "Parse documents using Contextual AI's advanced document parser"
args_schema: Type[BaseModel] = ContextualAIParseSchema args_schema: type[BaseModel] = ContextualAIParseSchema
api_key: str api_key: str
package_dependencies: List[str] = ["contextual-client"] package_dependencies: list[str] = ["contextual-client"]
def _run( def _run(
self, self,
file_path: str, file_path: str,
parse_mode: str = "standard", parse_mode: str = "standard",
figure_caption_mode: str = "concise", figure_caption_mode: str = "concise",
enable_document_hierarchy: bool = True, enable_document_hierarchy: bool = True,
page_range: Optional[str] = None, page_range: str | None = None,
output_types: List[str] = ["markdown-per-page"] output_types: list[str] | None = None,
) -> str: ) -> str:
"""Parse a document using Contextual AI's parser.""" """Parse a document using Contextual AI's parser."""
if output_types is None:
output_types = ["markdown-per-page"]
try: try:
import requests
import json import json
import os import os
from time import sleep from time import sleep
import requests
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"Document not found: {file_path}") raise FileNotFoundError(f"Document not found: {file_path}")
base_url = "https://api.contextual.ai/v1" base_url = "https://api.contextual.ai/v1"
headers = { headers = {
"accept": "application/json", "accept": "application/json",
"authorization": f"Bearer {self.api_key}" "authorization": f"Bearer {self.api_key}",
} }
# Submit parse job # Submit parse job
@@ -63,17 +74,17 @@ class ContextualAIParseTool(BaseTool):
file = {"raw_file": fp} file = {"raw_file": fp}
result = requests.post(url, headers=headers, data=config, files=file) result = requests.post(url, headers=headers, data=config, files=file)
response = json.loads(result.text) response = json.loads(result.text)
job_id = response['job_id'] job_id = response["job_id"]
# Monitor job status # Monitor job status
status_url = f"{base_url}/parse/jobs/{job_id}/status" status_url = f"{base_url}/parse/jobs/{job_id}/status"
while True: while True:
result = requests.get(status_url, headers=headers) result = requests.get(status_url, headers=headers)
parse_response = json.loads(result.text)['status'] parse_response = json.loads(result.text)["status"]
if parse_response == "completed": if parse_response == "completed":
break break
elif parse_response == "failed": if parse_response == "failed":
raise RuntimeError("Document parsing failed") raise RuntimeError("Document parsing failed")
sleep(5) sleep(5)
@@ -89,4 +100,4 @@ class ContextualAIParseTool(BaseTool):
return json.dumps(json.loads(result.text), indent=2) return json.dumps(json.loads(result.text), indent=2)
except Exception as e: except Exception as e:
return f"Failed to parse document: {str(e)}" return f"Failed to parse document: {e!s}"

View File

@@ -1,33 +1,39 @@
from typing import Any, Optional, Type, List import asyncio
from typing import Any
import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import asyncio
import requests
import os
class ContextualAIQuerySchema(BaseModel): class ContextualAIQuerySchema(BaseModel):
"""Schema for contextual query tool.""" """Schema for contextual query tool."""
query: str = Field(..., description="Query to send to the Contextual AI agent.") query: str = Field(..., description="Query to send to the Contextual AI agent.")
agent_id: str = Field(..., description="ID of the Contextual AI agent to query") agent_id: str = Field(..., description="ID of the Contextual AI agent to query")
datastore_id: Optional[str] = Field(None, description="Optional datastore ID for document readiness verification") datastore_id: str | None = Field(
None, description="Optional datastore ID for document readiness verification"
)
class ContextualAIQueryTool(BaseTool): class ContextualAIQueryTool(BaseTool):
"""Tool to query Contextual AI RAG agents.""" """Tool to query Contextual AI RAG agents."""
name: str = "Contextual AI Query Tool" name: str = "Contextual AI Query Tool"
description: str = "Use this tool to query a Contextual AI RAG agent with access to your documents" description: str = (
args_schema: Type[BaseModel] = ContextualAIQuerySchema "Use this tool to query a Contextual AI RAG agent with access to your documents"
)
args_schema: type[BaseModel] = ContextualAIQuerySchema
api_key: str api_key: str
contextual_client: Any = None contextual_client: Any = None
package_dependencies: List[str] = ["contextual-client"] package_dependencies: list[str] = ["contextual-client"]
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
try: try:
from contextual import ContextualAI from contextual import ContextualAI
self.contextual_client = ContextualAI(api_key=self.api_key) self.contextual_client = ContextualAI(api_key=self.api_key)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -41,13 +47,17 @@ class ContextualAIQueryTool(BaseTool):
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
documents = data.get('documents', []) documents = data.get("documents", [])
return not any(doc.get('status') in ('processing', 'pending') for doc in documents) return not any(
doc.get("status") in ("processing", "pending") for doc in documents
)
return True return True
async def _wait_for_documents_async(self, datastore_id: str, max_attempts: int = 20, interval: float = 30.0) -> bool: async def _wait_for_documents_async(
self, datastore_id: str, max_attempts: int = 20, interval: float = 30.0
) -> bool:
"""Asynchronously poll until documents are ready, exiting early if possible.""" """Asynchronously poll until documents are ready, exiting early if possible."""
for attempt in range(max_attempts): for _attempt in range(max_attempts):
ready = await asyncio.to_thread(self._check_documents_ready, datastore_id) ready = await asyncio.to_thread(self._check_documents_ready, datastore_id)
if ready: if ready:
return True return True
@@ -55,10 +65,10 @@ class ContextualAIQueryTool(BaseTool):
print("Processing documents ...") print("Processing documents ...")
return True # give up but don't fail hard return True # give up but don't fail hard
def _run(self, query: str, agent_id: str, datastore_id: Optional[str] = None) -> str: def _run(self, query: str, agent_id: str, datastore_id: str | None = None) -> str:
if not agent_id: if not agent_id:
raise ValueError("Agent ID is required to query the Contextual AI agent") raise ValueError("Agent ID is required to query the Contextual AI agent")
if datastore_id: if datastore_id:
ready = self._check_documents_ready(datastore_id) ready = self._check_documents_ready(datastore_id)
if not ready: if not ready:
@@ -69,31 +79,42 @@ class ContextualAIQueryTool(BaseTool):
loop = None loop = None
if loop and loop.is_running(): if loop and loop.is_running():
# Already inside an event loop # Already inside an event loop
try: try:
import nest_asyncio import nest_asyncio
nest_asyncio.apply(loop) nest_asyncio.apply(loop)
loop.run_until_complete(self._wait_for_documents_async(datastore_id)) loop.run_until_complete(
self._wait_for_documents_async(datastore_id)
)
except Exception as e: except Exception as e:
print(f"Failed to apply nest_asyncio: {str(e)}") print(f"Failed to apply nest_asyncio: {e!s}")
else: else:
asyncio.run(self._wait_for_documents_async(datastore_id)) asyncio.run(self._wait_for_documents_async(datastore_id))
else: else:
print("Warning: No datastore_id provided. Document status checking disabled.") print(
"Warning: No datastore_id provided. Document status checking disabled."
)
try: try:
response = self.contextual_client.agents.query.create( response = self.contextual_client.agents.query.create(
agent_id=agent_id, agent_id=agent_id, messages=[{"role": "user", "content": query}]
messages=[{"role": "user", "content": query}]
) )
if hasattr(response, 'content'): if hasattr(response, "content"):
return response.content return response.content
elif hasattr(response, 'message'): if hasattr(response, "message"):
return response.message.content if hasattr(response.message, 'content') else str(response.message) return (
elif hasattr(response, 'messages') and len(response.messages) > 0: response.message.content
if hasattr(response.message, "content")
else str(response.message)
)
if hasattr(response, "messages") and len(response.messages) > 0:
last_message = response.messages[-1] last_message = response.messages[-1]
return last_message.content if hasattr(last_message, 'content') else str(last_message) return (
else: last_message.content
return str(response) if hasattr(last_message, "content")
else str(last_message)
)
return str(response)
except Exception as e: except Exception as e:
return f"Error querying Contextual AI agent: {str(e)}" return f"Error querying Contextual AI agent: {e!s}"

View File

@@ -1,68 +1,79 @@
from typing import Any, Optional, Type, List from typing import ClassVar
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class ContextualAIRerankSchema(BaseModel): class ContextualAIRerankSchema(BaseModel):
"""Schema for contextual rerank tool.""" """Schema for contextual rerank tool."""
query: str = Field(..., description="The search query to rerank documents against") query: str = Field(..., description="The search query to rerank documents against")
documents: List[str] = Field(..., description="List of document texts to rerank") documents: list[str] = Field(..., description="List of document texts to rerank")
instruction: Optional[str] = Field(default=None, description="Optional instruction for reranking behavior") instruction: str | None = Field(
metadata: Optional[List[str]] = Field(default=None, description="Optional metadata for each document") default=None, description="Optional instruction for reranking behavior"
model: str = Field(default="ctxl-rerank-en-v1-instruct", description="Reranker model to use") )
metadata: list[str] | None = Field(
default=None, description="Optional metadata for each document"
)
model: str = Field(
default="ctxl-rerank-en-v1-instruct", description="Reranker model to use"
)
class ContextualAIRerankTool(BaseTool): class ContextualAIRerankTool(BaseTool):
"""Tool to rerank documents using Contextual AI's instruction-following reranker.""" """Tool to rerank documents using Contextual AI's instruction-following reranker."""
name: str = "Contextual AI Document Reranker" name: str = "Contextual AI Document Reranker"
description: str = "Rerank documents using Contextual AI's instruction-following reranker" description: str = (
args_schema: Type[BaseModel] = ContextualAIRerankSchema "Rerank documents using Contextual AI's instruction-following reranker"
)
args_schema: type[BaseModel] = ContextualAIRerankSchema
api_key: str api_key: str
package_dependencies: List[str] = ["contextual-client"] package_dependencies: ClassVar[list[str]] = ["contextual-client"]
def _run( def _run(
self, self,
query: str, query: str,
documents: List[str], documents: list[str],
instruction: Optional[str] = None, instruction: str | None = None,
metadata: Optional[List[str]] = None, metadata: list[str] | None = None,
model: str = "ctxl-rerank-en-v1-instruct" model: str = "ctxl-rerank-en-v1-instruct",
) -> str: ) -> str:
"""Rerank documents using Contextual AI's instruction-following reranker.""" """Rerank documents using Contextual AI's instruction-following reranker."""
try: try:
import requests
import json import json
import requests
base_url = "https://api.contextual.ai/v1" base_url = "https://api.contextual.ai/v1"
headers = { headers = {
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
"authorization": f"Bearer {self.api_key}" "authorization": f"Bearer {self.api_key}",
} }
payload = { payload = {"query": query, "documents": documents, "model": model}
"query": query,
"documents": documents,
"model": model
}
if instruction: if instruction:
payload["instruction"] = instruction payload["instruction"] = instruction
if metadata: if metadata:
if len(metadata) != len(documents): if len(metadata) != len(documents):
raise ValueError("Metadata list must have the same length as documents list") raise ValueError(
"Metadata list must have the same length as documents list"
)
payload["metadata"] = metadata payload["metadata"] = metadata
rerank_url = f"{base_url}/rerank" rerank_url = f"{base_url}/rerank"
result = requests.post(rerank_url, json=payload, headers=headers) result = requests.post(rerank_url, json=payload, headers=headers, timeout=30)
if result.status_code != 200: if result.status_code != 200:
raise RuntimeError(f"Reranker API returned status {result.status_code}: {result.text}") raise RuntimeError(
f"Reranker API returned status {result.status_code}: {result.text}"
)
return json.dumps(result.json(), indent=2) return json.dumps(result.json(), indent=2)
except Exception as e: except Exception as e:
return f"Failed to rerank documents: {str(e)}" return f"Failed to rerank documents: {e!s}"

View File

@@ -1,6 +1,6 @@
import json import json
import os from collections.abc import Callable
from typing import Any, Optional, Type, List, Dict, Callable from typing import Any
try: try:
import couchbase.search as search import couchbase.search as search
@@ -29,30 +29,33 @@ class CouchbaseToolSchema(BaseModel):
description="The query to search retrieve relevant information from the Couchbase database. Pass only the query, not the question.", description="The query to search retrieve relevant information from the Couchbase database. Pass only the query, not the question.",
) )
class CouchbaseFTSVectorSearchTool(BaseTool): class CouchbaseFTSVectorSearchTool(BaseTool):
"""Tool to search the Couchbase database""" """Tool to search the Couchbase database"""
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
name: str = "CouchbaseFTSVectorSearchTool" name: str = "CouchbaseFTSVectorSearchTool"
description: str = "A tool to search the Couchbase database for relevant information on internal documents." description: str = "A tool to search the Couchbase database for relevant information on internal documents."
args_schema: Type[BaseModel] = CouchbaseToolSchema args_schema: type[BaseModel] = CouchbaseToolSchema
cluster: SkipValidation[Optional[Cluster]] = None cluster: SkipValidation[Cluster | None] = None
collection_name: Optional[str] = None, collection_name: str | None = (None,)
scope_name: Optional[str] = None, scope_name: str | None = (None,)
bucket_name: Optional[str] = None, bucket_name: str | None = (None,)
index_name: Optional[str] = None, index_name: str | None = (None,)
embedding_key: Optional[str] = Field( embedding_key: str | None = Field(
default="embedding", default="embedding",
description="Name of the field in the search index that stores the vector" description="Name of the field in the search index that stores the vector",
) )
scoped_index: Optional[bool] = Field( scoped_index: bool | None = (
default=True, Field(
description="Specify whether the index is scoped. Is True by default." default=True,
), description="Specify whether the index is scoped. Is True by default.",
limit: Optional[int] = Field(default=3) ),
embedding_function: SkipValidation[Callable[[str], List[float]]] = Field( )
limit: int | None = Field(default=3)
embedding_function: SkipValidation[Callable[[str], list[float]]] = Field(
default=None, default=None,
description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database." description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database.",
) )
def _check_bucket_exists(self) -> bool: def _check_bucket_exists(self) -> bool:
@@ -67,7 +70,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
def _check_scope_and_collection_exists(self) -> bool: def _check_scope_and_collection_exists(self) -> bool:
"""Check if the scope and collection exists in the linked Couchbase bucket """Check if the scope and collection exists in the linked Couchbase bucket
Raises a ValueError if either is not found""" Raises a ValueError if either is not found"""
scope_collection_map: Dict[str, Any] = {} scope_collection_map: dict[str, Any] = {}
# Get a list of all scopes in the bucket # Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes(): for scope in self._bucket.collections().get_all_scopes():
@@ -203,11 +206,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
search_req = search.SearchRequest.create( search_req = search.SearchRequest.create(
VectorSearch.from_vector_query( VectorSearch.from_vector_query(
VectorQuery( VectorQuery(self.embedding_key, query_embedding, self.limit)
self.embedding_key,
query_embedding,
self.limit
)
) )
) )
@@ -219,16 +218,13 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
SearchOptions( SearchOptions(
limit=self.limit, limit=self.limit,
fields=fields, fields=fields,
) ),
) )
else: else:
search_iter = self.cluster.search( search_iter = self.cluster.search(
self.index_name, self.index_name,
search_req, search_req,
SearchOptions( SearchOptions(limit=self.limit, fields=fields),
limit=self.limit,
fields=fields
)
) )
json_response = [] json_response = []
@@ -238,4 +234,4 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
except Exception as e: except Exception as e:
return f"Search failed with error: {e}" return f"Search failed with error: {e}"
return json.dumps(json_response, indent=2) return json.dumps(json_response, indent=2)

View File

@@ -2,10 +2,10 @@
Crewai Enterprise Tools Crewai Enterprise Tools
""" """
import os
import typing as t
import logging
import json import json
import logging
import os
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter
from crewai_tools.adapters.tool_collection import ToolCollection from crewai_tools.adapters.tool_collection import ToolCollection
@@ -13,11 +13,11 @@ from crewai_tools.adapters.tool_collection import ToolCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def CrewaiEnterpriseTools( def CrewaiEnterpriseTools( # noqa: N802
enterprise_token: t.Optional[str] = None, enterprise_token: str | None = None,
actions_list: t.Optional[t.List[str]] = None, actions_list: list[str] | None = None,
enterprise_action_kit_project_id: t.Optional[str] = None, enterprise_action_kit_project_id: str | None = None,
enterprise_action_kit_project_url: t.Optional[str] = None, enterprise_action_kit_project_url: str | None = None,
) -> ToolCollection[BaseTool]: ) -> ToolCollection[BaseTool]:
"""Factory function that returns crewai enterprise tools. """Factory function that returns crewai enterprise tools.
@@ -34,10 +34,11 @@ def CrewaiEnterpriseTools(
""" """
import warnings import warnings
warnings.warn( warnings.warn(
"CrewaiEnterpriseTools will be removed in v1.0.0. Considering use `Agent(apps=[...])` instead.", "CrewaiEnterpriseTools will be removed in v1.0.0. Considering use `Agent(apps=[...])` instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
if enterprise_token is None or enterprise_token == "": if enterprise_token is None or enterprise_token == "":
@@ -65,7 +66,7 @@ def CrewaiEnterpriseTools(
# ENTERPRISE INJECTION ONLY # ENTERPRISE INJECTION ONLY
def _parse_actions_list(actions_list: t.Optional[t.List[str]]) -> t.List[str] | None: def _parse_actions_list(actions_list: list[str] | None) -> list[str] | None:
"""Parse a string representation of a list of tool names to a list of tool names. """Parse a string representation of a list of tool names to a list of tool names.
Args: Args:

View File

@@ -4,13 +4,18 @@ This module provides tools for integrating with various platform applications
through the CrewAI platform API. through the CrewAI platform API.
""" """
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import CrewaiPlatformTools from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import CrewAIPlatformActionTool CrewAIPlatformActionTool,
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import CrewaiPlatformToolBuilder )
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import (
CrewaiPlatformToolBuilder,
)
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
CrewaiPlatformTools,
)
__all__ = [ __all__ = [
"CrewaiPlatformTools",
"CrewAIPlatformActionTool", "CrewAIPlatformActionTool",
"CrewaiPlatformToolBuilder", "CrewaiPlatformToolBuilder",
"CrewaiPlatformTools",
] ]

View File

@@ -1,18 +1,24 @@
""" """
Crewai Enterprise Tools Crewai Enterprise Tools
""" """
import re
import json import json
import re
from typing import Any, Literal, Optional, Union, cast, get_origin
import requests import requests
from typing import Dict, Any, List, Type, Optional, Union, get_origin, cast, Literal
from pydantic import Field, create_model
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.tools.crewai_platform_tools.misc import get_platform_api_base_url, get_platform_integration_token from pydantic import Field, create_model
from crewai_tools.tools.crewai_platform_tools.misc import (
get_platform_api_base_url,
get_platform_integration_token,
)
class CrewAIPlatformActionTool(BaseTool): class CrewAIPlatformActionTool(BaseTool):
action_name: str = Field(default="", description="The name of the action") action_name: str = Field(default="", description="The name of the action")
action_schema: Dict[str, Any] = Field( action_schema: dict[str, Any] = Field(
default_factory=dict, description="The schema of the action" default_factory=dict, description="The schema of the action"
) )
@@ -20,7 +26,7 @@ class CrewAIPlatformActionTool(BaseTool):
self, self,
description: str, description: str,
action_name: str, action_name: str,
action_schema: Dict[str, Any], action_schema: dict[str, Any],
): ):
self._model_registry = {} self._model_registry = {}
self._base_name = self._sanitize_name(action_name) self._base_name = self._sanitize_name(action_name)
@@ -36,7 +42,7 @@ class CrewAIPlatformActionTool(BaseTool):
field_type = self._process_schema_type( field_type = self._process_schema_type(
param_details, self._sanitize_name(param_name).title() param_details, self._sanitize_name(param_name).title()
) )
except Exception as e: except Exception:
field_type = str field_type = str
field_definitions[param_name] = self._create_field_definition( field_definitions[param_name] = self._create_field_definition(
@@ -60,7 +66,11 @@ class CrewAIPlatformActionTool(BaseTool):
input_text=(str, Field(description="Input for the action")), input_text=(str, Field(description="Input for the action")),
) )
super().__init__(name=action_name.lower().replace(" ", "_"), description=description, args_schema=args_schema) super().__init__(
name=action_name.lower().replace(" ", "_"),
description=description,
args_schema=args_schema,
)
self.action_name = action_name self.action_name = action_name
self.action_schema = action_schema self.action_schema = action_schema
@@ -71,8 +81,8 @@ class CrewAIPlatformActionTool(BaseTool):
return "".join(word.capitalize() for word in parts if word) return "".join(word.capitalize() for word in parts if word)
def _extract_schema_info( def _extract_schema_info(
self, action_schema: Dict[str, Any] self, action_schema: dict[str, Any]
) -> tuple[Dict[str, Any], List[str]]: ) -> tuple[dict[str, Any], list[str]]:
schema_props = ( schema_props = (
action_schema.get("function", {}) action_schema.get("function", {})
.get("parameters", {}) .get("parameters", {})
@@ -83,7 +93,7 @@ class CrewAIPlatformActionTool(BaseTool):
) )
return schema_props, required return schema_props, required
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]: def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
if "anyOf" in schema: if "anyOf" in schema:
any_of_types = schema["anyOf"] any_of_types = schema["anyOf"]
is_nullable = any(t.get("type") == "null" for t in any_of_types) is_nullable = any(t.get("type") == "null" for t in any_of_types)
@@ -92,7 +102,7 @@ class CrewAIPlatformActionTool(BaseTool):
if non_null_types: if non_null_types:
base_type = self._process_schema_type(non_null_types[0], type_name) base_type = self._process_schema_type(non_null_types[0], type_name)
return Optional[base_type] if is_nullable else base_type return Optional[base_type] if is_nullable else base_type
return cast(Type[Any], Optional[str]) return cast(type[Any], Optional[str])
if "oneOf" in schema: if "oneOf" in schema:
return self._process_schema_type(schema["oneOf"][0], type_name) return self._process_schema_type(schema["oneOf"][0], type_name)
@@ -111,14 +121,16 @@ class CrewAIPlatformActionTool(BaseTool):
if json_type == "array": if json_type == "array":
items_schema = schema.get("items", {"type": "string"}) items_schema = schema.get("items", {"type": "string"})
item_type = self._process_schema_type(items_schema, f"{type_name}Item") item_type = self._process_schema_type(items_schema, f"{type_name}Item")
return List[item_type] return list[item_type]
if json_type == "object": if json_type == "object":
return self._create_nested_model(schema, type_name) return self._create_nested_model(schema, type_name)
return self._map_json_type_to_python(json_type) return self._map_json_type_to_python(json_type)
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]: def _create_nested_model(
self, schema: dict[str, Any], model_name: str
) -> type[Any]:
full_model_name = f"{self._base_name}{model_name}" full_model_name = f"{self._base_name}{model_name}"
if full_model_name in self._model_registry: if full_model_name in self._model_registry:
@@ -139,7 +151,7 @@ class CrewAIPlatformActionTool(BaseTool):
prop_type = self._process_schema_type( prop_type = self._process_schema_type(
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}" prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
) )
except Exception as e: except Exception:
prop_type = str prop_type = str
field_definitions[prop_name] = self._create_field_definition( field_definitions[prop_name] = self._create_field_definition(
@@ -155,20 +167,18 @@ class CrewAIPlatformActionTool(BaseTool):
return dict return dict
def _create_field_definition( def _create_field_definition(
self, field_type: Type[Any], is_required: bool, description: str self, field_type: type[Any], is_required: bool, description: str
) -> tuple: ) -> tuple:
if is_required: if is_required:
return (field_type, Field(description=description)) return (field_type, Field(description=description))
else: if get_origin(field_type) is Union:
if get_origin(field_type) is Union: return (field_type, Field(default=None, description=description))
return (field_type, Field(default=None, description=description)) return (
else: Optional[field_type],
return ( Field(default=None, description=description),
Optional[field_type], )
Field(default=None, description=description),
)
def _map_json_type_to_python(self, json_type: str) -> Type[Any]: def _map_json_type_to_python(self, json_type: str) -> type[Any]:
type_mapping = { type_mapping = {
"string": str, "string": str,
"integer": int, "integer": int,
@@ -180,7 +190,7 @@ class CrewAIPlatformActionTool(BaseTool):
} }
return type_mapping.get(json_type, str) return type_mapping.get(json_type, str)
def _get_required_nullable_fields(self) -> List[str]: def _get_required_nullable_fields(self) -> list[str]:
schema_props, required = self._extract_schema_info(self.action_schema) schema_props, required = self._extract_schema_info(self.action_schema)
required_nullable_fields = [] required_nullable_fields = []
@@ -191,7 +201,7 @@ class CrewAIPlatformActionTool(BaseTool):
return required_nullable_fields return required_nullable_fields
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool: def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
if "anyOf" in schema: if "anyOf" in schema:
return any(t.get("type") == "null" for t in schema["anyOf"]) return any(t.get("type") == "null" for t in schema["anyOf"])
return schema.get("type") == "null" return schema.get("type") == "null"
@@ -209,8 +219,9 @@ class CrewAIPlatformActionTool(BaseTool):
if field_name not in cleaned_kwargs: if field_name not in cleaned_kwargs:
cleaned_kwargs[field_name] = None cleaned_kwargs[field_name] = None
api_url = (
api_url = f"{get_platform_api_base_url()}/actions/{self.action_name}/execute" f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
)
token = get_platform_integration_token() token = get_platform_integration_token()
headers = { headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
@@ -230,4 +241,4 @@ class CrewAIPlatformActionTool(BaseTool):
return json.dumps(data, indent=2) return json.dumps(data, indent=2)
except Exception as e: except Exception as e:
return f"Error executing action {self.action_name}: {str(e)}" return f"Error executing action {self.action_name}: {e!s}"

View File

@@ -1,9 +1,15 @@
from typing import Any
import requests import requests
from typing import List, Any, Dict
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.tools.crewai_platform_tools.misc import get_platform_api_base_url, get_platform_integration_token
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import CrewAIPlatformActionTool from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
CrewAIPlatformActionTool,
)
from crewai_tools.tools.crewai_platform_tools.misc import (
get_platform_api_base_url,
get_platform_integration_token,
)
class CrewaiPlatformToolBuilder: class CrewaiPlatformToolBuilder:
@@ -27,13 +33,15 @@ class CrewaiPlatformToolBuilder:
try: try:
response = requests.get( response = requests.get(
actions_url, headers=headers, timeout=30, params={"apps": ",".join(self._apps)} actions_url,
headers=headers,
timeout=30,
params={"apps": ",".join(self._apps)},
) )
response.raise_for_status() response.raise_for_status()
except Exception as e: except Exception:
return return
raw_data = response.json() raw_data = response.json()
self._actions_schema = {} self._actions_schema = {}
@@ -46,7 +54,9 @@ class CrewaiPlatformToolBuilder:
action_schema = { action_schema = {
"function": { "function": {
"name": action_name, "name": action_name,
"description": action.get("description", f"Execute {action_name}"), "description": action.get(
"description", f"Execute {action_name}"
),
"parameters": action.get("parameters", {}), "parameters": action.get("parameters", {}),
"app": app, "app": app,
} }
@@ -54,8 +64,8 @@ class CrewaiPlatformToolBuilder:
self._actions_schema[action_name] = action_schema self._actions_schema[action_name] = action_schema
def _generate_detailed_description( def _generate_detailed_description(
self, schema: Dict[str, Any], indent: int = 0 self, schema: dict[str, Any], indent: int = 0
) -> List[str]: ) -> list[str]:
descriptions = [] descriptions = []
indent_str = " " * indent indent_str = " " * indent
@@ -127,7 +137,6 @@ class CrewaiPlatformToolBuilder:
self._tools = tools self._tools = tools
def __enter__(self): def __enter__(self):
return self.tools() return self.tools()

View File

@@ -1,18 +1,16 @@
import re
import os
import typing as t
from typing import Literal
import logging import logging
import json
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import CrewaiPlatformToolBuilder
from crewai_tools.adapters.tool_collection import ToolCollection from crewai_tools.adapters.tool_collection import ToolCollection
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import (
CrewaiPlatformToolBuilder,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def CrewaiPlatformTools( # noqa: N802
def CrewaiPlatformTools(
apps: list[str], apps: list[str],
) -> ToolCollection[BaseTool]: ) -> ToolCollection[BaseTool]:
"""Factory function that returns crewai platform tools. """Factory function that returns crewai platform tools.

View File

@@ -1,13 +1,17 @@
import os import os
def get_platform_api_base_url() -> str: def get_platform_api_base_url() -> str:
"""Get the platform API base URL from environment or use default.""" """Get the platform API base URL from environment or use default."""
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com") base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
return f"{base_url}/crewai_plus/api/v1/integrations" return f"{base_url}/crewai_plus/api/v1/integrations"
def get_platform_integration_token() -> str: def get_platform_integration_token() -> str:
"""Get the platform API base URL from environment or use default.""" """Get the platform API base URL from environment or use default."""
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or "" token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or ""
if not token: if not token:
raise ValueError("No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable") raise ValueError(
return token # TODO: Use context manager to get token "No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable"
)
return token # TODO: Use context manager to get token

View File

@@ -1,11 +1,4 @@
from typing import Optional, Type from crewai_tools.rag.data_types import DataType
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -31,9 +24,9 @@ class CSVSearchTool(RagTool):
description: str = ( description: str = (
"A tool that can be used to semantic search a query from a CSV's content." "A tool that can be used to semantic search a query from a CSV's content."
) )
args_schema: Type[BaseModel] = CSVSearchToolSchema args_schema: type[BaseModel] = CSVSearchToolSchema
def __init__(self, csv: Optional[str] = None, **kwargs): def __init__(self, csv: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if csv is not None: if csv is not None:
self.add(csv) self.add(csv)
@@ -42,15 +35,17 @@ class CSVSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, csv: str) -> None: def add(self, csv: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(csv, data_type=DataType.CSV) super().add(csv, data_type=DataType.CSV)
def _run( def _run(
self, self,
search_query: str, search_query: str,
csv: Optional[str] = None, csv: str | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if csv is not None: if csv is not None:
self.add(csv) self.add(csv)
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -1,5 +1,4 @@
import json import json
from typing import List, Type
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from openai import OpenAI from openai import OpenAI
@@ -9,21 +8,27 @@ from pydantic import BaseModel, Field
class ImagePromptSchema(BaseModel): class ImagePromptSchema(BaseModel):
"""Input for Dall-E Tool.""" """Input for Dall-E Tool."""
image_description: str = Field(description="Description of the image to be generated by Dall-E.") image_description: str = Field(
description="Description of the image to be generated by Dall-E."
)
class DallETool(BaseTool): class DallETool(BaseTool):
name: str = "Dall-E Tool" name: str = "Dall-E Tool"
description: str = "Generates images using OpenAI's Dall-E model." description: str = "Generates images using OpenAI's Dall-E model."
args_schema: Type[BaseModel] = ImagePromptSchema args_schema: type[BaseModel] = ImagePromptSchema
model: str = "dall-e-3" model: str = "dall-e-3"
size: str = "1024x1024" size: str = "1024x1024"
quality: str = "standard" quality: str = "standard"
n: int = 1 n: int = 1
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="OPENAI_API_KEY", description="API key for OpenAI services", required=True), EnvVar(
name="OPENAI_API_KEY",
description="API key for OpenAI services",
required=True,
),
] ]
def _run(self, **kwargs) -> str: def _run(self, **kwargs) -> str:
@@ -42,11 +47,9 @@ class DallETool(BaseTool):
n=self.n, n=self.n,
) )
image_data = json.dumps( return json.dumps(
{ {
"image_url": response.data[0].url, "image_url": response.data[0].url,
"image_description": response.data[0].revised_prompt, "image_description": response.data[0].revised_prompt,
} }
) )
return image_data

View File

@@ -1,5 +1,5 @@
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Optional
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@@ -7,27 +7,31 @@ from pydantic import BaseModel, Field, model_validator
if TYPE_CHECKING: if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
class DatabricksQueryToolSchema(BaseModel): class DatabricksQueryToolSchema(BaseModel):
"""Input schema for DatabricksQueryTool.""" """Input schema for DatabricksQueryTool."""
query: str = Field( query: str = Field(
..., description="SQL query to execute against the Databricks workspace table" ..., description="SQL query to execute against the Databricks workspace table"
) )
catalog: Optional[str] = Field( catalog: str | None = Field(
None, description="Databricks catalog name (optional, defaults to configured catalog)" None,
description="Databricks catalog name (optional, defaults to configured catalog)",
) )
db_schema: Optional[str] = Field( db_schema: str | None = Field(
None, description="Databricks schema name (optional, defaults to configured schema)" None,
description="Databricks schema name (optional, defaults to configured schema)",
) )
warehouse_id: Optional[str] = Field( warehouse_id: str | None = Field(
None, description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)" None,
description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)",
) )
row_limit: Optional[int] = Field( row_limit: int | None = Field(
1000, description="Maximum number of rows to return (default: 1000)" 1000, description="Maximum number of rows to return (default: 1000)"
) )
@model_validator(mode='after') @model_validator(mode="after")
def validate_input(self) -> 'DatabricksQueryToolSchema': def validate_input(self) -> "DatabricksQueryToolSchema":
"""Validate the input parameters.""" """Validate the input parameters."""
# Ensure the query is not empty # Ensure the query is not empty
if not self.query or not self.query.strip(): if not self.query or not self.query.strip():
@@ -61,21 +65,21 @@ class DatabricksQueryTool(BaseTool):
"Execute SQL queries against Databricks workspace tables and return the results." "Execute SQL queries against Databricks workspace tables and return the results."
" Provide a 'query' parameter with the SQL query to execute." " Provide a 'query' parameter with the SQL query to execute."
) )
args_schema: Type[BaseModel] = DatabricksQueryToolSchema args_schema: type[BaseModel] = DatabricksQueryToolSchema
# Optional default parameters # Optional default parameters
default_catalog: Optional[str] = None default_catalog: str | None = None
default_schema: Optional[str] = None default_schema: str | None = None
default_warehouse_id: Optional[str] = None default_warehouse_id: str | None = None
_workspace_client: Optional["WorkspaceClient"] = None _workspace_client: Optional["WorkspaceClient"] = None
package_dependencies: List[str] = ["databricks-sdk"] package_dependencies: list[str] = ["databricks-sdk"]
def __init__( def __init__(
self, self,
default_catalog: Optional[str] = None, default_catalog: str | None = None,
default_schema: Optional[str] = None, default_schema: str | None = None,
default_warehouse_id: Optional[str] = None, default_warehouse_id: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@@ -96,7 +100,9 @@ class DatabricksQueryTool(BaseTool):
def _validate_credentials(self) -> None: def _validate_credentials(self) -> None:
"""Validate that Databricks credentials are available.""" """Validate that Databricks credentials are available."""
has_profile = "DATABRICKS_CONFIG_PROFILE" in os.environ has_profile = "DATABRICKS_CONFIG_PROFILE" in os.environ
has_direct_auth = "DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ has_direct_auth = (
"DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ
)
if not (has_profile or has_direct_auth): if not (has_profile or has_direct_auth):
raise ValueError( raise ValueError(
@@ -110,6 +116,7 @@ class DatabricksQueryTool(BaseTool):
if self._workspace_client is None: if self._workspace_client is None:
try: try:
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
self._workspace_client = WorkspaceClient() self._workspace_client = WorkspaceClient()
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -117,7 +124,7 @@ class DatabricksQueryTool(BaseTool):
) )
return self._workspace_client return self._workspace_client
def _format_results(self, results: List[Dict[str, Any]]) -> str: def _format_results(self, results: list[dict[str, Any]]) -> str:
"""Format query results as a readable string.""" """Format query results as a readable string."""
if not results: if not results:
return "Query returned no results." return "Query returned no results."
@@ -149,8 +156,13 @@ class DatabricksQueryTool(BaseTool):
data_rows = [] data_rows = []
for row in results: for row in results:
# Handle None values by displaying "NULL" # Handle None values by displaying "NULL"
row_values = {col: str(row[col]) if row[col] is not None else "NULL" for col in columns} row_values = {
data_row = " | ".join(f"{row_values[col]:{col_widths[col]}}" for col in columns) col: str(row[col]) if row[col] is not None else "NULL"
for col in columns
}
data_row = " | ".join(
f"{row_values[col]:{col_widths[col]}}" for col in columns
)
data_rows.append(data_row) data_rows.append(data_row)
# Add row count information # Add row count information
@@ -190,7 +202,7 @@ class DatabricksQueryTool(BaseTool):
catalog=catalog, catalog=catalog,
db_schema=db_schema, db_schema=db_schema,
warehouse_id=warehouse_id, warehouse_id=warehouse_id,
row_limit=row_limit row_limit=row_limit,
) )
# Extract validated parameters # Extract validated parameters
@@ -212,18 +224,17 @@ class DatabricksQueryTool(BaseTool):
try: try:
# Execute the statement # Execute the statement
execution = statement.execute_statement( execution = statement.execute_statement(
warehouse_id=warehouse_id, warehouse_id=warehouse_id, statement=query, **context
statement=query,
**context
) )
statement_id = execution.statement_id statement_id = execution.statement_id
except Exception as execute_error: except Exception as execute_error:
# Handle immediate execution errors # Handle immediate execution errors
return f"Error starting query execution: {str(execute_error)}" return f"Error starting query execution: {execute_error!s}"
# Poll for results with better error handling # Poll for results with better error handling
import time import time
result = None result = None
timeout = 300 # 5 minutes timeout timeout = 300 # 5 minutes timeout
start_time = time.time() start_time = time.time()
@@ -237,8 +248,10 @@ class DatabricksQueryTool(BaseTool):
result = statement.get_statement(statement_id) result = statement.get_statement(statement_id)
# Check if finished - be very explicit about state checking # Check if finished - be very explicit about state checking
if hasattr(result, 'status') and hasattr(result.status, 'state'): if hasattr(result, "status") and hasattr(result.status, "state"):
state_value = str(result.status.state) # Convert to string to handle both string and enum state_value = str(
result.status.state
) # Convert to string to handle both string and enum
# Track state changes for debugging # Track state changes for debugging
if previous_state != state_value: if previous_state != state_value:
@@ -247,33 +260,38 @@ class DatabricksQueryTool(BaseTool):
# Check if state indicates completion # Check if state indicates completion
if "SUCCEEDED" in state_value: if "SUCCEEDED" in state_value:
break break
elif "FAILED" in state_value: if "FAILED" in state_value:
# Extract error message with more robust handling # Extract error message with more robust handling
error_info = "No detailed error info" error_info = "No detailed error info"
try: try:
# First try direct access to error.message # First try direct access to error.message
if hasattr(result.status, 'error') and result.status.error: if (
if hasattr(result.status.error, 'message'): hasattr(result.status, "error")
and result.status.error
):
if hasattr(result.status.error, "message"):
error_info = result.status.error.message error_info = result.status.error.message
# Some APIs may have a different structure # Some APIs may have a different structure
elif hasattr(result.status.error, 'error_message'): elif hasattr(result.status.error, "error_message"):
error_info = result.status.error.error_message error_info = result.status.error.error_message
# Last resort, try to convert the whole error object to string # Last resort, try to convert the whole error object to string
else: else:
error_info = str(result.status.error) error_info = str(result.status.error)
except Exception as err_extract_error: except Exception as err_extract_error:
# If all else fails, try to get any info we can # If all else fails, try to get any info we can
error_info = f"Error details unavailable: {str(err_extract_error)}" error_info = (
f"Error details unavailable: {err_extract_error!s}"
)
# Return immediately on first FAILED state detection # Return immediately on first FAILED state detection
return f"Query execution failed: {error_info}" return f"Query execution failed: {error_info}"
elif "CANCELED" in state_value: if "CANCELED" in state_value:
return "Query was canceled" return "Query was canceled"
except Exception as poll_error: except Exception as poll_error:
# Don't immediately fail - try again a few times # Don't immediately fail - try again a few times
if poll_count > 3: if poll_count > 3:
return f"Error checking query status: {str(poll_error)}" return f"Error checking query status: {poll_error!s}"
# Wait before polling again # Wait before polling again
time.sleep(2) time.sleep(2)
@@ -282,21 +300,27 @@ class DatabricksQueryTool(BaseTool):
if result is None: if result is None:
return "Query returned no result (likely timed out or failed)" return "Query returned no result (likely timed out or failed)"
if not hasattr(result, 'status') or not hasattr(result.status, 'state'): if not hasattr(result, "status") or not hasattr(result.status, "state"):
return "Query completed but returned an invalid result structure" return "Query completed but returned an invalid result structure"
# Convert state to string for comparison # Convert state to string for comparison
state_value = str(result.status.state) state_value = str(result.status.state)
if not any(state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]): if not any(
state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]
):
return f"Query timed out after 5 minutes (last state: {state_value})" return f"Query timed out after 5 minutes (last state: {state_value})"
# Get results - adapt this based on the actual structure of the result object # Get results - adapt this based on the actual structure of the result object
chunk_results = [] chunk_results = []
# Check if we have results and a schema in a very defensive way # Check if we have results and a schema in a very defensive way
has_schema = (hasattr(result, 'manifest') and result.manifest is not None and has_schema = (
hasattr(result.manifest, 'schema') and result.manifest.schema is not None) hasattr(result, "manifest")
has_result = (hasattr(result, 'result') and result.result is not None) and result.manifest is not None
and hasattr(result.manifest, "schema")
and result.manifest.schema is not None
)
has_result = hasattr(result, "result") and result.result is not None
if has_schema and has_result: if has_schema and has_result:
try: try:
@@ -309,10 +333,12 @@ class DatabricksQueryTool(BaseTool):
all_columns = set(columns) all_columns = set(columns)
# Dump the raw structure of result data to help troubleshoot # Dump the raw structure of result data to help troubleshoot
if hasattr(result.result, 'data_array'): if hasattr(result.result, "data_array"):
# Add defensive check for None data_array # Add defensive check for None data_array
if result.result.data_array is None: if result.result.data_array is None:
print("data_array is None - likely an empty result set or DDL query") print(
"data_array is None - likely an empty result set or DDL query"
)
# Return empty result handling rather than trying to process null data # Return empty result handling rather than trying to process null data
return "Query executed successfully (no data returned)" return "Query executed successfully (no data returned)"
@@ -321,7 +347,12 @@ class DatabricksQueryTool(BaseTool):
is_likely_incorrect_row_structure = False is_likely_incorrect_row_structure = False
# Only try to analyze sample if data_array exists and has content # Only try to analyze sample if data_array exists and has content
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0 and len(result.result.data_array[0]) > 0: if (
hasattr(result.result, "data_array")
and result.result.data_array
and len(result.result.data_array) > 0
and len(result.result.data_array[0]) > 0
):
sample_size = min(20, len(result.result.data_array[0])) sample_size = min(20, len(result.result.data_array[0]))
if sample_size > 0: if sample_size > 0:
@@ -332,40 +363,81 @@ class DatabricksQueryTool(BaseTool):
for i in range(sample_size): for i in range(sample_size):
val = result.result.data_array[0][i] val = result.result.data_array[0][i]
total_items += 1 total_items += 1
if isinstance(val, str) and len(val) == 1 and not val.isdigit(): if (
isinstance(val, str)
and len(val) == 1
and not val.isdigit()
):
single_char_count += 1 single_char_count += 1
elif isinstance(val, str) and len(val) == 1 and val.isdigit(): elif (
isinstance(val, str)
and len(val) == 1
and val.isdigit()
):
single_digit_count += 1 single_digit_count += 1
# If a significant portion of the first values are single characters or digits, # If a significant portion of the first values are single characters or digits,
# this likely indicates data is being incorrectly structured # this likely indicates data is being incorrectly structured
if total_items > 0 and (single_char_count + single_digit_count) / total_items > 0.5: if (
total_items > 0
and (single_char_count + single_digit_count)
/ total_items
> 0.5
):
is_likely_incorrect_row_structure = True is_likely_incorrect_row_structure = True
# Additional check: if many rows have just 1 item when we expect multiple columns # Additional check: if many rows have just 1 item when we expect multiple columns
rows_with_single_item = 0 rows_with_single_item = 0
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0: if (
sample_size_for_rows = min(sample_size, len(result.result.data_array[0])) if 'sample_size' in locals() else min(20, len(result.result.data_array[0])) hasattr(result.result, "data_array")
rows_with_single_item = sum(1 for row in result.result.data_array[0][:sample_size_for_rows] if isinstance(row, list) and len(row) == 1) and result.result.data_array
if rows_with_single_item > sample_size_for_rows * 0.5 and len(columns) > 1: and len(result.result.data_array) > 0
):
sample_size_for_rows = (
min(sample_size, len(result.result.data_array[0]))
if "sample_size" in locals()
else min(20, len(result.result.data_array[0]))
)
rows_with_single_item = sum(
1
for row in result.result.data_array[0][
:sample_size_for_rows
]
if isinstance(row, list) and len(row) == 1
)
if (
rows_with_single_item > sample_size_for_rows * 0.5
and len(columns) > 1
):
is_likely_incorrect_row_structure = True is_likely_incorrect_row_structure = True
# Check if we're getting primarily single characters or the data structure seems off, # Check if we're getting primarily single characters or the data structure seems off,
# we should use special handling # we should use special handling
if 'is_likely_incorrect_row_structure' in locals() and is_likely_incorrect_row_structure: if (
print("Data appears to be malformed - will use special row reconstruction") "is_likely_incorrect_row_structure" in locals()
and is_likely_incorrect_row_structure
):
print(
"Data appears to be malformed - will use special row reconstruction"
)
needs_special_string_handling = True needs_special_string_handling = True
else: else:
needs_special_string_handling = False needs_special_string_handling = False
# Process results differently based on detection # Process results differently based on detection
if 'needs_special_string_handling' in locals() and needs_special_string_handling: if (
"needs_special_string_handling" in locals()
and needs_special_string_handling
):
# We're dealing with data where the rows may be incorrectly structured # We're dealing with data where the rows may be incorrectly structured
print("Using row reconstruction processing mode") print("Using row reconstruction processing mode")
# Collect all values into a flat list # Collect all values into a flat list
all_values = [] all_values = []
if hasattr(result.result, 'data_array') and result.result.data_array: if (
hasattr(result.result, "data_array")
and result.result.data_array
):
# Flatten all values into a single list # Flatten all values into a single list
for chunk in result.result.data_array: for chunk in result.result.data_array:
for item in chunk: for item in chunk:
@@ -386,32 +458,43 @@ class DatabricksQueryTool(BaseTool):
# Use regex pattern to identify ID columns that likely start a new row # Use regex pattern to identify ID columns that likely start a new row
import re import re
id_pattern = re.compile(r'^\d{5,9}$') # Netflix IDs are often 5-9 digits
id_pattern = re.compile(
r"^\d{5,9}$"
) # Netflix IDs are often 5-9 digits
id_indices = [] id_indices = []
for i, val in enumerate(all_values): for i, val in enumerate(all_values):
if isinstance(val, str) and id_pattern.match(val): if isinstance(val, str) and id_pattern.match(val):
# This value looks like an ID, might be the start of a row # This value looks like an ID, might be the start of a row
if i < len(all_values) - 1: if i < len(all_values) - 1:
next_few_values = all_values[i+1:i+5] next_few_values = all_values[i + 1 : i + 5]
# If following values look like they could be part of a title # If following values look like they could be part of a title
if any(isinstance(v, str) and len(v) > 1 for v in next_few_values): if any(
isinstance(v, str) and len(v) > 1
for v in next_few_values
):
id_indices.append(i) id_indices.append(i)
if id_indices: if id_indices:
# If we found potential row starts, use them to extract rows # If we found potential row starts, use them to extract rows
for i in range(len(id_indices)): for i in range(len(id_indices)):
start_idx = id_indices[i] start_idx = id_indices[i]
end_idx = id_indices[i+1] if i+1 < len(id_indices) else len(all_values) end_idx = (
id_indices[i + 1]
if i + 1 < len(id_indices)
else len(all_values)
)
# Extract values for this row # Extract values for this row
row_values = all_values[start_idx:end_idx] row_values = all_values[start_idx:end_idx]
# Special handling for Netflix title data # Special handling for Netflix title data
# Titles might be split into individual characters # Titles might be split into individual characters
if 'Title' in columns and len(row_values) > expected_column_count: if (
"Title" in columns
and len(row_values) > expected_column_count
):
# Try to reconstruct by looking for patterns # Try to reconstruct by looking for patterns
# We know ID is first, then Title (which may be split) # We know ID is first, then Title (which may be split)
# Then other fields like Genre, etc. # Then other fields like Genre, etc.
@@ -424,7 +507,14 @@ class DatabricksQueryTool(BaseTool):
for j in range(2, min(100, len(row_values))): for j in range(2, min(100, len(row_values))):
val = row_values[j] val = row_values[j]
# Check for common genres or non-title markers # Check for common genres or non-title markers
if isinstance(val, str) and val in ['Comedy', 'Drama', 'Action', 'Horror', 'Thriller', 'Documentary']: if isinstance(val, str) and val in [
"Comedy",
"Drama",
"Action",
"Horror",
"Thriller",
"Documentary",
]:
# Likely found the Genre field # Likely found the Genre field
title_end_idx = j title_end_idx = j
break break
@@ -433,15 +523,24 @@ class DatabricksQueryTool(BaseTool):
if title_end_idx > 1: if title_end_idx > 1:
title_chars = row_values[1:title_end_idx] title_chars = row_values[1:title_end_idx]
# Check if they're individual characters # Check if they're individual characters
if all(isinstance(c, str) and len(c) == 1 for c in title_chars): if all(
title = ''.join(title_chars) isinstance(c, str) and len(c) == 1
row_dict['Title'] = title for c in title_chars
):
title = "".join(title_chars)
row_dict["Title"] = title
# Assign remaining values to columns # Assign remaining values to columns
remaining_values = row_values[title_end_idx:] remaining_values = row_values[
for j, col_name in enumerate(columns[2:], 2): title_end_idx:
if j-2 < len(remaining_values): ]
row_dict[col_name] = remaining_values[j-2] for j, col_name in enumerate(
columns[2:], 2
):
if j - 2 < len(remaining_values):
row_dict[col_name] = (
remaining_values[j - 2]
)
else: else:
row_dict[col_name] = None row_dict[col_name] = None
else: else:
@@ -463,7 +562,9 @@ class DatabricksQueryTool(BaseTool):
reconstructed_rows.append(row_dict) reconstructed_rows.append(row_dict)
else: else:
# More intelligent chunking - try to detect where columns like Title might be split # More intelligent chunking - try to detect where columns like Title might be split
title_idx = columns.index('Title') if 'Title' in columns else -1 title_idx = (
columns.index("Title") if "Title" in columns else -1
)
if title_idx >= 0: if title_idx >= 0:
print("Attempting title reconstruction method") print("Attempting title reconstruction method")
@@ -471,21 +572,27 @@ class DatabricksQueryTool(BaseTool):
i = 0 i = 0
while i < len(all_values): while i < len(all_values):
# Check if this could be an ID (start of a row) # Check if this could be an ID (start of a row)
if isinstance(all_values[i], str) and id_pattern.match(all_values[i]): if isinstance(
all_values[i], str
) and id_pattern.match(all_values[i]):
row_dict = {columns[0]: all_values[i]} row_dict = {columns[0]: all_values[i]}
i += 1 i += 1
# Try to reconstruct title if it appears to be split # Try to reconstruct title if it appears to be split
title_chars = [] title_chars = []
while (i < len(all_values) and while (
isinstance(all_values[i], str) and i < len(all_values)
len(all_values[i]) <= 1 and and isinstance(all_values[i], str)
len(title_chars) < 100): # Cap title length and len(all_values[i]) <= 1
and len(title_chars) < 100
): # Cap title length
title_chars.append(all_values[i]) title_chars.append(all_values[i])
i += 1 i += 1
if title_chars: if title_chars:
row_dict[columns[title_idx]] = ''.join(title_chars) row_dict[columns[title_idx]] = "".join(
title_chars
)
# Add remaining fields # Add remaining fields
for j in range(title_idx + 1, len(columns)): for j in range(title_idx + 1, len(columns)):
@@ -502,11 +609,18 @@ class DatabricksQueryTool(BaseTool):
# If we still don't have rows, use simple chunking as fallback # If we still don't have rows, use simple chunking as fallback
if not reconstructed_rows: if not reconstructed_rows:
print("Falling back to basic chunking approach") print("Falling back to basic chunking approach")
chunks = [all_values[i:i+expected_column_count] for i in range(0, len(all_values), expected_column_count)] chunks = [
all_values[i : i + expected_column_count]
for i in range(
0, len(all_values), expected_column_count
)
]
for chunk in chunks: for chunk in chunks:
# Skip chunks that seem to be partial/incomplete rows # Skip chunks that seem to be partial/incomplete rows
if len(chunk) < expected_column_count * 0.75: # Allow for some missing values if (
len(chunk) < expected_column_count * 0.75
): # Allow for some missing values
continue continue
row_dict = {} row_dict = {}
@@ -521,13 +635,16 @@ class DatabricksQueryTool(BaseTool):
reconstructed_rows.append(row_dict) reconstructed_rows.append(row_dict)
# Apply post-processing to fix known issues # Apply post-processing to fix known issues
if reconstructed_rows and 'Title' in columns: if reconstructed_rows and "Title" in columns:
print("Applying post-processing to improve data quality") print("Applying post-processing to improve data quality")
for row in reconstructed_rows: for row in reconstructed_rows:
# Fix titles that might still have issues # Fix titles that might still have issues
if isinstance(row.get('Title'), str) and len(row.get('Title')) <= 1: if (
isinstance(row.get("Title"), str)
and len(row.get("Title")) <= 1
):
# This is likely still a fragmented title - mark as potentially incomplete # This is likely still a fragmented title - mark as potentially incomplete
row['Title'] = f"[INCOMPLETE] {row.get('Title')}" row["Title"] = f"[INCOMPLETE] {row.get('Title')}"
# Ensure we respect the row limit # Ensure we respect the row limit
if row_limit and len(reconstructed_rows) > row_limit: if row_limit and len(reconstructed_rows) > row_limit:
@@ -539,28 +656,53 @@ class DatabricksQueryTool(BaseTool):
print("Using standard processing mode") print("Using standard processing mode")
# Check different result structures # Check different result structures
if hasattr(result.result, 'data_array') and result.result.data_array: if (
hasattr(result.result, "data_array")
and result.result.data_array
):
# Check if data appears to be malformed within chunks # Check if data appears to be malformed within chunks
for chunk_idx, chunk in enumerate(result.result.data_array): for _chunk_idx, chunk in enumerate(
result.result.data_array
):
# Check if chunk might actually contain individual columns of a single row # Check if chunk might actually contain individual columns of a single row
# This is another way data might be malformed - check the first few values # This is another way data might be malformed - check the first few values
if len(chunk) > 0 and len(columns) > 1: if len(chunk) > 0 and len(columns) > 1:
# If there seems to be a mismatch between chunk structure and expected columns # If there seems to be a mismatch between chunk structure and expected columns
first_few_values = chunk[:min(5, len(chunk))] first_few_values = chunk[: min(5, len(chunk))]
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in first_few_values): if all(
if len(chunk) > len(columns) * 3: # Heuristic: if chunk has way more items than columns isinstance(val, (str, int, float))
print("Chunk appears to contain individual values rather than rows - switching to row reconstruction") and not isinstance(val, (list, dict))
for val in first_few_values
):
if (
len(chunk) > len(columns) * 3
): # Heuristic: if chunk has way more items than columns
print(
"Chunk appears to contain individual values rather than rows - switching to row reconstruction"
)
# This chunk might actually be values of multiple rows - try to reconstruct # This chunk might actually be values of multiple rows - try to reconstruct
values = chunk # All values in this chunk values = chunk # All values in this chunk
reconstructed_rows = [] reconstructed_rows = []
# Try to create rows based on expected column count # Try to create rows based on expected column count
for i in range(0, len(values), len(columns)): for i in range(
if i + len(columns) <= len(values): # Ensure we have enough values 0, len(values), len(columns)
row_values = values[i:i+len(columns)] ):
row_dict = {col: val for col, val in zip(columns, row_values)} if i + len(columns) <= len(
values
): # Ensure we have enough values
row_values = values[
i : i + len(columns)
]
row_dict = {
col: val
for col, val in zip(
columns,
row_values,
strict=False,
)
}
reconstructed_rows.append(row_dict) reconstructed_rows.append(row_dict)
if reconstructed_rows: if reconstructed_rows:
@@ -569,21 +711,36 @@ class DatabricksQueryTool(BaseTool):
# Special case: when chunk contains exactly the right number of values for a single row # Special case: when chunk contains exactly the right number of values for a single row
# This handles the case where instead of a list of rows, we just got all values in a flat list # This handles the case where instead of a list of rows, we just got all values in a flat list
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in chunk): if all(
if len(chunk) == len(columns) or (len(chunk) > 0 and len(chunk) % len(columns) == 0): isinstance(val, (str, int, float))
and not isinstance(val, (list, dict))
for val in chunk
):
if len(chunk) == len(columns) or (
len(chunk) > 0
and len(chunk) % len(columns) == 0
):
# Process flat list of values as rows # Process flat list of values as rows
for i in range(0, len(chunk), len(columns)): for i in range(0, len(chunk), len(columns)):
row_values = chunk[i:i+len(columns)] row_values = chunk[i : i + len(columns)]
if len(row_values) == len(columns): # Only process complete rows if len(row_values) == len(
row_dict = {col: val for col, val in zip(columns, row_values)} columns
): # Only process complete rows
row_dict = {
col: val
for col, val in zip(
columns,
row_values,
strict=False,
)
}
chunk_results.append(row_dict) chunk_results.append(row_dict)
# Skip regular row processing for this chunk # Skip regular row processing for this chunk
continue continue
# Normal processing for typical row structure # Normal processing for typical row structure
for row_idx, row in enumerate(chunk): for _row_idx, row in enumerate(chunk):
# Ensure row is actually a collection of values # Ensure row is actually a collection of values
if not isinstance(row, (list, tuple, dict)): if not isinstance(row, (list, tuple, dict)):
# This might be a single value; skip it or handle specially # This might be a single value; skip it or handle specially
@@ -599,7 +756,9 @@ class DatabricksQueryTool(BaseTool):
elif isinstance(row, (list, tuple)): elif isinstance(row, (list, tuple)):
# Map list of values to columns # Map list of values to columns
for i, val in enumerate(row): for i, val in enumerate(row):
if i < len(columns): # Only process if we have a matching column if (
i < len(columns)
): # Only process if we have a matching column
row_dict[columns[i]] = val row_dict[columns[i]] = val
else: else:
# Extra values without column names # Extra values without column names
@@ -614,16 +773,18 @@ class DatabricksQueryTool(BaseTool):
chunk_results.append(row_dict) chunk_results.append(row_dict)
elif hasattr(result.result, 'data') and result.result.data: elif hasattr(result.result, "data") and result.result.data:
# Alternative data structure # Alternative data structure
for row_idx, row in enumerate(result.result.data): for _row_idx, row in enumerate(result.result.data):
# Debug info # Debug info
# Safely create dictionary matching column names to values # Safely create dictionary matching column names to values
row_dict = {} row_dict = {}
for i, val in enumerate(row): for i, val in enumerate(row):
if i < len(columns): # Only process if we have a matching column if i < len(
columns
): # Only process if we have a matching column
row_dict[columns[i]] = val row_dict[columns[i]] = val
else: else:
# Extra values without column names # Extra values without column names
@@ -642,7 +803,9 @@ class DatabricksQueryTool(BaseTool):
normalized_results = [] normalized_results = []
for row in chunk_results: for row in chunk_results:
# Create a new row with all columns, defaulting to None for missing ones # Create a new row with all columns, defaulting to None for missing ones
normalized_row = {col: row.get(col, None) for col in all_columns} normalized_row = {
col: row.get(col, None) for col in all_columns
}
normalized_results.append(normalized_row) normalized_results.append(normalized_row)
# Replace the original results with normalized ones # Replace the original results with normalized ones
@@ -651,11 +814,12 @@ class DatabricksQueryTool(BaseTool):
except Exception as results_error: except Exception as results_error:
# Enhanced error message with more context # Enhanced error message with more context
import traceback import traceback
error_details = traceback.format_exc() error_details = traceback.format_exc()
return f"Error processing query results: {str(results_error)}\n\nDetails:\n{error_details}" return f"Error processing query results: {results_error!s}\n\nDetails:\n{error_details}"
# If we have no results but the query succeeded (e.g., for DDL statements) # If we have no results but the query succeeded (e.g., for DDL statements)
if not chunk_results and hasattr(result, 'status'): if not chunk_results and hasattr(result, "status"):
state_value = str(result.status.state) state_value = str(result.status.state)
if "SUCCEEDED" in state_value: if "SUCCEEDED" in state_value:
return "Query executed successfully (no results to display)" return "Query executed successfully (no results to display)"
@@ -666,5 +830,8 @@ class DatabricksQueryTool(BaseTool):
except Exception as e: except Exception as e:
# Include more details in the error message to help with debugging # Include more details in the error message to help with debugging
import traceback import traceback
error_details = traceback.format_exc() error_details = traceback.format_exc()
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}" return (
f"Error executing Databricks query: {e!s}\n\nDetails:\n{error_details}"
)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Optional, Type from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -20,10 +20,10 @@ class DirectoryReadTool(BaseTool):
description: str = ( description: str = (
"A tool that can be used to recursively list a directory's content." "A tool that can be used to recursively list a directory's content."
) )
args_schema: Type[BaseModel] = DirectoryReadToolSchema args_schema: type[BaseModel] = DirectoryReadToolSchema
directory: Optional[str] = None directory: str | None = None
def __init__(self, directory: Optional[str] = None, **kwargs): def __init__(self, directory: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if directory is not None: if directory is not None:
self.directory = directory self.directory = directory

View File

@@ -1,11 +1,4 @@
from typing import Optional, Type from crewai_tools.rag.data_types import DataType
try:
from embedchain.loaders.directory_loader import DirectoryLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -31,11 +24,9 @@ class DirectorySearchTool(RagTool):
description: str = ( description: str = (
"A tool that can be used to semantic search a query from a directory's content." "A tool that can be used to semantic search a query from a directory's content."
) )
args_schema: Type[BaseModel] = DirectorySearchToolSchema args_schema: type[BaseModel] = DirectorySearchToolSchema
def __init__(self, directory: Optional[str] = None, **kwargs): def __init__(self, directory: str | None = None, **kwargs):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
if directory is not None: if directory is not None:
self.add(directory) self.add(directory)
@@ -44,16 +35,17 @@ class DirectorySearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, directory: str) -> None: def add(self, directory: str) -> None:
super().add( super().add(directory, data_type=DataType.DIRECTORY)
directory,
loader=DirectoryLoader(config=dict(recursive=True)),
)
def _run( def _run(
self, self,
search_query: str, search_query: str,
directory: Optional[str] = None, directory: str | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if directory is not None: if directory is not None:
self.add(directory) self.add(directory)
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -1,11 +1,6 @@
from typing import Any, Optional, Type from typing import Any
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from crewai_tools.rag.data_types import DataType
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -14,7 +9,7 @@ from ..rag.rag_tool import RagTool
class FixedDOCXSearchToolSchema(BaseModel): class FixedDOCXSearchToolSchema(BaseModel):
"""Input for DOCXSearchTool.""" """Input for DOCXSearchTool."""
docx: Optional[str] = Field( docx: str | None = Field(
..., description="File path or URL of a DOCX file to be searched" ..., description="File path or URL of a DOCX file to be searched"
) )
search_query: str = Field( search_query: str = Field(
@@ -37,9 +32,9 @@ class DOCXSearchTool(RagTool):
description: str = ( description: str = (
"A tool that can be used to semantic search a query from a DOCX's content." "A tool that can be used to semantic search a query from a DOCX's content."
) )
args_schema: Type[BaseModel] = DOCXSearchToolSchema args_schema: type[BaseModel] = DOCXSearchToolSchema
def __init__(self, docx: Optional[str] = None, **kwargs): def __init__(self, docx: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docx is not None: if docx is not None:
self.add(docx) self.add(docx)
@@ -48,15 +43,17 @@ class DOCXSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, docx: str) -> None: def add(self, docx: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(docx, data_type=DataType.DOCX) super().add(docx, data_type=DataType.DOCX)
def _run( def _run(
self, self,
search_query: str, search_query: str,
docx: Optional[str] = None, docx: str | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> Any: ) -> Any:
if docx is not None: if docx is not None:
self.add(docx) self.add(docx)
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, List, Optional, Type from typing import Any, Optional
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -17,13 +17,11 @@ class EXABaseToolSchema(BaseModel):
search_query: str = Field( search_query: str = Field(
..., description="Mandatory search query you want to use to search the internet" ..., description="Mandatory search query you want to use to search the internet"
) )
start_published_date: Optional[str] = Field( start_published_date: str | None = Field(
None, description="Start date for the search" None, description="Start date for the search"
) )
end_published_date: Optional[str] = Field( end_published_date: str | None = Field(None, description="End date for the search")
None, description="End date for the search" include_domains: list[str] | None = Field(
)
include_domains: Optional[list[str]] = Field(
None, description="List of domains to include in the search" None, description="List of domains to include in the search"
) )
@@ -32,18 +30,18 @@ class EXASearchTool(BaseTool):
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
name: str = "EXASearchTool" name: str = "EXASearchTool"
description: str = "Search the internet using Exa" description: str = "Search the internet using Exa"
args_schema: Type[BaseModel] = EXABaseToolSchema args_schema: type[BaseModel] = EXABaseToolSchema
client: Optional["Exa"] = None client: Optional["Exa"] = None
content: Optional[bool] = False content: bool | None = False
summary: Optional[bool] = False summary: bool | None = False
type: Optional[str] = "auto" type: str | None = "auto"
package_dependencies: List[str] = ["exa_py"] package_dependencies: list[str] = ["exa_py"]
api_key: Optional[str] = Field( api_key: str | None = Field(
default_factory=lambda: os.getenv("EXA_API_KEY"), default_factory=lambda: os.getenv("EXA_API_KEY"),
description="API key for Exa services", description="API key for Exa services",
json_schema_extra={"required": False}, json_schema_extra={"required": False},
) )
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar( EnvVar(
name="EXA_API_KEY", description="API key for Exa services", required=False name="EXA_API_KEY", description="API key for Exa services", required=False
), ),
@@ -51,9 +49,9 @@ class EXASearchTool(BaseTool):
def __init__( def __init__(
self, self,
content: Optional[bool] = False, content: bool | None = False,
summary: Optional[bool] = False, summary: bool | None = False,
type: Optional[str] = "auto", type: str | None = "auto",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@@ -81,9 +79,9 @@ class EXASearchTool(BaseTool):
def _run( def _run(
self, self,
search_query: str, search_query: str,
start_published_date: Optional[str] = None, start_published_date: str | None = None,
end_published_date: Optional[str] = None, end_published_date: str | None = None,
include_domains: Optional[list[str]] = None, include_domains: list[str] | None = None,
) -> Any: ) -> Any:
if self.client is None: if self.client is None:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -8,8 +8,12 @@ class FileReadToolSchema(BaseModel):
"""Input for FileReadTool.""" """Input for FileReadTool."""
file_path: str = Field(..., description="Mandatory file full path to read the file") file_path: str = Field(..., description="Mandatory file full path to read the file")
start_line: Optional[int] = Field(1, description="Line number to start reading from (1-indexed)") start_line: int | None = Field(
line_count: Optional[int] = Field(None, description="Number of lines to read. If None, reads the entire file") 1, description="Line number to start reading from (1-indexed)"
)
line_count: int | None = Field(
None, description="Number of lines to read. If None, reads the entire file"
)
class FileReadTool(BaseTool): class FileReadTool(BaseTool):
@@ -38,10 +42,10 @@ class FileReadTool(BaseTool):
name: str = "Read a file's content" name: str = "Read a file's content"
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read. Optionally, provide 'start_line' to start reading from a specific line and 'line_count' to limit the number of lines read." description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read. Optionally, provide 'start_line' to start reading from a specific line and 'line_count' to limit the number of lines read."
args_schema: Type[BaseModel] = FileReadToolSchema args_schema: type[BaseModel] = FileReadToolSchema
file_path: Optional[str] = None file_path: str | None = None
def __init__(self, file_path: Optional[str] = None, **kwargs: Any) -> None: def __init__(self, file_path: str | None = None, **kwargs: Any) -> None:
"""Initialize the FileReadTool. """Initialize the FileReadTool.
Args: Args:
@@ -59,18 +63,16 @@ class FileReadTool(BaseTool):
def _run( def _run(
self, self,
file_path: Optional[str] = None, file_path: str | None = None,
start_line: Optional[int] = 1, start_line: int | None = 1,
line_count: Optional[int] = None, line_count: int | None = None,
) -> str: ) -> str:
file_path = file_path or self.file_path file_path = file_path or self.file_path
start_line = start_line or 1 start_line = start_line or 1
line_count = line_count or None line_count = line_count or None
if file_path is None: if file_path is None:
return ( return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
"Error: No file path provided. Please provide a file path either in the constructor or as an argument."
)
try: try:
with open(file_path, "r") as file: with open(file_path, "r") as file:
@@ -82,7 +84,8 @@ class FileReadTool(BaseTool):
selected_lines = [ selected_lines = [
line line
for i, line in enumerate(file) for i, line in enumerate(file)
if i >= start_idx and (line_count is None or i < start_idx + line_count) if i >= start_idx
and (line_count is None or i < start_idx + line_count)
] ]
if not selected_lines and start_idx > 0: if not selected_lines and start_idx > 0:
@@ -94,4 +97,4 @@ class FileReadTool(BaseTool):
except PermissionError: except PermissionError:
return f"Error: Permission denied when trying to read file: {file_path}" return f"Error: Permission denied when trying to read file: {file_path}"
except Exception as e: except Exception as e:
return f"Error: Failed to read file {file_path}. {str(e)}" return f"Error: Failed to read file {file_path}. {e!s}"

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Optional, Type from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,25 +11,22 @@ def strtobool(val) -> bool:
val = val.lower() val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"): if val in ("y", "yes", "t", "true", "on", "1"):
return True return True
elif val in ("n", "no", "f", "false", "off", "0"): if val in ("n", "no", "f", "false", "off", "0"):
return False return False
else: raise ValueError(f"invalid value to cast to bool: {val!r}")
raise ValueError(f"invalid value to cast to bool: {val!r}")
class FileWriterToolInput(BaseModel): class FileWriterToolInput(BaseModel):
filename: str filename: str
directory: Optional[str] = "./" directory: str | None = "./"
overwrite: str | bool = False overwrite: str | bool = False
content: str content: str
class FileWriterTool(BaseTool): class FileWriterTool(BaseTool):
name: str = "File Writer Tool" name: str = "File Writer Tool"
description: str = ( description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." args_schema: type[BaseModel] = FileWriterToolInput
)
args_schema: Type[BaseModel] = FileWriterToolInput
def _run(self, **kwargs: Any) -> str: def _run(self, **kwargs: Any) -> str:
try: try:
@@ -57,6 +54,6 @@ class FileWriterTool(BaseTool):
f"File {filepath} already exists and overwrite option was not passed." f"File {filepath} already exists and overwrite option was not passed."
) )
except KeyError as e: except KeyError as e:
return f"An error occurred while accessing key: {str(e)}" return f"An error occurred while accessing key: {e!s}"
except Exception as e: except Exception as e:
return f"An error occurred while writing to the file: {str(e)}" return f"An error occurred while writing to the file: {e!s}"

View File

@@ -3,7 +3,6 @@ import shutil
import tempfile import tempfile
import pytest import pytest
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool

View File

@@ -1,17 +1,28 @@
import os import os
import zipfile
import tarfile import tarfile
from typing import Type, Optional import zipfile
from pydantic import BaseModel, Field
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class FileCompressorToolInput(BaseModel): class FileCompressorToolInput(BaseModel):
"""Input schema for FileCompressorTool.""" """Input schema for FileCompressorTool."""
input_path: str = Field(..., description="Path to the file or directory to compress.")
output_path: Optional[str] = Field(default=None, description="Optional output archive filename.") input_path: str = Field(
overwrite: bool = Field(default=False, description="Whether to overwrite the archive if it already exists.") ..., description="Path to the file or directory to compress."
format: str = Field(default="zip", description="Compression format ('zip', 'tar', 'tar.gz', 'tar.bz2', 'tar.xz').") )
output_path: str | None = Field(
default=None, description="Optional output archive filename."
)
overwrite: bool = Field(
default=False,
description="Whether to overwrite the archive if it already exists.",
)
format: str = Field(
default="zip",
description="Compression format ('zip', 'tar', 'tar.gz', 'tar.bz2', 'tar.xz').",
)
class FileCompressorTool(BaseTool): class FileCompressorTool(BaseTool):
@@ -20,58 +31,65 @@ class FileCompressorTool(BaseTool):
"Compresses a file or directory into an archive (.zip currently supported). " "Compresses a file or directory into an archive (.zip currently supported). "
"Useful for archiving logs, documents, or backups." "Useful for archiving logs, documents, or backups."
) )
args_schema: Type[BaseModel] = FileCompressorToolInput args_schema: type[BaseModel] = FileCompressorToolInput
def _run(
def _run(self, input_path: str, output_path: Optional[str] = None, overwrite: bool = False, format: str = "zip") -> str: self,
input_path: str,
if not os.path.exists(input_path): output_path: str | None = None,
return f"Input path '{input_path}' does not exist." overwrite: bool = False,
format: str = "zip",
if not output_path: ) -> str:
output_path = self._generate_output_path(input_path, format) if not os.path.exists(input_path):
return f"Input path '{input_path}' does not exist."
FORMAT_EXTENSION = {
"zip": ".zip",
"tar": ".tar",
"tar.gz": ".tar.gz",
"tar.bz2": ".tar.bz2",
"tar.xz": ".tar.xz"
}
if format not in FORMAT_EXTENSION:
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(FORMAT_EXTENSION.keys())}"
elif not output_path.endswith(FORMAT_EXTENSION[format]):
return f"Error: If '{format}' format is chosen, output file must have a '{FORMAT_EXTENSION[format]}' extension."
if not self._prepare_output(output_path, overwrite):
return f"Output '{output_path}' already exists and overwrite is set to False."
try: if not output_path:
format_compression = { output_path = self._generate_output_path(input_path, format)
"zip": self._compress_zip,
"tar": self._compress_tar,
"tar.gz": self._compress_tar,
"tar.bz2": self._compress_tar,
"tar.xz": self._compress_tar
}
if format == "zip":
format_compression[format](input_path, output_path)
else:
format_compression[format](input_path, output_path, format)
return f"Successfully compressed '{input_path}' into '{output_path}'"
except FileNotFoundError:
return f"Error: File not found at path: {input_path}"
except PermissionError:
return f"Error: Permission denied when accessing '{input_path}' or writing '{output_path}'"
except Exception as e:
return f"An unexpected error occurred during compression: {str(e)}"
FORMAT_EXTENSION = {
"zip": ".zip",
"tar": ".tar",
"tar.gz": ".tar.gz",
"tar.bz2": ".tar.bz2",
"tar.xz": ".tar.xz",
}
if format not in FORMAT_EXTENSION:
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(FORMAT_EXTENSION.keys())}"
if not output_path.endswith(FORMAT_EXTENSION[format]):
return f"Error: If '{format}' format is chosen, output file must have a '{FORMAT_EXTENSION[format]}' extension."
if not self._prepare_output(output_path, overwrite):
return (
f"Output '{output_path}' already exists and overwrite is set to False."
)
try:
format_compression = {
"zip": self._compress_zip,
"tar": self._compress_tar,
"tar.gz": self._compress_tar,
"tar.bz2": self._compress_tar,
"tar.xz": self._compress_tar,
}
if format == "zip":
format_compression[format](input_path, output_path)
else:
format_compression[format](input_path, output_path, format)
return f"Successfully compressed '{input_path}' into '{output_path}'"
except FileNotFoundError:
return f"Error: File not found at path: {input_path}"
except PermissionError:
return f"Error: Permission denied when accessing '{input_path}' or writing '{output_path}'"
except Exception as e:
return f"An unexpected error occurred during compression: {e!s}"
def _generate_output_path(self, input_path: str, format: str) -> str: def _generate_output_path(self, input_path: str, format: str) -> str:
"""Generates output path based on input path and format.""" """Generates output path based on input path and format."""
if os.path.isfile(input_path): if os.path.isfile(input_path):
base_name = os.path.splitext(os.path.basename(input_path))[0] # Remove extension base_name = os.path.splitext(os.path.basename(input_path))[
0
] # Remove extension
else: else:
base_name = os.path.basename(os.path.normpath(input_path)) # Directory name base_name = os.path.basename(os.path.normpath(input_path)) # Directory name
return os.path.join(os.getcwd(), f"{base_name}.{format}") return os.path.join(os.getcwd(), f"{base_name}.{format}")
@@ -87,7 +105,7 @@ class FileCompressorTool(BaseTool):
def _compress_zip(self, input_path: str, output_path: str): def _compress_zip(self, input_path: str, output_path: str):
"""Compresses input into a zip archive.""" """Compresses input into a zip archive."""
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
if os.path.isfile(input_path): if os.path.isfile(input_path):
zipf.write(input_path, os.path.basename(input_path)) zipf.write(input_path, os.path.basename(input_path))
else: else:
@@ -97,19 +115,18 @@ class FileCompressorTool(BaseTool):
arcname = os.path.relpath(full_path, start=input_path) arcname = os.path.relpath(full_path, start=input_path)
zipf.write(full_path, arcname) zipf.write(full_path, arcname)
def _compress_tar(self, input_path: str, output_path: str, format: str): def _compress_tar(self, input_path: str, output_path: str, format: str):
"""Compresses input into a tar archive with the given format.""" """Compresses input into a tar archive with the given format."""
format_mode = { format_mode = {
"tar": "w", "tar": "w",
"tar.gz": "w:gz", "tar.gz": "w:gz",
"tar.bz2": "w:bz2", "tar.bz2": "w:bz2",
"tar.xz": "w:xz" "tar.xz": "w:xz",
} }
if format not in format_mode: if format not in format_mode:
raise ValueError(f"Unsupported tar format: {format}") raise ValueError(f"Unsupported tar format: {format}")
mode = format_mode[format] mode = format_mode[format]
with tarfile.open(output_path, mode) as tarf: with tarfile.open(output_path, mode) as tarf:

View File

@@ -1,88 +1,126 @@
from unittest.mock import patch
import os
import pytest import pytest
from crewai_tools.tools.files_compressor_tool import FileCompressorTool from crewai_tools.tools.files_compressor_tool import FileCompressorTool
from unittest.mock import patch, MagicMock
@pytest.fixture @pytest.fixture
def tool(): def tool():
return FileCompressorTool() return FileCompressorTool()
@patch("os.path.exists", return_value=False) @patch("os.path.exists", return_value=False)
def test_input_path_does_not_exist(mock_exists, tool): def test_input_path_does_not_exist(mock_exists, tool):
result = tool._run("nonexistent_path") result = tool._run("nonexistent_path")
assert "does not exist" in result assert "does not exist" in result
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch("os.getcwd", return_value="/mocked/cwd") @patch("os.getcwd", return_value="/mocked/cwd")
@patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression @patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression
@patch.object(FileCompressorTool, "_prepare_output", return_value=True) @patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_generate_output_path_default(mock_prepare, mock_compress, mock_cwd, mock_exists, tool): def test_generate_output_path_default(
mock_prepare, mock_compress, mock_cwd, mock_exists, tool
):
result = tool._run(input_path="mydir", format="zip") result = tool._run(input_path="mydir", format="zip")
assert "Successfully compressed" in result assert "Successfully compressed" in result
mock_compress.assert_called_once() mock_compress.assert_called_once()
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_zip") @patch.object(FileCompressorTool, "_compress_zip")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True) @patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_zip_compression(mock_prepare, mock_compress, mock_exists, tool): def test_zip_compression(mock_prepare, mock_compress, mock_exists, tool):
result = tool._run(input_path="some/path", output_path="archive.zip", format="zip", overwrite=True) result = tool._run(
input_path="some/path", output_path="archive.zip", format="zip", overwrite=True
)
assert "Successfully compressed" in result assert "Successfully compressed" in result
mock_compress.assert_called_once() mock_compress.assert_called_once()
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_tar") @patch.object(FileCompressorTool, "_compress_tar")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True) @patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_tar_gz_compression(mock_prepare, mock_compress, mock_exists, tool): def test_tar_gz_compression(mock_prepare, mock_compress, mock_exists, tool):
result = tool._run(input_path="some/path", output_path="archive.tar.gz", format="tar.gz", overwrite=True) result = tool._run(
input_path="some/path",
output_path="archive.tar.gz",
format="tar.gz",
overwrite=True,
)
assert "Successfully compressed" in result assert "Successfully compressed" in result
mock_compress.assert_called_once() mock_compress.assert_called_once()
@pytest.mark.parametrize("format", ["tar", "tar.bz2", "tar.xz"]) @pytest.mark.parametrize("format", ["tar", "tar.bz2", "tar.xz"])
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_tar") @patch.object(FileCompressorTool, "_compress_tar")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True) @patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_other_tar_formats(mock_prepare, mock_compress, mock_exists, format, tool): def test_other_tar_formats(mock_prepare, mock_compress, mock_exists, format, tool):
result = tool._run(input_path="path/to/input", output_path=f"archive.{format}", format=format, overwrite=True) result = tool._run(
input_path="path/to/input",
output_path=f"archive.{format}",
format=format,
overwrite=True,
)
assert "Successfully compressed" in result assert "Successfully compressed" in result
mock_compress.assert_called_once() mock_compress.assert_called_once()
@pytest.mark.parametrize("format", ["rar", "7z"]) @pytest.mark.parametrize("format", ["rar", "7z"])
@patch("os.path.exists", return_value=True) #Ensure input_path exists @patch("os.path.exists", return_value=True) # Ensure input_path exists
def test_unsupported_format(_, tool, format): def test_unsupported_format(_, tool, format):
result = tool._run(input_path="some/path", output_path=f"archive.{format}", format=format) result = tool._run(
input_path="some/path", output_path=f"archive.{format}", format=format
)
assert "not supported" in result assert "not supported" in result
@patch("os.path.exists", return_value=True)
def test_extension_mismatch(_ , tool): @patch("os.path.exists", return_value=True)
result = tool._run(input_path="some/path", output_path="archive.zip", format="tar.gz") def test_extension_mismatch(_, tool):
result = tool._run(
input_path="some/path", output_path="archive.zip", format="tar.gz"
)
assert "must have a '.tar.gz' extension" in result assert "must have a '.tar.gz' extension" in result
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True) @patch("os.path.isfile", return_value=True)
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
def test_existing_output_no_overwrite(_, __, ___, tool): def test_existing_output_no_overwrite(_, __, ___, tool):
result = tool._run(input_path="some/path", output_path="archive.zip", format="zip", overwrite=False) result = tool._run(
input_path="some/path", output_path="archive.zip", format="zip", overwrite=False
)
assert "overwrite is set to False" in result assert "overwrite is set to False" in result
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=PermissionError) @patch("zipfile.ZipFile", side_effect=PermissionError)
def test_permission_error(mock_zip, _, tool): def test_permission_error(mock_zip, _, tool):
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True) result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "Permission denied" in result assert "Permission denied" in result
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=FileNotFoundError) @patch("zipfile.ZipFile", side_effect=FileNotFoundError)
def test_file_not_found_during_zip(mock_zip, _, tool): def test_file_not_found_during_zip(mock_zip, _, tool):
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True) result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "File not found" in result assert "File not found" in result
@patch("os.path.exists", return_value=True) @patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=Exception("Unexpected")) @patch("zipfile.ZipFile", side_effect=Exception("Unexpected"))
def test_general_exception_during_zip(mock_zip, _, tool): def test_general_exception_during_zip(mock_zip, _, tool):
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True) result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "unexpected error" in result assert "unexpected error" in result
# Test: Output directory is created when missing # Test: Output directory is created when missing
@patch("os.makedirs") @patch("os.makedirs")
@patch("os.path.exists", return_value=False) @patch("os.path.exists", return_value=False)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type, List, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Optional
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -43,9 +43,9 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
) )
name: str = "Firecrawl web crawl tool" name: str = "Firecrawl web crawl tool"
description: str = "Crawl webpages using Firecrawl and return the contents" description: str = "Crawl webpages using Firecrawl and return the contents"
args_schema: Type[BaseModel] = FirecrawlCrawlWebsiteToolSchema args_schema: type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
api_key: Optional[str] = None api_key: str | None = None
config: Optional[dict[str, Any]] = Field( config: dict[str, Any] | None = Field(
default_factory=lambda: { default_factory=lambda: {
"maxDepth": 2, "maxDepth": 2,
"ignoreSitemap": True, "ignoreSitemap": True,
@@ -60,12 +60,16 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
} }
) )
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"] package_dependencies: list[str] = ["firecrawl-py"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True), EnvVar(
name="FIRECRAWL_API_KEY",
description="API key for Firecrawl services",
required=True,
),
] ]
def __init__(self, api_key: Optional[str] = None, **kwargs): def __init__(self, api_key: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = api_key self.api_key = api_key
self._initialize_firecrawl() self._initialize_firecrawl()

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type, Dict, List, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Optional
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -41,9 +41,9 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
) )
name: str = "Firecrawl web scrape tool" name: str = "Firecrawl web scrape tool"
description: str = "Scrape webpages using Firecrawl and return the contents" description: str = "Scrape webpages using Firecrawl and return the contents"
args_schema: Type[BaseModel] = FirecrawlScrapeWebsiteToolSchema args_schema: type[BaseModel] = FirecrawlScrapeWebsiteToolSchema
api_key: Optional[str] = None api_key: str | None = None
config: Dict[str, Any] = Field( config: dict[str, Any] = Field(
default_factory=lambda: { default_factory=lambda: {
"formats": ["markdown"], "formats": ["markdown"],
"onlyMainContent": True, "onlyMainContent": True,
@@ -55,12 +55,16 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
) )
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"] package_dependencies: list[str] = ["firecrawl-py"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True), EnvVar(
name="FIRECRAWL_API_KEY",
description="API key for Firecrawl services",
required=True,
),
] ]
def __init__(self, api_key: Optional[str] = None, **kwargs): def __init__(self, api_key: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
try: try:
from firecrawl import FirecrawlApp # type: ignore from firecrawl import FirecrawlApp # type: ignore

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List from typing import TYPE_CHECKING, Any, Optional
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -36,17 +36,14 @@ class FirecrawlSearchTool(BaseTool):
timeout (int): Timeout in milliseconds. Default: 60000 timeout (int): Timeout in milliseconds. Default: 60000
""" """
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False arbitrary_types_allowed=True, validate_assignment=True, frozen=False
) )
name: str = "Firecrawl web search tool" name: str = "Firecrawl web search tool"
description: str = "Search webpages using Firecrawl and return the results" description: str = "Search webpages using Firecrawl and return the results"
args_schema: Type[BaseModel] = FirecrawlSearchToolSchema args_schema: type[BaseModel] = FirecrawlSearchToolSchema
api_key: Optional[str] = None api_key: str | None = None
config: Optional[dict[str, Any]] = Field( config: dict[str, Any] | None = Field(
default_factory=lambda: { default_factory=lambda: {
"limit": 5, "limit": 5,
"tbs": None, "tbs": None,
@@ -57,12 +54,16 @@ class FirecrawlSearchTool(BaseTool):
} }
) )
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"] package_dependencies: list[str] = ["firecrawl-py"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True), EnvVar(
name="FIRECRAWL_API_KEY",
description="API key for Firecrawl services",
required=True,
),
] ]
def __init__(self, api_key: Optional[str] = None, **kwargs): def __init__(self, api_key: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = api_key self.api_key = api_key
self._initialize_firecrawl() self._initialize_firecrawl()
@@ -116,4 +117,3 @@ except ImportError:
""" """
When this tool is not used, then exception can be ignored. When this tool is not used, then exception can be ignored.
""" """
pass

View File

@@ -1,5 +1,4 @@
import os import os
from typing import List, Optional, Type
import requests import requests
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
@@ -10,7 +9,7 @@ class GenerateCrewaiAutomationToolSchema(BaseModel):
prompt: str = Field( prompt: str = Field(
description="The prompt to generate the CrewAI automation, e.g. 'Generate a CrewAI automation that will scrape the website and store the data in a database.'" description="The prompt to generate the CrewAI automation, e.g. 'Generate a CrewAI automation that will scrape the website and store the data in a database.'"
) )
organization_id: Optional[str] = Field( organization_id: str | None = Field(
default=None, default=None,
description="The identifier for the CrewAI Enterprise organization. If not specified, a default organization will be used.", description="The identifier for the CrewAI Enterprise organization. If not specified, a default organization will be used.",
) )
@@ -23,16 +22,16 @@ class GenerateCrewaiAutomationTool(BaseTool):
"automations based on natural language descriptions. It translates high-level requirements into " "automations based on natural language descriptions. It translates high-level requirements into "
"functional CrewAI implementations." "functional CrewAI implementations."
) )
args_schema: Type[BaseModel] = GenerateCrewaiAutomationToolSchema args_schema: type[BaseModel] = GenerateCrewaiAutomationToolSchema
crewai_enterprise_url: str = Field( crewai_enterprise_url: str = Field(
default_factory=lambda: os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com"), default_factory=lambda: os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com"),
description="The base URL of CrewAI Enterprise. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.", description="The base URL of CrewAI Enterprise. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
) )
personal_access_token: Optional[str] = Field( personal_access_token: str | None = Field(
default_factory=lambda: os.getenv("CREWAI_PERSONAL_ACCESS_TOKEN"), default_factory=lambda: os.getenv("CREWAI_PERSONAL_ACCESS_TOKEN"),
description="The user's Personal Access Token to access CrewAI Enterprise API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.", description="The user's Personal Access Token to access CrewAI Enterprise API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
) )
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar( EnvVar(
name="CREWAI_PERSONAL_ACCESS_TOKEN", name="CREWAI_PERSONAL_ACCESS_TOKEN",
description="Personal Access Token for CrewAI Enterprise API", description="Personal Access Token for CrewAI Enterprise API",
@@ -57,7 +56,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
studio_project_url = response.json().get("url") studio_project_url = response.json().get("url")
return f"Generated CrewAI Studio project URL: {studio_project_url}" return f"Generated CrewAI Studio project URL: {studio_project_url}"
def _get_headers(self, organization_id: Optional[str] = None) -> dict: def _get_headers(self, organization_id: str | None = None) -> dict:
headers = { headers = {
"Authorization": f"Bearer {self.personal_access_token}", "Authorization": f"Bearer {self.personal_access_token}",
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@@ -1,12 +1,5 @@
from typing import List, Optional, Type, Any from crewai_tools.rag.data_types import DataType
from pydantic import BaseModel, Field
try:
from embedchain.loaders.github import GithubLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field, PrivateAttr
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -24,7 +17,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
"""Input for GithubSearchTool.""" """Input for GithubSearchTool."""
github_repo: str = Field(..., description="Mandatory github you want to search") github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field( content_types: list[str] = Field(
..., ...,
description="Mandatory content types you want to be included search, options: [code, repo, pr, issue]", description="Mandatory content types you want to be included search, options: [code, repo, pr, issue]",
) )
@@ -32,28 +25,22 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
class GithubSearchTool(RagTool): class GithubSearchTool(RagTool):
name: str = "Search a github repo's content" name: str = "Search a github repo's content"
description: str = ( description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
)
summarize: bool = False summarize: bool = False
gh_token: str gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema args_schema: type[BaseModel] = GithubSearchToolSchema
content_types: List[str] = Field( content_types: list[str] = Field(
default_factory=lambda: ["code", "repo", "pr", "issue"], default_factory=lambda: ["code", "repo", "pr", "issue"],
description="Content types you want to be included search, options: [code, repo, pr, issue]", description="Content types you want to be included search, options: [code, repo, pr, issue]",
) )
_loader: Any | None = PrivateAttr(default=None)
def __init__( def __init__(
self, self,
github_repo: Optional[str] = None, github_repo: str | None = None,
content_types: Optional[List[str]] = None, content_types: list[str] | None = None,
**kwargs, **kwargs,
): ):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
self._loader = GithubLoader(config={"token": self.gh_token})
if github_repo and content_types: if github_repo and content_types:
self.add(repo=github_repo, content_types=content_types) self.add(repo=github_repo, content_types=content_types)
@@ -64,25 +51,28 @@ class GithubSearchTool(RagTool):
def add( def add(
self, self,
repo: str, repo: str,
content_types: Optional[List[str]] = None, content_types: list[str] | None = None,
) -> None: ) -> None:
content_types = content_types or self.content_types content_types = content_types or self.content_types
super().add( super().add(
f"repo:{repo} type:{','.join(content_types)}", f"https://github.com/{repo}",
data_type="github", data_type=DataType.GITHUB,
loader=self._loader, metadata={"content_types": content_types, "gh_token": self.gh_token},
) )
def _run( def _run(
self, self,
search_query: str, search_query: str,
github_repo: Optional[str] = None, github_repo: str | None = None,
content_types: Optional[List[str]] = None, content_types: list[str] | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if github_repo: if github_repo:
self.add( self.add(
repo=github_repo, repo=github_repo,
content_types=content_types, content_types=content_types,
) )
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Optional, Type, Dict, Literal, Union, List from typing import Any, Literal
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -7,8 +7,13 @@ from pydantic import BaseModel, Field
class HyperbrowserLoadToolSchema(BaseModel): class HyperbrowserLoadToolSchema(BaseModel):
url: str = Field(description="Website URL") url: str = Field(description="Website URL")
operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'") operation: Literal["scrape", "crawl"] = Field(
params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait") description="Operation to perform on the website. Either 'scrape' or 'crawl'"
)
params: dict | None = Field(
description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait"
)
class HyperbrowserLoadTool(BaseTool): class HyperbrowserLoadTool(BaseTool):
"""HyperbrowserLoadTool. """HyperbrowserLoadTool.
@@ -20,19 +25,24 @@ class HyperbrowserLoadTool(BaseTool):
Args: Args:
api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly
""" """
name: str = "Hyperbrowser web load tool" name: str = "Hyperbrowser web load tool"
description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html" description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html"
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema args_schema: type[BaseModel] = HyperbrowserLoadToolSchema
api_key: Optional[str] = None api_key: str | None = None
hyperbrowser: Optional[Any] = None hyperbrowser: Any | None = None
package_dependencies: List[str] = ["hyperbrowser"] package_dependencies: list[str] = ["hyperbrowser"]
env_vars: List[EnvVar] = [ env_vars: list[EnvVar] = [
EnvVar(name="HYPERBROWSER_API_KEY", description="API key for Hyperbrowser services", required=False), EnvVar(
name="HYPERBROWSER_API_KEY",
description="API key for Hyperbrowser services",
required=False,
),
] ]
def __init__(self, api_key: Optional[str] = None, **kwargs): def __init__(self, api_key: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY') self.api_key = api_key or os.getenv("HYPERBROWSER_API_KEY")
if not api_key: if not api_key:
raise ValueError( raise ValueError(
"`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly" "`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly"
@@ -41,18 +51,22 @@ class HyperbrowserLoadTool(BaseTool):
try: try:
from hyperbrowser import Hyperbrowser from hyperbrowser import Hyperbrowser
except ImportError: except ImportError:
raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`") raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
)
if not self.api_key: if not self.api_key:
raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.") raise ValueError(
"HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable."
)
self.hyperbrowser = Hyperbrowser(api_key=self.api_key) self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
def _prepare_params(self, params: Dict) -> Dict: def _prepare_params(self, params: dict) -> dict:
"""Prepare session and scrape options parameters.""" """Prepare session and scrape options parameters."""
try: try:
from hyperbrowser.models.session import CreateSessionParams
from hyperbrowser.models.scrape import ScrapeOptions from hyperbrowser.models.scrape import ScrapeOptions
from hyperbrowser.models.session import CreateSessionParams
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`" "`hyperbrowser` package not found, please run `pip install hyperbrowser`"
@@ -70,17 +84,24 @@ class HyperbrowserLoadTool(BaseTool):
params["scrape_options"] = ScrapeOptions(**params["scrape_options"]) params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
return params return params
def _extract_content(self, data: Union[Any, None]): def _extract_content(self, data: Any | None):
"""Extract content from response data.""" """Extract content from response data."""
content = "" content = ""
if data: if data:
content = data.markdown or data.html or "" content = data.markdown or data.html or ""
return content return content
def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}): def _run(
self,
url: str,
operation: Literal["scrape", "crawl"] = "scrape",
params: dict | None = None,
):
if params is None:
params = {}
try: try:
from hyperbrowser.models.scrape import StartScrapeJobParams
from hyperbrowser.models.crawl import StartCrawlJobParams from hyperbrowser.models.crawl import StartCrawlJobParams
from hyperbrowser.models.scrape import StartScrapeJobParams
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`" "`hyperbrowser` package not found, please run `pip install hyperbrowser`"
@@ -88,20 +109,18 @@ class HyperbrowserLoadTool(BaseTool):
params = self._prepare_params(params) params = self._prepare_params(params)
if operation == 'scrape': if operation == "scrape":
scrape_params = StartScrapeJobParams(url=url, **params) scrape_params = StartScrapeJobParams(url=url, **params)
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params) scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
content = self._extract_content(scrape_resp.data) return self._extract_content(scrape_resp.data)
return content crawl_params = StartCrawlJobParams(url=url, **params)
else: crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
crawl_params = StartCrawlJobParams(url=url, **params) content = ""
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params) if crawl_resp.data:
content = "" for page in crawl_resp.data:
if crawl_resp.data: page_content = self._extract_content(page)
for page in crawl_resp.data: if page_content:
page_content = self._extract_content(page) content += (
if page_content: f"\n{'-' * 50}\nUrl: {page.url}\nContent:\n{page_content}\n"
content += ( )
f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n" return content
)
return content

View File

@@ -1,23 +1,27 @@
import time
from typing import Any
import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Type
import requests
import time
class InvokeCrewAIAutomationInput(BaseModel): class InvokeCrewAIAutomationInput(BaseModel):
"""Input schema for InvokeCrewAIAutomationTool.""" """Input schema for InvokeCrewAIAutomationTool."""
prompt: str = Field(..., description="The prompt or query to send to the crew") prompt: str = Field(..., description="The prompt or query to send to the crew")
class InvokeCrewAIAutomationTool(BaseTool): class InvokeCrewAIAutomationTool(BaseTool):
""" """
A CrewAI tool for invoking external crew/flows APIs. A CrewAI tool for invoking external crew/flows APIs.
This tool provides CrewAI Platform API integration with external crew services, supporting: This tool provides CrewAI Platform API integration with external crew services, supporting:
- Dynamic input schema configuration - Dynamic input schema configuration
- Automatic polling for task completion - Automatic polling for task completion
- Bearer token authentication - Bearer token authentication
- Comprehensive error handling - Comprehensive error handling
Example: Example:
Basic usage: Basic usage:
>>> tool = InvokeCrewAIAutomationTool( >>> tool = InvokeCrewAIAutomationTool(
@@ -26,7 +30,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
... crew_name="My Crew", ... crew_name="My Crew",
... crew_description="Description of what the crew does" ... crew_description="Description of what the crew does"
... ) ... )
With custom inputs: With custom inputs:
>>> custom_inputs = { >>> custom_inputs = {
... "param1": Field(..., description="Description of param1"), ... "param1": Field(..., description="Description of param1"),
@@ -39,7 +43,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
... crew_description="Description of what the crew does", ... crew_description="Description of what the crew does",
... crew_inputs=custom_inputs ... crew_inputs=custom_inputs
... ) ... )
Example: Example:
>>> tools=[ >>> tools=[
... InvokeCrewAIAutomationTool( ... InvokeCrewAIAutomationTool(
@@ -53,25 +57,27 @@ class InvokeCrewAIAutomationTool(BaseTool):
... ) ... )
... ] ... ]
""" """
name: str = "invoke_amp_automation" name: str = "invoke_amp_automation"
description: str = "Invokes an CrewAI Platform Automation using API" description: str = "Invokes an CrewAI Platform Automation using API"
args_schema: Type[BaseModel] = InvokeCrewAIAutomationInput args_schema: type[BaseModel] = InvokeCrewAIAutomationInput
crew_api_url: str crew_api_url: str
crew_bearer_token: str crew_bearer_token: str
max_polling_time: int = 10 * 60 # 10 minutes max_polling_time: int = 10 * 60 # 10 minutes
def __init__( def __init__(
self, self,
crew_api_url: str, crew_api_url: str,
crew_bearer_token: str, crew_bearer_token: str,
crew_name: str, crew_name: str,
crew_description: str, crew_description: str,
max_polling_time: int = 10 * 60, max_polling_time: int = 10 * 60,
crew_inputs: dict[str, Any] = None): crew_inputs: dict[str, Any] | None = None,
):
""" """
Initialize the InvokeCrewAIAutomationTool. Initialize the InvokeCrewAIAutomationTool.
Args: Args:
crew_api_url: Base URL of the crew API service crew_api_url: Base URL of the crew API service
crew_bearer_token: Bearer token for API authentication crew_bearer_token: Bearer token for API authentication
@@ -84,7 +90,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
if crew_inputs: if crew_inputs:
# Start with the base prompt field # Start with the base prompt field
fields = {} fields = {}
# Add custom fields # Add custom fields
for field_name, field_def in crew_inputs.items(): for field_name, field_def in crew_inputs.items():
if isinstance(field_def, tuple): if isinstance(field_def, tuple):
@@ -92,12 +98,12 @@ class InvokeCrewAIAutomationTool(BaseTool):
else: else:
# Assume it's a Field object, extract type from annotation if available # Assume it's a Field object, extract type from annotation if available
fields[field_name] = (str, field_def) fields[field_name] = (str, field_def)
# Create dynamic model # Create dynamic model
args_schema = create_model('DynamicInvokeCrewAIAutomationInput', **fields) args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields)
else: else:
args_schema = InvokeCrewAIAutomationInput args_schema = InvokeCrewAIAutomationInput
# Initialize the parent class with proper field values # Initialize the parent class with proper field values
super().__init__( super().__init__(
name=crew_name, name=crew_name,
@@ -105,7 +111,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
args_schema=args_schema, args_schema=args_schema,
crew_api_url=crew_api_url, crew_api_url=crew_api_url,
crew_bearer_token=crew_bearer_token, crew_bearer_token=crew_bearer_token,
max_polling_time=max_polling_time max_polling_time=max_polling_time,
) )
def _kickoff_crew(self, inputs: dict[str, Any]) -> dict[str, Any]: def _kickoff_crew(self, inputs: dict[str, Any]) -> dict[str, Any]:
@@ -125,8 +131,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
}, },
json={"inputs": inputs}, json={"inputs": inputs},
) )
response_json = response.json() return response.json()
return response_json
def _get_crew_status(self, crew_id: str) -> dict[str, Any]: def _get_crew_status(self, crew_id: str) -> dict[str, Any]:
"""Get the status of a crew task """Get the status of a crew task
@@ -150,27 +155,27 @@ class InvokeCrewAIAutomationTool(BaseTool):
"""Execute the crew invocation tool.""" """Execute the crew invocation tool."""
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
# Start the crew # Start the crew
response = self._kickoff_crew(inputs=kwargs) response = self._kickoff_crew(inputs=kwargs)
if response.get("kickoff_id") is None: if response.get("kickoff_id") is None:
return f"Error: Failed to kickoff crew. Response: {response}" return f"Error: Failed to kickoff crew. Response: {response}"
kickoff_id = response.get("kickoff_id") kickoff_id = response.get("kickoff_id")
# Poll for completion # Poll for completion
for i in range(self.max_polling_time): for i in range(self.max_polling_time):
try: try:
status_response = self._get_crew_status(crew_id=kickoff_id) status_response = self._get_crew_status(crew_id=kickoff_id)
if status_response.get("state", "").lower() == "success": if status_response.get("state", "").lower() == "success":
return status_response.get("result", "No result returned") return status_response.get("result", "No result returned")
elif status_response.get("state", "").lower() == "failed": if status_response.get("state", "").lower() == "failed":
return f"Error: Crew task failed. Response: {status_response}" return f"Error: Crew task failed. Response: {status_response}"
except Exception as e: except Exception as e:
if i == self.max_polling_time - 1: # Last attempt if i == self.max_polling_time - 1: # Last attempt
return f"Error: Failed to get crew status after {self.max_polling_time} attempts. Last error: {e}" return f"Error: Failed to get crew status after {self.max_polling_time} attempts. Last error: {e}"
time.sleep(1) time.sleep(1)
return f"Error: Crew did not complete within {self.max_polling_time} seconds" return f"Error: Crew did not complete within {self.max_polling_time} seconds"

View File

@@ -1,5 +1,3 @@
from typing import Optional, Type
import requests import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -14,16 +12,16 @@ class JinaScrapeWebsiteToolInput(BaseModel):
class JinaScrapeWebsiteTool(BaseTool): class JinaScrapeWebsiteTool(BaseTool):
name: str = "JinaScrapeWebsiteTool" name: str = "JinaScrapeWebsiteTool"
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content." description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput args_schema: type[BaseModel] = JinaScrapeWebsiteToolInput
website_url: Optional[str] = None website_url: str | None = None
api_key: Optional[str] = None api_key: str | None = None
headers: dict = {} headers: dict = {}
def __init__( def __init__(
self, self,
website_url: Optional[str] = None, website_url: str | None = None,
api_key: Optional[str] = None, api_key: str | None = None,
custom_headers: Optional[dict] = None, custom_headers: dict | None = None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -38,7 +36,7 @@ class JinaScrapeWebsiteTool(BaseTool):
if api_key is not None: if api_key is not None:
self.headers["Authorization"] = f"Bearer {api_key}" self.headers["Authorization"] = f"Bearer {api_key}"
def _run(self, website_url: Optional[str] = None) -> str: def _run(self, website_url: str | None = None) -> str:
url = website_url or self.website_url url = website_url or self.website_url
if not url: if not url:
raise ValueError( raise ValueError(

View File

@@ -1,5 +1,3 @@
from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -27,9 +25,9 @@ class JSONSearchTool(RagTool):
description: str = ( description: str = (
"A tool that can be used to semantic search a query from a JSON's content." "A tool that can be used to semantic search a query from a JSON's content."
) )
args_schema: Type[BaseModel] = JSONSearchToolSchema args_schema: type[BaseModel] = JSONSearchToolSchema
def __init__(self, json_path: Optional[str] = None, **kwargs): def __init__(self, json_path: str | None = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if json_path is not None: if json_path is not None:
self.add(json_path) self.add(json_path)
@@ -40,8 +38,12 @@ class JSONSearchTool(RagTool):
def _run( def _run(
self, self,
search_query: str, search_query: str,
json_path: Optional[str] = None, json_path: str | None = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if json_path is not None: if json_path is not None:
self.add(json_path) self.add(json_path)
return super()._run(query=search_query) return super()._run(
query=search_query, similarity_threshold=similarity_threshold, limit=limit
)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, List from typing import Any
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
@@ -20,12 +20,8 @@ class LinkupSearchTool(BaseTool):
"Performs an API call to Linkup to retrieve contextual information." "Performs an API call to Linkup to retrieve contextual information."
) )
_client: LinkupClient = PrivateAttr() # type: ignore _client: LinkupClient = PrivateAttr() # type: ignore
description: str = ( package_dependencies: list[str] = ["linkup-sdk"]
"Performs an API call to Linkup to retrieve contextual information." env_vars: list[EnvVar] = [
)
_client: LinkupClient = PrivateAttr() # type: ignore
package_dependencies: List[str] = ["linkup-sdk"]
env_vars: List[EnvVar] = [
EnvVar(name="LINKUP_API_KEY", description="API key for Linkup", required=True), EnvVar(name="LINKUP_API_KEY", description="API key for Linkup", required=True),
] ]

Some files were not shown because too many files have changed in this diff Show More