mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
feat: merge latest changes from crewAI-tools main into packages/tools
- Merged upstream changes from crewAI-tools main branch - Resolved conflicts due to monorepo structure (crewai_tools -> src/crewai_tools) - Removed deprecated embedchain adapters - Added new RAG loaders and crewai_rag_adapter - Consolidated dependencies in pyproject.toml Fixed critical linting issues: - Added ClassVar annotations for mutable class attributes - Added timeouts to requests calls (30s default) - Fixed exception handling with proper 'from' clauses - Added noqa comments for public API functions (backward compatibility) - Updated ruff config to ignore expected patterns: - F401 in __init__ files (intentional re-exports) - S101 in test files (assertions are expected) - S607 for subprocess calls (uv/pip commands are safe) Remaining issues are from upstream code and will be addressed in separate PRs.
This commit is contained in:
@@ -9,12 +9,18 @@ authors = [
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai-core",
|
||||
"click>=8.1.8",
|
||||
"lancedb>=0.5.4",
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.31.0",
|
||||
"docker>=7.1.0",
|
||||
"tiktoken>=0.8.0",
|
||||
"stagehand>=0.4.1",
|
||||
"portalocker==2.7.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -24,9 +30,6 @@ Documentation = "https://docs.crewai.com"
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
embedchain = [
|
||||
"embedchain>=0.1.114",
|
||||
]
|
||||
scrapfly-sdk = [
|
||||
"scrapfly-sdk>=0.8.19",
|
||||
]
|
||||
@@ -124,6 +127,12 @@ oxylabs = [
|
||||
mongodb = [
|
||||
"pymongo>=4.13"
|
||||
]
|
||||
mysql = [
|
||||
"pymysql>=1.1.1"
|
||||
]
|
||||
postgresql = [
|
||||
"psycopg2-binary>=2.9.10"
|
||||
]
|
||||
bedrock = [
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"bedrock-agentcore>=0.1.0",
|
||||
@@ -135,6 +144,9 @@ contextual = [
|
||||
"nest-asyncio>=1.6.0",
|
||||
]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -149,3 +161,12 @@ build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/crewai_tools"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest-asyncio>=0.25.2",
|
||||
"pytest>=8.0.0",
|
||||
"pytest-recording>=0.13.3",
|
||||
"mypy>=1.18.1",
|
||||
"ruff>=0.13.0",
|
||||
]
|
||||
|
||||
@@ -59,6 +59,7 @@ from .tools import (
|
||||
OxylabsAmazonSearchScraperTool,
|
||||
OxylabsGoogleSearchScraperTool,
|
||||
OxylabsUniversalScraperTool,
|
||||
ParallelSearchTool,
|
||||
PatronusEvalTool,
|
||||
PatronusLocalEvaluatorTool,
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
@@ -96,5 +97,4 @@ from .tools import (
|
||||
YoutubeChannelSearchTool,
|
||||
YoutubeVideoSearchTool,
|
||||
ZapierActionTools,
|
||||
ParallelSearchTool,
|
||||
)
|
||||
|
||||
267
packages/tools/src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
267
packages/tools/src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from pydantic import PrivateAttr
|
||||
from typing_extensions import Unpack
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
Supports custom vector database configuration through the config parameter.
|
||||
"""
|
||||
|
||||
collection_name: str = "default"
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
config: RagConfigType | None = None
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize the CrewAI RAG client after model initialization."""
|
||||
if self.config is not None:
|
||||
self._client = create_client(self.config)
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""Query the knowledge base with a question.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Relevant content from the knowledge base
|
||||
"""
|
||||
search_limit = limit if limit is not None else self.limit
|
||||
search_threshold = (
|
||||
similarity_threshold
|
||||
if similarity_threshold is not None
|
||||
else self.similarity_threshold
|
||||
)
|
||||
|
||||
results: list[SearchResult] = self._client.search(
|
||||
collection_name=self.collection_name,
|
||||
query=question,
|
||||
limit=search_limit,
|
||||
score_threshold=search_threshold,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return "No relevant content found."
|
||||
|
||||
contents: list[str] = []
|
||||
for result in results:
|
||||
content: str = result.get("content", "")
|
||||
if content:
|
||||
contents.append(content)
|
||||
|
||||
return "\n\n".join(contents)
|
||||
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
"""
|
||||
import os
|
||||
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.data_types import DataType, DataTypes
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
else:
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
if not os.path.isdir(source_ref):
|
||||
raise ValueError(f"Directory does not exist: {source_ref}")
|
||||
|
||||
# Define binary and non-text file extensions to skip
|
||||
binary_extensions = {
|
||||
".pyc",
|
||||
".pyo",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".bmp",
|
||||
".ico",
|
||||
".svg",
|
||||
".webp",
|
||||
".pdf",
|
||||
".zip",
|
||||
".tar",
|
||||
".gz",
|
||||
".bz2",
|
||||
".7z",
|
||||
".rar",
|
||||
".exe",
|
||||
".dll",
|
||||
".so",
|
||||
".dylib",
|
||||
".bin",
|
||||
".dat",
|
||||
".db",
|
||||
".sqlite",
|
||||
".class",
|
||||
".jar",
|
||||
".war",
|
||||
".ear",
|
||||
}
|
||||
|
||||
for root, dirs, files in os.walk(source_ref):
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
for filename in files:
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
|
||||
# Skip binary files based on extension
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in binary_extensions:
|
||||
continue
|
||||
|
||||
# Skip __pycache__ directories
|
||||
if "__pycache__" in root:
|
||||
continue
|
||||
|
||||
file_path: str = os.path.join(root, filename)
|
||||
try:
|
||||
file_data_type: DataType = DataTypes.from_content(file_path)
|
||||
file_loader = file_data_type.get_loader()
|
||||
file_chunker = file_data_type.get_chunker()
|
||||
|
||||
file_source = SourceContent(file_path)
|
||||
file_result: LoaderResult = file_loader.load(file_source)
|
||||
|
||||
file_chunks = file_chunker.chunk(file_result.content)
|
||||
|
||||
for chunk_idx, file_chunk in enumerate(file_chunks):
|
||||
file_metadata: dict[str, Any] = base_metadata.copy()
|
||||
file_metadata.update(file_result.metadata)
|
||||
file_metadata["data_type"] = str(file_data_type)
|
||||
file_metadata["file_path"] = file_path
|
||||
file_metadata["chunk_index"] = chunk_idx
|
||||
file_metadata["total_chunks"] = len(file_chunks)
|
||||
|
||||
if isinstance(arg, dict):
|
||||
file_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
||||
).hexdigest()
|
||||
|
||||
documents.append(
|
||||
{
|
||||
"doc_id": chunk_id,
|
||||
"content": file_chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(
|
||||
file_metadata
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently skip files that can't be processed
|
||||
continue
|
||||
else:
|
||||
metadata: dict[str, Any] = base_metadata.copy()
|
||||
|
||||
if data_type in [
|
||||
DataType.PDF_FILE,
|
||||
DataType.TEXT_FILE,
|
||||
DataType.DOCX,
|
||||
DataType.CSV,
|
||||
DataType.JSON,
|
||||
DataType.XML,
|
||||
DataType.MDX,
|
||||
]:
|
||||
if not os.path.isfile(source_ref):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
|
||||
loader = data_type.get_loader()
|
||||
chunker = data_type.get_chunker()
|
||||
|
||||
source_content = SourceContent(source_ref)
|
||||
loader_result: LoaderResult = loader.load(source_content)
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_metadata: dict[str, Any] = metadata.copy()
|
||||
chunk_metadata.update(loader_result.metadata)
|
||||
chunk_metadata["data_type"] = str(data_type)
|
||||
chunk_metadata["chunk_index"] = i
|
||||
chunk_metadata["total_chunks"] = len(chunks)
|
||||
chunk_metadata["source"] = source_ref
|
||||
|
||||
if isinstance(arg, dict):
|
||||
chunk_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
||||
).hexdigest()
|
||||
|
||||
documents.append(
|
||||
{
|
||||
"doc_id": chunk_id,
|
||||
"content": chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(chunk_metadata),
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
self._client.add_documents(
|
||||
collection_name=self.collection_name, documents=documents
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import warnings
|
||||
from typing import List, Any, Dict, Literal, Optional, Union, get_origin, Type, cast
|
||||
from pydantic import Field, create_model
|
||||
from crewai.tools import BaseTool
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union, cast, get_origin
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import Field, create_model
|
||||
|
||||
|
||||
def get_enterprise_api_base_url() -> str:
|
||||
@@ -13,6 +14,7 @@ def get_enterprise_api_base_url() -> str:
|
||||
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
|
||||
return f"{base_url}/crewai_plus/api/v1/integrations"
|
||||
|
||||
|
||||
ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url()
|
||||
|
||||
|
||||
@@ -23,7 +25,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
default="", description="The enterprise action token"
|
||||
)
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
action_schema: Dict[str, Any] = Field(
|
||||
action_schema: dict[str, Any] = Field(
|
||||
default={}, description="The schema of the action"
|
||||
)
|
||||
enterprise_api_base_url: str = Field(
|
||||
@@ -36,8 +38,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
description: str,
|
||||
enterprise_action_token: str,
|
||||
action_name: str,
|
||||
action_schema: Dict[str, Any],
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
action_schema: dict[str, Any],
|
||||
enterprise_api_base_url: str | None = None,
|
||||
):
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(name)
|
||||
@@ -86,7 +88,9 @@ class EnterpriseActionTool(BaseTool):
|
||||
self.enterprise_action_token = enterprise_action_token
|
||||
self.action_name = action_name
|
||||
self.action_schema = action_schema
|
||||
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
self.enterprise_api_base_url = (
|
||||
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
)
|
||||
|
||||
def _sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names to create proper Python class names."""
|
||||
@@ -95,8 +99,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
return "".join(word.capitalize() for word in parts if word)
|
||||
|
||||
def _extract_schema_info(
|
||||
self, action_schema: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], List[str]]:
|
||||
self, action_schema: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Extract schema properties and required fields from action schema."""
|
||||
schema_props = (
|
||||
action_schema.get("function", {})
|
||||
@@ -108,7 +112,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
"""Process a JSON schema and return appropriate Python type."""
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
@@ -118,7 +122,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
if non_null_types:
|
||||
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||
return Optional[base_type] if is_nullable else base_type
|
||||
return cast(Type[Any], Optional[str])
|
||||
return cast(type[Any], Optional[str])
|
||||
|
||||
if "oneOf" in schema:
|
||||
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||
@@ -137,14 +141,16 @@ class EnterpriseActionTool(BaseTool):
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return List[item_type]
|
||||
return list[item_type]
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]:
|
||||
def _create_nested_model(
|
||||
self, schema: dict[str, Any], model_name: str
|
||||
) -> type[Any]:
|
||||
"""Create a nested Pydantic model for complex objects."""
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
@@ -183,21 +189,19 @@ class EnterpriseActionTool(BaseTool):
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: Type[Any], is_required: bool, description: str
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
"""Create Pydantic field definition based on type and requirement."""
|
||||
if is_required:
|
||||
return (field_type, Field(description=description))
|
||||
else:
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
else:
|
||||
return (
|
||||
Optional[field_type],
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
return (
|
||||
Optional[field_type],
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
|
||||
"""Map basic JSON schema types to Python types."""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
@@ -210,7 +214,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
def _get_required_nullable_fields(self) -> List[str]:
|
||||
def _get_required_nullable_fields(self) -> list[str]:
|
||||
"""Get a list of required nullable fields from the action schema."""
|
||||
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||
|
||||
@@ -222,7 +226,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
|
||||
return required_nullable_fields
|
||||
|
||||
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
|
||||
"""Check if a schema represents a nullable type."""
|
||||
if "anyOf" in schema:
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
@@ -242,8 +246,9 @@ class EnterpriseActionTool(BaseTool):
|
||||
if field_name not in cleaned_kwargs:
|
||||
cleaned_kwargs[field_name] = None
|
||||
|
||||
|
||||
api_url = f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
|
||||
api_url = (
|
||||
f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.enterprise_action_token}",
|
||||
"Content-Type": "application/json",
|
||||
@@ -262,7 +267,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing action {self.action_name}: {str(e)}"
|
||||
return f"Error executing action {self.action_name}: {e!s}"
|
||||
|
||||
|
||||
class EnterpriseActionKitToolAdapter:
|
||||
@@ -271,15 +276,17 @@ class EnterpriseActionKitToolAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
enterprise_action_token: str,
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
enterprise_api_base_url: str | None = None,
|
||||
):
|
||||
"""Initialize the adapter with an enterprise action token."""
|
||||
self._set_enterprise_action_token(enterprise_action_token)
|
||||
self._actions_schema = {}
|
||||
self._tools = None
|
||||
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
self.enterprise_api_base_url = (
|
||||
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
)
|
||||
|
||||
def tools(self) -> List[BaseTool]:
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Get the list of tools created from enterprise actions."""
|
||||
if self._tools is None:
|
||||
self._fetch_actions()
|
||||
@@ -289,13 +296,10 @@ class EnterpriseActionKitToolAdapter:
|
||||
def _fetch_actions(self):
|
||||
"""Fetch available actions from the API."""
|
||||
try:
|
||||
|
||||
actions_url = f"{self.enterprise_api_base_url}/actions"
|
||||
headers = {"Authorization": f"Bearer {self.enterprise_action_token}"}
|
||||
|
||||
response = requests.get(
|
||||
actions_url, headers=headers, timeout=30
|
||||
)
|
||||
response = requests.get(actions_url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.json()
|
||||
@@ -306,7 +310,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
parsed_schema = {}
|
||||
action_categories = raw_data["actions"]
|
||||
|
||||
for integration_type, action_list in action_categories.items():
|
||||
for action_list in action_categories.values():
|
||||
if isinstance(action_list, list):
|
||||
for action in action_list:
|
||||
action_name = action.get("name")
|
||||
@@ -314,8 +318,10 @@ class EnterpriseActionKitToolAdapter:
|
||||
action_schema = {
|
||||
"function": {
|
||||
"name": action_name,
|
||||
"description": action.get("description", f"Execute {action_name}"),
|
||||
"parameters": action.get("parameters", {})
|
||||
"description": action.get(
|
||||
"description", f"Execute {action_name}"
|
||||
),
|
||||
"parameters": action.get("parameters", {}),
|
||||
}
|
||||
}
|
||||
parsed_schema[action_name] = action_schema
|
||||
@@ -329,8 +335,8 @@ class EnterpriseActionKitToolAdapter:
|
||||
traceback.print_exc()
|
||||
|
||||
def _generate_detailed_description(
|
||||
self, schema: Dict[str, Any], indent: int = 0
|
||||
) -> List[str]:
|
||||
self, schema: dict[str, Any], indent: int = 0
|
||||
) -> list[str]:
|
||||
"""Generate detailed description for nested schema structures."""
|
||||
descriptions = []
|
||||
indent_str = " " * indent
|
||||
@@ -407,15 +413,17 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
self._tools = tools
|
||||
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]):
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: str | None):
|
||||
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
|
||||
warnings.warn(
|
||||
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
token = enterprise_action_token or os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN")
|
||||
token = enterprise_action_token or os.environ.get(
|
||||
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
|
||||
)
|
||||
|
||||
self.enterprise_action_token = token
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from lancedb import DBConnection as LanceDBConnection
|
||||
from lancedb import connect as lancedb_connect
|
||||
from lancedb.table import Table as LanceDBTable
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
client = OpenAIClient()
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
"""
|
||||
MCPServer for CrewAI.
|
||||
|
||||
@@ -103,8 +104,8 @@ class MCPServerAdapter:
|
||||
try:
|
||||
subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True)
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install mcp package")
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install mcp package") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"`mcp` package not found, please run `uv add crewai-tools[mcp]`"
|
||||
@@ -112,7 +113,9 @@ class MCPServerAdapter:
|
||||
|
||||
try:
|
||||
self._serverparams = serverparams
|
||||
self._adapter = MCPAdapt(self._serverparams, CrewAIAdapter(), connect_timeout)
|
||||
self._adapter = MCPAdapt(
|
||||
self._serverparams, CrewAIAdapter(), connect_timeout
|
||||
)
|
||||
self.start()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class PDFEmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
src: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
where = (
|
||||
{"app_id": self.embedchain_app.config.id, "source": self.src}
|
||||
if self.src
|
||||
else None
|
||||
)
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize), where=where
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.src = args[0] if args else None
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.core import RAG
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
@@ -8,26 +8,23 @@ class RAGAdapter(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "crewai_knowledge_base",
|
||||
persist_directory: Optional[str] = None,
|
||||
persist_directory: str | None = None,
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
top_k: int = 5,
|
||||
embedding_api_key: Optional[str] = None,
|
||||
**embedding_kwargs
|
||||
embedding_api_key: str | None = None,
|
||||
**embedding_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare embedding configuration
|
||||
embedding_config = {
|
||||
"api_key": embedding_api_key,
|
||||
**embedding_kwargs
|
||||
}
|
||||
embedding_config = {"api_key": embedding_api_key, **embedding_kwargs}
|
||||
|
||||
self._adapter = RAG(
|
||||
collection_name=collection_name,
|
||||
persist_directory=persist_directory,
|
||||
embedding_model=embedding_model,
|
||||
top_k=top_k,
|
||||
embedding_config=embedding_config
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import List, Optional, Union, TypeVar, Generic, Dict, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
T = TypeVar('T', bound=BaseTool)
|
||||
T = TypeVar("T", bound=BaseTool)
|
||||
|
||||
|
||||
class ToolCollection(list, Generic[T]):
|
||||
"""
|
||||
@@ -18,15 +21,15 @@ class ToolCollection(list, Generic[T]):
|
||||
search_tool = tools["search"]
|
||||
"""
|
||||
|
||||
def __init__(self, tools: Optional[List[T]] = None):
|
||||
def __init__(self, tools: list[T] | None = None):
|
||||
super().__init__(tools or [])
|
||||
self._name_cache: Dict[str, T] = {}
|
||||
self._name_cache: dict[str, T] = {}
|
||||
self._build_name_cache()
|
||||
|
||||
def _build_name_cache(self) -> None:
|
||||
self._name_cache = {tool.name.lower(): tool for tool in self}
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> T:
|
||||
def __getitem__(self, key: int | str) -> T:
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key.lower()]
|
||||
return super().__getitem__(key)
|
||||
@@ -35,7 +38,7 @@ class ToolCollection(list, Generic[T]):
|
||||
super().append(tool)
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
def extend(self, tools: List[T]) -> None:
|
||||
def extend(self, tools: list[T]) -> None:
|
||||
super().extend(tools)
|
||||
self._build_name_cache()
|
||||
|
||||
@@ -54,7 +57,7 @@ class ToolCollection(list, Generic[T]):
|
||||
del self._name_cache[tool.name.lower()]
|
||||
return tool
|
||||
|
||||
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]":
|
||||
def filter_by_names(self, names: list[str] | None = None) -> "ToolCollection[T]":
|
||||
if names is None:
|
||||
return self
|
||||
|
||||
@@ -71,4 +74,4 @@ class ToolCollection(list, Generic[T]):
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self._name_cache.clear()
|
||||
self._name_cache.clear()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
@@ -42,7 +41,7 @@ class ZapierActionTool(BaseTool):
|
||||
|
||||
execute_url = f"{ACTIONS_URL}/{self.action_id}/execute/"
|
||||
response = requests.request(
|
||||
"POST", execute_url, headers=headers, json=action_params
|
||||
"POST", execute_url, headers=headers, json=action_params, timeout=30
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -57,7 +56,7 @@ class ZapierActionsAdapter:
|
||||
|
||||
api_key: str
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("ZAPIER_API_KEY")
|
||||
if not self.api_key:
|
||||
logger.error("Zapier Actions API key is required")
|
||||
@@ -67,13 +66,12 @@ class ZapierActionsAdapter:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
response = requests.request("GET", ACTIONS_URL, headers=headers)
|
||||
response = requests.request("GET", ACTIONS_URL, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
return response.json()
|
||||
|
||||
def tools(self) -> List[BaseTool]:
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Convert Zapier actions to BaseTool instances"""
|
||||
actions_response = self.get_zapier_actions()
|
||||
tools = []
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from .s3 import S3ReaderTool, S3WriterTool
|
||||
from .bedrock import (
|
||||
BedrockKBRetrieverTool,
|
||||
BedrockInvokeAgentTool,
|
||||
BedrockKBRetrieverTool,
|
||||
create_browser_toolkit,
|
||||
create_code_interpreter_toolkit,
|
||||
)
|
||||
from .s3 import S3ReaderTool, S3WriterTool
|
||||
|
||||
__all__ = [
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"S3ReaderTool",
|
||||
"S3WriterTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"create_browser_toolkit",
|
||||
"create_code_interpreter_toolkit"
|
||||
"create_code_interpreter_toolkit",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||
from .agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
from .browser import create_browser_toolkit
|
||||
from .code_interpreter import create_code_interpreter_toolkit
|
||||
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||
|
||||
__all__ = [
|
||||
"BedrockKBRetrieverTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"create_browser_toolkit",
|
||||
"create_code_interpreter_toolkit"
|
||||
"create_code_interpreter_toolkit",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import Type, Optional, Dict, Any, List
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
from typing import ClassVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||
@@ -17,29 +16,30 @@ load_dotenv()
|
||||
|
||||
class BedrockInvokeAgentToolInput(BaseModel):
|
||||
"""Input schema for BedrockInvokeAgentTool."""
|
||||
|
||||
query: str = Field(..., description="The query to send to the agent")
|
||||
|
||||
|
||||
class BedrockInvokeAgentTool(BaseTool):
|
||||
name: str = "Bedrock Agent Invoke Tool"
|
||||
description: str = "An agent responsible for policy analysis."
|
||||
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
args_schema: type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
agent_id: str = None
|
||||
agent_alias_id: str = None
|
||||
session_id: str = None
|
||||
enable_trace: bool = False
|
||||
end_session: bool = False
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
package_dependencies: ClassVar[list[str]] = ["boto3"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str = None,
|
||||
agent_alias_id: str = None,
|
||||
session_id: str = None,
|
||||
agent_id: str | None = None,
|
||||
agent_alias_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
enable_trace: bool = False,
|
||||
end_session: bool = False,
|
||||
description: Optional[str] = None,
|
||||
**kwargs
|
||||
description: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
||||
|
||||
@@ -54,9 +54,11 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get values from environment variables if not provided
|
||||
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID')
|
||||
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID')
|
||||
self.session_id = session_id or str(int(time.time())) # Use timestamp as session ID if not provided
|
||||
self.agent_id = agent_id or os.getenv("BEDROCK_AGENT_ID")
|
||||
self.agent_alias_id = agent_alias_id or os.getenv("BEDROCK_AGENT_ALIAS_ID")
|
||||
self.session_id = session_id or str(
|
||||
int(time.time())
|
||||
) # Use timestamp as session ID if not provided
|
||||
self.enable_trace = enable_trace
|
||||
self.end_session = end_session
|
||||
|
||||
@@ -87,20 +89,22 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}") from e
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
except ImportError as e:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`") from e
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-west-2'))
|
||||
region_name=os.getenv(
|
||||
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")
|
||||
),
|
||||
)
|
||||
|
||||
# Format the prompt with current time
|
||||
@@ -119,28 +123,28 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
sessionId=self.session_id,
|
||||
inputText=prompt,
|
||||
enableTrace=self.enable_trace,
|
||||
endSession=self.end_session
|
||||
endSession=self.end_session,
|
||||
)
|
||||
|
||||
# Process the response
|
||||
completion = ""
|
||||
|
||||
# Check if response contains a completion field
|
||||
if 'completion' in response:
|
||||
if "completion" in response:
|
||||
# Process streaming response format
|
||||
for event in response.get('completion', []):
|
||||
if 'chunk' in event and 'bytes' in event['chunk']:
|
||||
chunk_bytes = event['chunk']['bytes']
|
||||
for event in response.get("completion", []):
|
||||
if "chunk" in event and "bytes" in event["chunk"]:
|
||||
chunk_bytes = event["chunk"]["bytes"]
|
||||
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||
completion += chunk_bytes.decode('utf-8')
|
||||
completion += chunk_bytes.decode("utf-8")
|
||||
else:
|
||||
completion += str(chunk_bytes)
|
||||
|
||||
# If no completion found in streaming format, try direct format
|
||||
if not completion and 'chunk' in response and 'bytes' in response['chunk']:
|
||||
chunk_bytes = response['chunk']['bytes']
|
||||
if not completion and "chunk" in response and "bytes" in response["chunk"]:
|
||||
chunk_bytes = response["chunk"]["bytes"]
|
||||
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||
completion = chunk_bytes.decode('utf-8')
|
||||
completion = chunk_bytes.decode("utf-8")
|
||||
else:
|
||||
completion = str(chunk_bytes)
|
||||
|
||||
@@ -148,14 +152,16 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
if not completion:
|
||||
debug_info = {
|
||||
"error": "Could not extract completion from response",
|
||||
"response_keys": list(response.keys())
|
||||
"response_keys": list(response.keys()),
|
||||
}
|
||||
|
||||
# Add more debug info
|
||||
if 'chunk' in response:
|
||||
debug_info["chunk_keys"] = list(response['chunk'].keys())
|
||||
if "chunk" in response:
|
||||
debug_info["chunk_keys"] = list(response["chunk"].keys())
|
||||
|
||||
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
|
||||
raise BedrockAgentError(
|
||||
f"Failed to extract completion: {json.dumps(debug_info, indent=2)}"
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
@@ -164,13 +170,13 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"].get("Code", "Unknown")
|
||||
error_message = e.response["Error"].get("Message", str(e))
|
||||
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}") from e
|
||||
except BedrockAgentError:
|
||||
# Re-raise BedrockAgentError exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BedrockAgentError(f"Unexpected error: {str(e)}")
|
||||
raise BedrockAgentError(f"Unexpected error: {e!s}") from e
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .browser_toolkit import BrowserToolkit, create_browser_toolkit
|
||||
|
||||
__all__ = ["BrowserToolkit", "create_browser_toolkit"]
|
||||
__all__ = ["BrowserToolkit", "create_browser_toolkit"]
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
from playwright.sync_api import Browser as SyncBrowser
|
||||
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,8 +28,8 @@ class BrowserSessionManager:
|
||||
region: AWS region for browser client
|
||||
"""
|
||||
self.region = region
|
||||
self._async_sessions: Dict[str, Tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: Dict[str, Tuple[BrowserClient, SyncBrowser]] = {}
|
||||
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
|
||||
|
||||
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||
"""
|
||||
@@ -75,6 +75,7 @@ class BrowserSessionManager:
|
||||
Exception: If browser session creation fails
|
||||
"""
|
||||
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
@@ -132,6 +133,7 @@ class BrowserSessionManager:
|
||||
Exception: If browser session creation fails
|
||||
"""
|
||||
from bedrock_agentcore.tools.browser_client import BrowserClient
|
||||
|
||||
browser_client = BrowserClient(region=self.region)
|
||||
|
||||
try:
|
||||
@@ -257,4 +259,4 @@ class BrowserSessionManager:
|
||||
for thread_id in sync_thread_ids:
|
||||
self.close_sync_browser(thread_id)
|
||||
|
||||
logger.info("All browser sessions closed")
|
||||
logger.info("All browser sessions closed")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Toolkit for navigating web with AWS browser."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple, Any, Type
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -18,78 +18,100 @@ logger = logging.getLogger(__name__)
|
||||
# Input schemas
|
||||
class NavigateToolInput(BaseModel):
|
||||
"""Input for NavigateTool."""
|
||||
|
||||
url: str = Field(description="URL to navigate to")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class ClickToolInput(BaseModel):
|
||||
"""Input for ClickTool."""
|
||||
|
||||
selector: str = Field(description="CSS selector for the element to click on")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class GetElementsToolInput(BaseModel):
|
||||
"""Input for GetElementsTool."""
|
||||
|
||||
selector: str = Field(description="CSS selector for elements to get")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class ExtractTextToolInput(BaseModel):
|
||||
"""Input for ExtractTextTool."""
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class ExtractHyperlinksToolInput(BaseModel):
|
||||
"""Input for ExtractHyperlinksTool."""
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class NavigateBackToolInput(BaseModel):
|
||||
"""Input for NavigateBackTool."""
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
class CurrentWebPageToolInput(BaseModel):
|
||||
"""Input for CurrentWebPageTool."""
|
||||
thread_id: str = Field(default="default", description="Thread ID for the browser session")
|
||||
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the browser session"
|
||||
)
|
||||
|
||||
|
||||
# Base tool class
|
||||
class BrowserBaseTool(BaseTool):
|
||||
"""Base class for browser tools."""
|
||||
|
||||
|
||||
def __init__(self, session_manager: BrowserSessionManager):
|
||||
"""Initialize with a session manager."""
|
||||
super().__init__()
|
||||
self._session_manager = session_manager
|
||||
|
||||
if self._is_in_asyncio_loop() and hasattr(self, '_arun'):
|
||||
|
||||
if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
|
||||
self._original_run = self._run
|
||||
|
||||
# Override _run to use _arun when in an asyncio loop
|
||||
def patched_run(*args, **kwargs):
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
nest_asyncio.apply(loop)
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._arun(*args, **kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error in patched _run: {str(e)}"
|
||||
return f"Error in patched _run: {e!s}"
|
||||
|
||||
self._run = patched_run
|
||||
|
||||
|
||||
async def get_async_page(self, thread_id: str) -> Any:
|
||||
"""Get or create a page for the specified thread."""
|
||||
browser = await self._session_manager.get_async_browser(thread_id)
|
||||
page = await aget_current_page(browser)
|
||||
return page
|
||||
|
||||
return await aget_current_page(browser)
|
||||
|
||||
def get_sync_page(self, thread_id: str) -> Any:
|
||||
"""Get or create a page for the specified thread."""
|
||||
browser = self._session_manager.get_sync_browser(thread_id)
|
||||
page = get_current_page(browser)
|
||||
return page
|
||||
|
||||
return get_current_page(browser)
|
||||
|
||||
def _is_in_asyncio_loop(self) -> bool:
|
||||
"""Check if we're currently in an asyncio event loop."""
|
||||
try:
|
||||
@@ -105,8 +127,8 @@ class NavigateTool(BrowserBaseTool):
|
||||
|
||||
name: str = "navigate_browser"
|
||||
description: str = "Navigate a browser to the specified URL"
|
||||
args_schema: Type[BaseModel] = NavigateToolInput
|
||||
|
||||
args_schema: type[BaseModel] = NavigateToolInput
|
||||
|
||||
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -123,7 +145,7 @@ class NavigateTool(BrowserBaseTool):
|
||||
status = response.status if response else "unknown"
|
||||
return f"Navigating to {url} returned status code {status}"
|
||||
except Exception as e:
|
||||
return f"Error navigating to {url}: {str(e)}"
|
||||
return f"Error navigating to {url}: {e!s}"
|
||||
|
||||
async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -141,7 +163,7 @@ class NavigateTool(BrowserBaseTool):
|
||||
status = response.status if response else "unknown"
|
||||
return f"Navigating to {url} returned status code {status}"
|
||||
except Exception as e:
|
||||
return f"Error navigating to {url}: {str(e)}"
|
||||
return f"Error navigating to {url}: {e!s}"
|
||||
|
||||
|
||||
class ClickTool(BrowserBaseTool):
|
||||
@@ -149,8 +171,8 @@ class ClickTool(BrowserBaseTool):
|
||||
|
||||
name: str = "click_element"
|
||||
description: str = "Click on an element with the given CSS selector"
|
||||
args_schema: Type[BaseModel] = ClickToolInput
|
||||
|
||||
args_schema: type[BaseModel] = ClickToolInput
|
||||
|
||||
visible_only: bool = True
|
||||
"""Whether to consider only visible elements."""
|
||||
playwright_strict: bool = False
|
||||
@@ -162,7 +184,7 @@ class ClickTool(BrowserBaseTool):
|
||||
if not self.visible_only:
|
||||
return selector
|
||||
return f"{selector} >> visible=1"
|
||||
|
||||
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -172,7 +194,7 @@ class ClickTool(BrowserBaseTool):
|
||||
# Click on the element
|
||||
selector_effective = self._selector_effective(selector=selector)
|
||||
from playwright.sync_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
|
||||
try:
|
||||
page.click(
|
||||
selector_effective,
|
||||
@@ -182,11 +204,11 @@ class ClickTool(BrowserBaseTool):
|
||||
except PlaywrightTimeoutError:
|
||||
return f"Unable to click on element '{selector}'"
|
||||
except Exception as click_error:
|
||||
return f"Unable to click on element '{selector}': {str(click_error)}"
|
||||
|
||||
return f"Unable to click on element '{selector}': {click_error!s}"
|
||||
|
||||
return f"Clicked element '{selector}'"
|
||||
except Exception as e:
|
||||
return f"Error clicking on element: {str(e)}"
|
||||
return f"Error clicking on element: {e!s}"
|
||||
|
||||
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -197,7 +219,7 @@ class ClickTool(BrowserBaseTool):
|
||||
# Click on the element
|
||||
selector_effective = self._selector_effective(selector=selector)
|
||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
|
||||
try:
|
||||
await page.click(
|
||||
selector_effective,
|
||||
@@ -207,19 +229,20 @@ class ClickTool(BrowserBaseTool):
|
||||
except PlaywrightTimeoutError:
|
||||
return f"Unable to click on element '{selector}'"
|
||||
except Exception as click_error:
|
||||
return f"Unable to click on element '{selector}': {str(click_error)}"
|
||||
|
||||
return f"Unable to click on element '{selector}': {click_error!s}"
|
||||
|
||||
return f"Clicked element '{selector}'"
|
||||
except Exception as e:
|
||||
return f"Error clicking on element: {str(e)}"
|
||||
return f"Error clicking on element: {e!s}"
|
||||
|
||||
|
||||
class NavigateBackTool(BrowserBaseTool):
|
||||
"""Tool for navigating back in browser history."""
|
||||
|
||||
name: str = "navigate_back"
|
||||
description: str = "Navigate back to the previous page"
|
||||
args_schema: Type[BaseModel] = NavigateBackToolInput
|
||||
|
||||
args_schema: type[BaseModel] = NavigateBackToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -231,9 +254,9 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
page.go_back()
|
||||
return "Navigated back to the previous page"
|
||||
except Exception as nav_error:
|
||||
return f"Unable to navigate back: {str(nav_error)}"
|
||||
return f"Unable to navigate back: {nav_error!s}"
|
||||
except Exception as e:
|
||||
return f"Error navigating back: {str(e)}"
|
||||
return f"Error navigating back: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -246,17 +269,18 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
await page.go_back()
|
||||
return "Navigated back to the previous page"
|
||||
except Exception as nav_error:
|
||||
return f"Unable to navigate back: {str(nav_error)}"
|
||||
return f"Unable to navigate back: {nav_error!s}"
|
||||
except Exception as e:
|
||||
return f"Error navigating back: {str(e)}"
|
||||
return f"Error navigating back: {e!s}"
|
||||
|
||||
|
||||
class ExtractTextTool(BrowserBaseTool):
|
||||
"""Tool for extracting text from a webpage."""
|
||||
|
||||
name: str = "extract_text"
|
||||
description: str = "Extract all the text on the current webpage"
|
||||
args_schema: Type[BaseModel] = ExtractTextToolInput
|
||||
|
||||
args_schema: type[BaseModel] = ExtractTextToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -268,7 +292,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
@@ -277,7 +301,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
return soup.get_text(separator="\n").strip()
|
||||
except Exception as e:
|
||||
return f"Error extracting text: {str(e)}"
|
||||
return f"Error extracting text: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -290,7 +314,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
@@ -299,15 +323,16 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
return soup.get_text(separator="\n").strip()
|
||||
except Exception as e:
|
||||
return f"Error extracting text: {str(e)}"
|
||||
return f"Error extracting text: {e!s}"
|
||||
|
||||
|
||||
class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
"""Tool for extracting hyperlinks from a webpage."""
|
||||
|
||||
name: str = "extract_hyperlinks"
|
||||
description: str = "Extract all hyperlinks on the current webpage"
|
||||
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
||||
|
||||
args_schema: type[BaseModel] = ExtractHyperlinksToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -319,7 +344,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
|
||||
# Get the current page
|
||||
page = self.get_sync_page(thread_id)
|
||||
|
||||
@@ -330,15 +355,15 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith("http") or href.startswith("https"):
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
|
||||
if not links:
|
||||
return "No hyperlinks found on the current page."
|
||||
|
||||
|
||||
return json.dumps(links, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error extracting hyperlinks: {str(e)}"
|
||||
return f"Error extracting hyperlinks: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -351,7 +376,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
" Please install it with 'pip install beautifulsoup4'."
|
||||
)
|
||||
|
||||
|
||||
# Get the current page
|
||||
page = await self.get_async_page(thread_id)
|
||||
|
||||
@@ -362,23 +387,24 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith("http") or href.startswith("https"):
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
|
||||
if not links:
|
||||
return "No hyperlinks found on the current page."
|
||||
|
||||
|
||||
return json.dumps(links, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error extracting hyperlinks: {str(e)}"
|
||||
return f"Error extracting hyperlinks: {e!s}"
|
||||
|
||||
|
||||
class GetElementsTool(BrowserBaseTool):
|
||||
"""Tool for getting elements from a webpage."""
|
||||
|
||||
name: str = "get_elements"
|
||||
description: str = "Get elements from the webpage using a CSS selector"
|
||||
args_schema: Type[BaseModel] = GetElementsToolInput
|
||||
|
||||
args_schema: type[BaseModel] = GetElementsToolInput
|
||||
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -389,15 +415,15 @@ class GetElementsTool(BrowserBaseTool):
|
||||
elements = page.query_selector_all(selector)
|
||||
if not elements:
|
||||
return f"No elements found with selector '{selector}'"
|
||||
|
||||
|
||||
elements_text = []
|
||||
for i, element in enumerate(elements):
|
||||
text = element.text_content()
|
||||
elements_text.append(f"Element {i+1}: {text.strip()}")
|
||||
|
||||
elements_text.append(f"Element {i + 1}: {text.strip()}")
|
||||
|
||||
return "\n".join(elements_text)
|
||||
except Exception as e:
|
||||
return f"Error getting elements: {str(e)}"
|
||||
return f"Error getting elements: {e!s}"
|
||||
|
||||
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -409,23 +435,24 @@ class GetElementsTool(BrowserBaseTool):
|
||||
elements = await page.query_selector_all(selector)
|
||||
if not elements:
|
||||
return f"No elements found with selector '{selector}'"
|
||||
|
||||
|
||||
elements_text = []
|
||||
for i, element in enumerate(elements):
|
||||
text = await element.text_content()
|
||||
elements_text.append(f"Element {i+1}: {text.strip()}")
|
||||
|
||||
elements_text.append(f"Element {i + 1}: {text.strip()}")
|
||||
|
||||
return "\n".join(elements_text)
|
||||
except Exception as e:
|
||||
return f"Error getting elements: {str(e)}"
|
||||
return f"Error getting elements: {e!s}"
|
||||
|
||||
|
||||
class CurrentWebPageTool(BrowserBaseTool):
|
||||
"""Tool for getting information about the current webpage."""
|
||||
|
||||
name: str = "current_webpage"
|
||||
description: str = "Get information about the current webpage"
|
||||
args_schema: Type[BaseModel] = CurrentWebPageToolInput
|
||||
|
||||
args_schema: type[BaseModel] = CurrentWebPageToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
@@ -437,7 +464,7 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
title = page.title()
|
||||
return f"URL: {url}\nTitle: {title}"
|
||||
except Exception as e:
|
||||
return f"Error getting current webpage info: {str(e)}"
|
||||
return f"Error getting current webpage info: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the async tool."""
|
||||
@@ -450,7 +477,7 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
title = await page.title()
|
||||
return f"URL: {url}\nTitle: {title}"
|
||||
except Exception as e:
|
||||
return f"Error getting current webpage info: {str(e)}"
|
||||
return f"Error getting current webpage info: {e!s}"
|
||||
|
||||
|
||||
class BrowserToolkit:
|
||||
@@ -504,10 +531,10 @@ class BrowserToolkit:
|
||||
"""
|
||||
self.region = region
|
||||
self.session_manager = BrowserSessionManager(region=region)
|
||||
self.tools: List[BaseTool] = []
|
||||
self.tools: list[BaseTool] = []
|
||||
self._nest_current_loop()
|
||||
self._setup_tools()
|
||||
|
||||
|
||||
def _nest_current_loop(self):
|
||||
"""Apply nest_asyncio if we're in an asyncio loop."""
|
||||
try:
|
||||
@@ -515,9 +542,10 @@ class BrowserToolkit:
|
||||
if loop.is_running():
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply(loop)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply nest_asyncio: {str(e)}")
|
||||
logger.warning(f"Failed to apply nest_asyncio: {e!s}")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@@ -530,10 +558,10 @@ class BrowserToolkit:
|
||||
ExtractTextTool(session_manager=self.session_manager),
|
||||
ExtractHyperlinksTool(session_manager=self.session_manager),
|
||||
GetElementsTool(session_manager=self.session_manager),
|
||||
CurrentWebPageTool(session_manager=self.session_manager)
|
||||
CurrentWebPageTool(session_manager=self.session_manager),
|
||||
]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""
|
||||
Get the list of browser tools
|
||||
|
||||
@@ -542,7 +570,7 @@ class BrowserToolkit:
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||
def get_tools_by_name(self) -> dict[str, BaseTool]:
|
||||
"""
|
||||
Get a dictionary of tools mapped by their names
|
||||
|
||||
@@ -555,11 +583,11 @@ class BrowserToolkit:
|
||||
"""Clean up all browser sessions asynchronously"""
|
||||
await self.session_manager.close_all_browsers()
|
||||
logger.info("All browser sessions cleaned up")
|
||||
|
||||
|
||||
def sync_cleanup(self) -> None:
|
||||
"""Clean up all browser sessions from synchronous code"""
|
||||
import asyncio
|
||||
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
@@ -572,7 +600,7 @@ class BrowserToolkit:
|
||||
|
||||
def create_browser_toolkit(
|
||||
region: str = "us-west-2",
|
||||
) -> Tuple[BrowserToolkit, List[BaseTool]]:
|
||||
) -> tuple[BrowserToolkit, list[BaseTool]]:
|
||||
"""
|
||||
Create a BrowserToolkit
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from playwright.sync_api import Page as SyncPage
|
||||
|
||||
|
||||
async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
|
||||
async def aget_current_page(browser: AsyncBrowser | Any) -> AsyncPage:
|
||||
"""
|
||||
Asynchronously get the current page of the browser.
|
||||
Args:
|
||||
@@ -26,7 +26,7 @@ async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
|
||||
return context.pages[-1]
|
||||
|
||||
|
||||
def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage:
|
||||
def get_current_page(browser: SyncBrowser | Any) -> SyncPage:
|
||||
"""
|
||||
Get the current page of the browser.
|
||||
Args:
|
||||
@@ -40,4 +40,4 @@ def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage:
|
||||
context = browser.contexts[0]
|
||||
if not context.pages:
|
||||
return context.new_page()
|
||||
return context.pages[-1]
|
||||
return context.pages[-1]
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from .code_interpreter_toolkit import CodeInterpreterToolkit, create_code_interpreter_toolkit
|
||||
from .code_interpreter_toolkit import (
|
||||
CodeInterpreterToolkit,
|
||||
create_code_interpreter_toolkit,
|
||||
)
|
||||
|
||||
__all__ = ["CodeInterpreterToolkit", "create_code_interpreter_toolkit"]
|
||||
__all__ = ["CodeInterpreterToolkit", "create_code_interpreter_toolkit"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Toolkit for working with AWS Bedrock Code Interpreter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Optional, Type, Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -39,124 +40,184 @@ def extract_output_from_stream(response):
|
||||
output.append(f"==== File: {file_path} ====\n{file_content}\n")
|
||||
else:
|
||||
output.append(json.dumps(resource))
|
||||
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
# Input schemas
|
||||
class ExecuteCodeInput(BaseModel):
|
||||
"""Input for ExecuteCode."""
|
||||
|
||||
code: str = Field(description="The code to execute")
|
||||
language: str = Field(default="python", description="The programming language of the code")
|
||||
clear_context: bool = Field(default=False, description="Whether to clear execution context")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
language: str = Field(
|
||||
default="python", description="The programming language of the code"
|
||||
)
|
||||
clear_context: bool = Field(
|
||||
default=False, description="Whether to clear execution context"
|
||||
)
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""Input for ExecuteCommand."""
|
||||
|
||||
command: str = Field(description="The command to execute")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class ReadFilesInput(BaseModel):
|
||||
"""Input for ReadFiles."""
|
||||
paths: List[str] = Field(description="List of file paths to read")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
|
||||
paths: list[str] = Field(description="List of file paths to read")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class ListFilesInput(BaseModel):
|
||||
"""Input for ListFiles."""
|
||||
|
||||
directory_path: str = Field(default="", description="Path to the directory to list")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class DeleteFilesInput(BaseModel):
|
||||
"""Input for DeleteFiles."""
|
||||
paths: List[str] = Field(description="List of file paths to delete")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
|
||||
paths: list[str] = Field(description="List of file paths to delete")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class WriteFilesInput(BaseModel):
|
||||
"""Input for WriteFiles."""
|
||||
files: List[Dict[str, str]] = Field(description="List of dictionaries with path and text fields")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
|
||||
files: list[dict[str, str]] = Field(
|
||||
description="List of dictionaries with path and text fields"
|
||||
)
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class StartCommandInput(BaseModel):
|
||||
"""Input for StartCommand."""
|
||||
|
||||
command: str = Field(description="The command to execute asynchronously")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class GetTaskInput(BaseModel):
|
||||
"""Input for GetTask."""
|
||||
|
||||
task_id: str = Field(description="The ID of the task to check")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
class StopTaskInput(BaseModel):
|
||||
"""Input for StopTask."""
|
||||
|
||||
task_id: str = Field(description="The ID of the task to stop")
|
||||
thread_id: str = Field(default="default", description="Thread ID for the code interpreter session")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
|
||||
|
||||
# Tool classes
|
||||
class ExecuteCodeTool(BaseTool):
|
||||
"""Tool for executing code in various languages."""
|
||||
|
||||
name: str = "execute_code"
|
||||
description: str = "Execute code in various languages (primarily Python)"
|
||||
args_schema: Type[BaseModel] = ExecuteCodeInput
|
||||
args_schema: type[BaseModel] = ExecuteCodeInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, code: str, language: str = "python", clear_context: bool = False, thread_id: str = "default") -> str:
|
||||
|
||||
def _run(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
clear_context: bool = False,
|
||||
thread_id: str = "default",
|
||||
) -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Execute code
|
||||
response = code_interpreter.invoke(
|
||||
method="executeCode",
|
||||
params={"code": code, "language": language, "clearContext": clear_context},
|
||||
params={
|
||||
"code": code,
|
||||
"language": language,
|
||||
"clearContext": clear_context,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error executing code: {str(e)}"
|
||||
|
||||
async def _arun(self, code: str, language: str = "python", clear_context: bool = False, thread_id: str = "default") -> str:
|
||||
return f"Error executing code: {e!s}"
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
clear_context: bool = False,
|
||||
thread_id: str = "default",
|
||||
) -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(code=code, language=language, clear_context=clear_context, thread_id=thread_id)
|
||||
return self._run(
|
||||
code=code,
|
||||
language=language,
|
||||
clear_context=clear_context,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(BaseTool):
|
||||
"""Tool for running shell commands in the code interpreter environment."""
|
||||
|
||||
name: str = "execute_command"
|
||||
description: str = "Run shell commands in the code interpreter environment"
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
args_schema: type[BaseModel] = ExecuteCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Execute command
|
||||
response = code_interpreter.invoke(
|
||||
method="executeCommand", params={"command": command}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
return f"Error executing command: {e!s}"
|
||||
|
||||
async def _arun(self, command: str, thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(command=command, thread_id=thread_id)
|
||||
@@ -164,57 +225,65 @@ class ExecuteCommandTool(BaseTool):
|
||||
|
||||
class ReadFilesTool(BaseTool):
|
||||
"""Tool for reading content of files in the environment."""
|
||||
|
||||
name: str = "read_files"
|
||||
description: str = "Read content of files in the environment"
|
||||
args_schema: Type[BaseModel] = ReadFilesInput
|
||||
args_schema: type[BaseModel] = ReadFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Read files
|
||||
response = code_interpreter.invoke(method="readFiles", params={"paths": paths})
|
||||
|
||||
response = code_interpreter.invoke(
|
||||
method="readFiles", params={"paths": paths}
|
||||
)
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error reading files: {str(e)}"
|
||||
|
||||
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
return f"Error reading files: {e!s}"
|
||||
|
||||
async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(paths=paths, thread_id=thread_id)
|
||||
|
||||
|
||||
class ListFilesTool(BaseTool):
|
||||
"""Tool for listing files in directories in the environment."""
|
||||
|
||||
name: str = "list_files"
|
||||
description: str = "List files in directories in the environment"
|
||||
args_schema: Type[BaseModel] = ListFilesInput
|
||||
args_schema: type[BaseModel] = ListFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
|
||||
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# List files
|
||||
response = code_interpreter.invoke(
|
||||
method="listFiles", params={"directoryPath": directory_path}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error listing files: {str(e)}"
|
||||
|
||||
return f"Error listing files: {e!s}"
|
||||
|
||||
async def _arun(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(directory_path=directory_path, thread_id=thread_id)
|
||||
@@ -222,89 +291,100 @@ class ListFilesTool(BaseTool):
|
||||
|
||||
class DeleteFilesTool(BaseTool):
|
||||
"""Tool for removing files from the environment."""
|
||||
|
||||
name: str = "delete_files"
|
||||
description: str = "Remove files from the environment"
|
||||
args_schema: Type[BaseModel] = DeleteFilesInput
|
||||
args_schema: type[BaseModel] = DeleteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Remove files
|
||||
response = code_interpreter.invoke(
|
||||
method="removeFiles", params={"paths": paths}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error deleting files: {str(e)}"
|
||||
|
||||
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
return f"Error deleting files: {e!s}"
|
||||
|
||||
async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(paths=paths, thread_id=thread_id)
|
||||
|
||||
|
||||
class WriteFilesTool(BaseTool):
|
||||
"""Tool for creating or updating files in the environment."""
|
||||
|
||||
name: str = "write_files"
|
||||
description: str = "Create or update files in the environment"
|
||||
args_schema: Type[BaseModel] = WriteFilesInput
|
||||
args_schema: type[BaseModel] = WriteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, files: List[Dict[str, str]], thread_id: str = "default") -> str:
|
||||
|
||||
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Write files
|
||||
response = code_interpreter.invoke(
|
||||
method="writeFiles", params={"content": files}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error writing files: {str(e)}"
|
||||
|
||||
async def _arun(self, files: List[Dict[str, str]], thread_id: str = "default") -> str:
|
||||
return f"Error writing files: {e!s}"
|
||||
|
||||
async def _arun(
|
||||
self, files: list[dict[str, str]], thread_id: str = "default"
|
||||
) -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(files=files, thread_id=thread_id)
|
||||
|
||||
|
||||
class StartCommandTool(BaseTool):
|
||||
"""Tool for starting long-running commands asynchronously."""
|
||||
|
||||
name: str = "start_command_execution"
|
||||
description: str = "Start long-running commands asynchronously"
|
||||
args_schema: Type[BaseModel] = StartCommandInput
|
||||
args_schema: type[BaseModel] = StartCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Start command execution
|
||||
response = code_interpreter.invoke(
|
||||
method="startCommandExecution", params={"command": command}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error starting command: {str(e)}"
|
||||
|
||||
return f"Error starting command: {e!s}"
|
||||
|
||||
async def _arun(self, command: str, thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(command=command, thread_id=thread_id)
|
||||
@@ -312,27 +392,32 @@ class StartCommandTool(BaseTool):
|
||||
|
||||
class GetTaskTool(BaseTool):
|
||||
"""Tool for checking status of async tasks."""
|
||||
|
||||
name: str = "get_task"
|
||||
description: str = "Check status of async tasks"
|
||||
args_schema: Type[BaseModel] = GetTaskInput
|
||||
args_schema: type[BaseModel] = GetTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Get task status
|
||||
response = code_interpreter.invoke(method="getTask", params={"taskId": task_id})
|
||||
|
||||
response = code_interpreter.invoke(
|
||||
method="getTask", params={"taskId": task_id}
|
||||
)
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error getting task status: {str(e)}"
|
||||
|
||||
return f"Error getting task status: {e!s}"
|
||||
|
||||
async def _arun(self, task_id: str, thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(task_id=task_id, thread_id=thread_id)
|
||||
@@ -340,29 +425,32 @@ class GetTaskTool(BaseTool):
|
||||
|
||||
class StopTaskTool(BaseTool):
|
||||
"""Tool for stopping running tasks."""
|
||||
|
||||
name: str = "stop_task"
|
||||
description: str = "Stop running tasks"
|
||||
args_schema: Type[BaseModel] = StopTaskInput
|
||||
args_schema: type[BaseModel] = StopTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(thread_id=thread_id)
|
||||
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
thread_id=thread_id
|
||||
)
|
||||
|
||||
# Stop task
|
||||
response = code_interpreter.invoke(
|
||||
method="stopTask", params={"taskId": task_id}
|
||||
)
|
||||
|
||||
|
||||
return extract_output_from_stream(response)
|
||||
except Exception as e:
|
||||
return f"Error stopping task: {str(e)}"
|
||||
|
||||
return f"Error stopping task: {e!s}"
|
||||
|
||||
async def _arun(self, task_id: str, thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(task_id=task_id, thread_id=thread_id)
|
||||
@@ -429,8 +517,8 @@ class CodeInterpreterToolkit:
|
||||
region: AWS region for the code interpreter
|
||||
"""
|
||||
self.region = region
|
||||
self._code_interpreters: Dict[str, CodeInterpreter] = {}
|
||||
self.tools: List[BaseTool] = []
|
||||
self._code_interpreters: dict[str, CodeInterpreter] = {}
|
||||
self.tools: list[BaseTool] = []
|
||||
self._setup_tools()
|
||||
|
||||
def _setup_tools(self) -> None:
|
||||
@@ -444,17 +532,15 @@ class CodeInterpreterToolkit:
|
||||
WriteFilesTool(self),
|
||||
StartCommandTool(self),
|
||||
GetTaskTool(self),
|
||||
StopTaskTool(self)
|
||||
StopTaskTool(self),
|
||||
]
|
||||
|
||||
def _get_or_create_interpreter(
|
||||
self, thread_id: str = "default"
|
||||
) -> CodeInterpreter:
|
||||
def _get_or_create_interpreter(self, thread_id: str = "default") -> CodeInterpreter:
|
||||
"""Get or create a code interpreter for the specified thread.
|
||||
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID for the code interpreter session
|
||||
|
||||
|
||||
Returns:
|
||||
CodeInterpreter instance
|
||||
"""
|
||||
@@ -463,6 +549,7 @@ class CodeInterpreterToolkit:
|
||||
|
||||
# Create a new code interpreter for this thread
|
||||
from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter
|
||||
|
||||
code_interpreter = CodeInterpreter(region=self.region)
|
||||
code_interpreter.start()
|
||||
logger.info(
|
||||
@@ -473,8 +560,7 @@ class CodeInterpreterToolkit:
|
||||
self._code_interpreters[thread_id] = code_interpreter
|
||||
return code_interpreter
|
||||
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""
|
||||
Get the list of code interpreter tools
|
||||
|
||||
@@ -483,7 +569,7 @@ class CodeInterpreterToolkit:
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||
def get_tools_by_name(self) -> dict[str, BaseTool]:
|
||||
"""
|
||||
Get a dictionary of tools mapped by their names
|
||||
|
||||
@@ -492,9 +578,9 @@ class CodeInterpreterToolkit:
|
||||
"""
|
||||
return {tool.name: tool for tool in self.tools}
|
||||
|
||||
async def cleanup(self, thread_id: Optional[str] = None) -> None:
|
||||
async def cleanup(self, thread_id: str | None = None) -> None:
|
||||
"""Clean up resources
|
||||
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
|
||||
"""
|
||||
@@ -521,14 +607,14 @@ class CodeInterpreterToolkit:
|
||||
logger.warning(
|
||||
f"Error stopping code interpreter for thread {tid}: {e}"
|
||||
)
|
||||
|
||||
|
||||
self._code_interpreters = {}
|
||||
logger.info("All code interpreter sessions cleaned up")
|
||||
|
||||
|
||||
def create_code_interpreter_toolkit(
|
||||
region: str = "us-west-2",
|
||||
) -> Tuple[CodeInterpreterToolkit, List[BaseTool]]:
|
||||
) -> tuple[CodeInterpreterToolkit, list[BaseTool]]:
|
||||
"""
|
||||
Create a CodeInterpreterToolkit
|
||||
|
||||
@@ -540,4 +626,4 @@ def create_code_interpreter_toolkit(
|
||||
"""
|
||||
toolkit = CodeInterpreterToolkit(region=region)
|
||||
tools = toolkit.get_tools()
|
||||
return toolkit, tools
|
||||
return toolkit, tools
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Custom exceptions for AWS Bedrock integration."""
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Base exception for Bedrock-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class BedrockAgentError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Agent operations."""
|
||||
pass
|
||||
|
||||
|
||||
class BedrockKnowledgeBaseError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Knowledge Base operations."""
|
||||
pass
|
||||
|
||||
|
||||
class BedrockValidationError(BedrockError):
|
||||
"""Exception raised for validation errors in Bedrock operations."""
|
||||
pass
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import Type, Optional, List, Dict, Any
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||
@@ -14,28 +14,33 @@ load_dotenv()
|
||||
|
||||
class BedrockKBRetrieverToolInput(BaseModel):
|
||||
"""Input schema for BedrockKBRetrieverTool."""
|
||||
query: str = Field(..., description="The query to retrieve information from the knowledge base")
|
||||
|
||||
query: str = Field(
|
||||
..., description="The query to retrieve information from the knowledge base"
|
||||
)
|
||||
|
||||
|
||||
class BedrockKBRetrieverTool(BaseTool):
|
||||
name: str = "Bedrock Knowledge Base Retriever Tool"
|
||||
description: str = "Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
description: str = (
|
||||
"Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||
)
|
||||
args_schema: type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
knowledge_base_id: str = None
|
||||
number_of_results: Optional[int] = 5
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None
|
||||
next_token: Optional[str] = None
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
number_of_results: int | None = 5
|
||||
retrieval_configuration: dict[str, Any] | None = None
|
||||
guardrail_configuration: dict[str, Any] | None = None
|
||||
next_token: str | None = None
|
||||
package_dependencies: list[str] = ["boto3"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_base_id: str = None,
|
||||
number_of_results: Optional[int] = 5,
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None,
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None,
|
||||
next_token: Optional[str] = None,
|
||||
**kwargs
|
||||
knowledge_base_id: str | None = None,
|
||||
number_of_results: int | None = 5,
|
||||
retrieval_configuration: dict[str, Any] | None = None,
|
||||
guardrail_configuration: dict[str, Any] | None = None,
|
||||
next_token: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||
|
||||
@@ -49,7 +54,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get knowledge_base_id from environment variable if not provided
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID")
|
||||
self.number_of_results = number_of_results
|
||||
self.guardrail_configuration = guardrail_configuration
|
||||
self.next_token = next_token
|
||||
@@ -66,7 +71,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Update the description to include the knowledge base details
|
||||
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
||||
|
||||
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||
def _build_retrieval_configuration(self) -> dict[str, Any]:
|
||||
"""Build the retrieval configuration based on provided parameters.
|
||||
|
||||
Returns:
|
||||
@@ -89,17 +94,23 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if not isinstance(self.knowledge_base_id, str):
|
||||
raise BedrockValidationError("knowledge_base_id must be a string")
|
||||
if len(self.knowledge_base_id) > 10:
|
||||
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
|
||||
raise BedrockValidationError(
|
||||
"knowledge_base_id must be 10 characters or less"
|
||||
)
|
||||
if not all(c.isalnum() for c in self.knowledge_base_id):
|
||||
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
|
||||
raise BedrockValidationError(
|
||||
"knowledge_base_id must contain only alphanumeric characters"
|
||||
)
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token:
|
||||
if not isinstance(self.next_token, str):
|
||||
raise BedrockValidationError("next_token must be a string")
|
||||
if len(self.next_token) < 1 or len(self.next_token) > 2048:
|
||||
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
|
||||
if ' ' in self.next_token:
|
||||
raise BedrockValidationError(
|
||||
"next_token must be between 1 and 2048 characters"
|
||||
)
|
||||
if " " in self.next_token:
|
||||
raise BedrockValidationError("next_token cannot contain spaces")
|
||||
|
||||
# Validate number_of_results if provided
|
||||
@@ -107,12 +118,14 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if not isinstance(self.number_of_results, int):
|
||||
raise BedrockValidationError("number_of_results must be an integer")
|
||||
if self.number_of_results < 1:
|
||||
raise BedrockValidationError("number_of_results must be greater than 0")
|
||||
raise BedrockValidationError(
|
||||
"number_of_results must be greater than 0"
|
||||
)
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}")
|
||||
|
||||
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _process_retrieval_result(self, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||
|
||||
Args:
|
||||
@@ -122,57 +135,57 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
Dict[str, Any]: Processed result with standardized format
|
||||
"""
|
||||
# Extract content
|
||||
content_obj = result.get('content', {})
|
||||
content = content_obj.get('text', '')
|
||||
content_type = content_obj.get('type', 'text')
|
||||
content_obj = result.get("content", {})
|
||||
content = content_obj.get("text", "")
|
||||
content_type = content_obj.get("type", "text")
|
||||
|
||||
# Extract location information
|
||||
location = result.get('location', {})
|
||||
location_type = location.get('type', 'unknown')
|
||||
location = result.get("location", {})
|
||||
location_type = location.get("type", "unknown")
|
||||
source_uri = None
|
||||
|
||||
# Map for location types and their URI fields
|
||||
location_mapping = {
|
||||
's3Location': {'field': 'uri', 'type': 'S3'},
|
||||
'confluenceLocation': {'field': 'url', 'type': 'Confluence'},
|
||||
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'},
|
||||
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'},
|
||||
'webLocation': {'field': 'url', 'type': 'Web'},
|
||||
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'},
|
||||
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
||||
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
||||
"s3Location": {"field": "uri", "type": "S3"},
|
||||
"confluenceLocation": {"field": "url", "type": "Confluence"},
|
||||
"salesforceLocation": {"field": "url", "type": "Salesforce"},
|
||||
"sharePointLocation": {"field": "url", "type": "SharePoint"},
|
||||
"webLocation": {"field": "url", "type": "Web"},
|
||||
"customDocumentLocation": {"field": "id", "type": "CustomDocument"},
|
||||
"kendraDocumentLocation": {"field": "uri", "type": "KendraDocument"},
|
||||
"sqlLocation": {"field": "query", "type": "SQL"},
|
||||
}
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
source_uri = location[loc_key].get(config['field'])
|
||||
if not location_type or location_type == 'unknown':
|
||||
location_type = config['type']
|
||||
source_uri = location[loc_key].get(config["field"])
|
||||
if not location_type or location_type == "unknown":
|
||||
location_type = config["type"]
|
||||
break
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
'content': content,
|
||||
'content_type': content_type,
|
||||
'source_type': location_type,
|
||||
'source_uri': source_uri
|
||||
"content": content,
|
||||
"content_type": content_type,
|
||||
"source_type": location_type,
|
||||
"source_uri": source_uri,
|
||||
}
|
||||
|
||||
# Add optional fields if available
|
||||
if 'score' in result:
|
||||
result_object['score'] = result['score']
|
||||
if "score" in result:
|
||||
result_object["score"] = result["score"]
|
||||
|
||||
if 'metadata' in result:
|
||||
result_object['metadata'] = result['metadata']
|
||||
if "metadata" in result:
|
||||
result_object["metadata"] = result["metadata"]
|
||||
|
||||
# Handle byte content if present
|
||||
if 'byteContent' in content_obj:
|
||||
result_object['byte_content'] = content_obj['byteContent']
|
||||
if "byteContent" in content_obj:
|
||||
result_object["byte_content"] = content_obj["byteContent"]
|
||||
|
||||
# Handle row content if present
|
||||
if 'row' in content_obj:
|
||||
result_object['row_content'] = content_obj['row']
|
||||
if "row" in content_obj:
|
||||
result_object["row_content"] = content_obj["row"]
|
||||
|
||||
return result_object
|
||||
|
||||
@@ -186,35 +199,35 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent_runtime = boto3.client(
|
||||
'bedrock-agent-runtime',
|
||||
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-east-1')),
|
||||
"bedrock-agent-runtime",
|
||||
region_name=os.getenv(
|
||||
"AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")
|
||||
),
|
||||
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
|
||||
)
|
||||
|
||||
# Prepare the request parameters
|
||||
retrieve_params = {
|
||||
'knowledgeBaseId': self.knowledge_base_id,
|
||||
'retrievalQuery': {
|
||||
'text': query
|
||||
}
|
||||
"knowledgeBaseId": self.knowledge_base_id,
|
||||
"retrievalQuery": {"text": query},
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if self.retrieval_configuration:
|
||||
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration
|
||||
retrieve_params["retrievalConfiguration"] = self.retrieval_configuration
|
||||
|
||||
if self.guardrail_configuration:
|
||||
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration
|
||||
retrieve_params["guardrailConfiguration"] = self.guardrail_configuration
|
||||
|
||||
if self.next_token:
|
||||
retrieve_params['nextToken'] = self.next_token
|
||||
retrieve_params["nextToken"] = self.next_token
|
||||
|
||||
# Make the retrieve API call
|
||||
response = bedrock_agent_runtime.retrieve(**retrieve_params)
|
||||
|
||||
# Process the response
|
||||
results = []
|
||||
for result in response.get('retrievalResults', []):
|
||||
for result in response.get("retrievalResults", []):
|
||||
processed_result = self._process_retrieval_result(result)
|
||||
results.append(processed_result)
|
||||
|
||||
@@ -239,10 +252,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"].get("Code", "Unknown")
|
||||
error_message = e.response["Error"].get("Message", str(e))
|
||||
|
||||
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||
except Exception as e:
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}")
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}")
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .reader_tool import S3ReaderTool
|
||||
from .writer_tool import S3WriterTool
|
||||
from .writer_tool import S3WriterTool
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Any, Type, List
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -8,14 +7,16 @@ from pydantic import BaseModel, Field
|
||||
class S3ReaderToolInput(BaseModel):
|
||||
"""Input schema for S3ReaderTool."""
|
||||
|
||||
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')")
|
||||
file_path: str = Field(
|
||||
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
|
||||
)
|
||||
|
||||
|
||||
class S3ReaderTool(BaseTool):
|
||||
name: str = "S3 Reader Tool"
|
||||
description: str = "Reads a file from Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3ReaderToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
args_schema: type[BaseModel] = S3ReaderToolInput
|
||||
package_dependencies: list[str] = ["boto3"]
|
||||
|
||||
def _run(self, file_path: str) -> str:
|
||||
try:
|
||||
@@ -28,19 +29,18 @@ class S3ReaderTool(BaseTool):
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
|
||||
s3 = boto3.client(
|
||||
's3',
|
||||
region_name=os.getenv('CREW_AWS_REGION', 'us-east-1'),
|
||||
aws_access_key_id=os.getenv('CREW_AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_access_key=os.getenv('CREW_AWS_SEC_ACCESS_KEY')
|
||||
"s3",
|
||||
region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
|
||||
aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
|
||||
)
|
||||
|
||||
# Read file content from S3
|
||||
response = s3.get_object(Bucket=bucket_name, Key=object_key)
|
||||
file_content = response['Body'].read().decode('utf-8')
|
||||
return response["Body"].read().decode("utf-8")
|
||||
|
||||
return file_content
|
||||
except ClientError as e:
|
||||
return f"Error reading file from S3: {str(e)}"
|
||||
return f"Error reading file from S3: {e!s}"
|
||||
|
||||
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||
parts = file_path.replace("s3://", "").split("/", 1)
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
from typing import Type, List
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class S3WriterToolInput(BaseModel):
|
||||
"""Input schema for S3WriterTool."""
|
||||
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')")
|
||||
|
||||
file_path: str = Field(
|
||||
..., description="S3 file path (e.g., 's3://bucket-name/file-name')"
|
||||
)
|
||||
content: str = Field(..., description="Content to write to the file")
|
||||
|
||||
|
||||
class S3WriterTool(BaseTool):
|
||||
name: str = "S3 Writer Tool"
|
||||
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3WriterToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
args_schema: type[BaseModel] = S3WriterToolInput
|
||||
package_dependencies: list[str] = ["boto3"]
|
||||
|
||||
def _run(self, file_path: str, content: str) -> str:
|
||||
try:
|
||||
@@ -27,16 +30,18 @@ class S3WriterTool(BaseTool):
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
|
||||
s3 = boto3.client(
|
||||
's3',
|
||||
region_name=os.getenv('CREW_AWS_REGION', 'us-east-1'),
|
||||
aws_access_key_id=os.getenv('CREW_AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_access_key=os.getenv('CREW_AWS_SEC_ACCESS_KEY')
|
||||
"s3",
|
||||
region_name=os.getenv("CREW_AWS_REGION", "us-east-1"),
|
||||
aws_access_key_id=os.getenv("CREW_AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("CREW_AWS_SEC_ACCESS_KEY"),
|
||||
)
|
||||
|
||||
s3.put_object(Bucket=bucket_name, Key=object_key, Body=content.encode('utf-8'))
|
||||
s3.put_object(
|
||||
Bucket=bucket_name, Key=object_key, Body=content.encode("utf-8")
|
||||
)
|
||||
return f"Successfully wrote content to {file_path}"
|
||||
except ClientError as e:
|
||||
return f"Error writing file to S3: {str(e)}"
|
||||
return f"Error writing file to S3: {e!s}"
|
||||
|
||||
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||
parts = file_path.replace("s3://", "").split("/", 1)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Printer:
|
||||
"""Handles colored console output formatting."""
|
||||
|
||||
@staticmethod
|
||||
def print(content: str, color: Optional[str] = None) -> None:
|
||||
def print(content: str, color: str | None = None) -> None:
|
||||
"""Prints content with optional color formatting.
|
||||
|
||||
Args:
|
||||
@@ -29,7 +27,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold purple.
|
||||
"""
|
||||
print("\033[1m\033[95m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[95m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_green(content: str) -> None:
|
||||
@@ -38,7 +36,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold green.
|
||||
"""
|
||||
print("\033[1m\033[92m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[92m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_purple(content: str) -> None:
|
||||
@@ -47,7 +45,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in purple.
|
||||
"""
|
||||
print("\033[95m {}\033[00m".format(content))
|
||||
print(f"\033[95m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_red(content: str) -> None:
|
||||
@@ -56,7 +54,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in red.
|
||||
"""
|
||||
print("\033[91m {}\033[00m".format(content))
|
||||
print(f"\033[91m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_blue(content: str) -> None:
|
||||
@@ -65,7 +63,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold blue.
|
||||
"""
|
||||
print("\033[1m\033[94m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[94m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_yellow(content: str) -> None:
|
||||
@@ -74,7 +72,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in yellow.
|
||||
"""
|
||||
print("\033[93m {}\033[00m".format(content))
|
||||
print(f"\033[93m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_yellow(content: str) -> None:
|
||||
@@ -83,7 +81,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold yellow.
|
||||
"""
|
||||
print("\033[1m\033[93m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[93m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_cyan(content: str) -> None:
|
||||
@@ -92,7 +90,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in cyan.
|
||||
"""
|
||||
print("\033[96m {}\033[00m".format(content))
|
||||
print(f"\033[96m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_cyan(content: str) -> None:
|
||||
@@ -101,7 +99,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold cyan.
|
||||
"""
|
||||
print("\033[1m\033[96m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[96m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_magenta(content: str) -> None:
|
||||
@@ -110,7 +108,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in magenta.
|
||||
"""
|
||||
print("\033[35m {}\033[00m".format(content))
|
||||
print(f"\033[35m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_magenta(content: str) -> None:
|
||||
@@ -119,7 +117,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold magenta.
|
||||
"""
|
||||
print("\033[1m\033[35m {}\033[00m".format(content))
|
||||
print(f"\033[1m\033[35m {content}\033[00m")
|
||||
|
||||
@staticmethod
|
||||
def _print_green(content: str) -> None:
|
||||
@@ -128,4 +126,4 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in green.
|
||||
"""
|
||||
print("\033[32m {}\033[00m".format(content))
|
||||
print(f"\033[32m {content}\033[00m")
|
||||
|
||||
@@ -3,6 +3,6 @@ from crewai_tools.rag.data_types import DataType
|
||||
|
||||
__all__ = [
|
||||
"RAG",
|
||||
"EmbeddingService",
|
||||
"DataType",
|
||||
"EmbeddingService",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
@@ -9,19 +10,22 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class LoaderResult(BaseModel):
|
||||
content: str = Field(description="The text content of the source")
|
||||
source: str = Field(description="The source of the content", default="unknown")
|
||||
metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(
|
||||
description="The metadata of the source", default_factory=dict
|
||||
)
|
||||
doc_id: str = Field(description="The id of the document")
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult:
|
||||
...
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
|
||||
|
||||
def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str:
|
||||
def generate_doc_id(
|
||||
self, source_ref: str | None = None, content: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique document id based on the source reference and content.
|
||||
If the source reference is not provided, the content is used as the source reference.
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
|
||||
from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker
|
||||
from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker
|
||||
from crewai_tools.rag.chunkers.structured_chunker import (
|
||||
CsvChunker,
|
||||
JsonChunker,
|
||||
XmlChunker,
|
||||
)
|
||||
from crewai_tools.rag.chunkers.text_chunker import DocxChunker, MdxChunker, TextChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"DefaultChunker",
|
||||
"TextChunker",
|
||||
"DocxChunker",
|
||||
"MdxChunker",
|
||||
"CsvChunker",
|
||||
"DefaultChunker",
|
||||
"DocxChunker",
|
||||
"JsonChunker",
|
||||
"MdxChunker",
|
||||
"TextChunker",
|
||||
"XmlChunker",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
import re
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter:
|
||||
"""
|
||||
A text splitter that recursively splits text based on a hierarchy of separators.
|
||||
@@ -10,7 +10,7 @@ class RecursiveCharacterTextSplitter:
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -23,7 +23,9 @@ class RecursiveCharacterTextSplitter:
|
||||
keep_separator: Whether to keep the separator in the split text
|
||||
"""
|
||||
if chunk_overlap >= chunk_size:
|
||||
raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})")
|
||||
raise ValueError(
|
||||
f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})"
|
||||
)
|
||||
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
@@ -36,10 +38,10 @@ class RecursiveCharacterTextSplitter:
|
||||
"",
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
@@ -49,7 +51,7 @@ class RecursiveCharacterTextSplitter:
|
||||
break
|
||||
if re.search(re.escape(sep), text):
|
||||
separator = sep
|
||||
new_separators = separators[i + 1:]
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = self._split_text_with_separator(text, separator)
|
||||
@@ -68,7 +70,7 @@ class RecursiveCharacterTextSplitter:
|
||||
|
||||
return self._merge_splits(good_splits, separator)
|
||||
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> list[str]:
|
||||
if separator == "":
|
||||
return list(text)
|
||||
|
||||
@@ -90,16 +92,15 @@ class RecursiveCharacterTextSplitter:
|
||||
splits[-1] += separator
|
||||
|
||||
return [s for s in splits if s]
|
||||
else:
|
||||
return text.split(separator)
|
||||
return text.split(separator)
|
||||
|
||||
def _split_by_characters(self, text: str) -> List[str]:
|
||||
def _split_by_characters(self, text: str) -> list[str]:
|
||||
chunks = []
|
||||
for i in range(0, len(text), self._chunk_size):
|
||||
chunks.append(text[i:i + self._chunk_size])
|
||||
chunks.append(text[i : i + self._chunk_size])
|
||||
return chunks
|
||||
|
||||
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||
def _merge_splits(self, splits: list[str], separator: str) -> list[str]:
|
||||
"""Merge splits into chunks with proper overlap."""
|
||||
docs = []
|
||||
current_doc = []
|
||||
@@ -112,7 +113,10 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
@@ -133,15 +137,25 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class BaseChunker:
|
||||
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the Chunker
|
||||
|
||||
@@ -159,8 +173,7 @@ class BaseChunker:
|
||||
keep_separator=keep_separator,
|
||||
)
|
||||
|
||||
|
||||
def chunk(self, text: str) -> List[str]:
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class DefaultChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 20,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,49 +1,66 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class CsvChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1200,
|
||||
chunk_overlap: int = 100,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||
"\n", # Line breaks
|
||||
" | ", # Column separators
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||
"\n", # Line breaks
|
||||
" | ", # Column separators
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class JsonChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Object/array boundaries
|
||||
"\n", # Line breaks
|
||||
"},", # Object endings
|
||||
"],", # Array endings
|
||||
", ", # Property separators
|
||||
": ", # Key-value separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Object/array boundaries
|
||||
"\n", # Line breaks
|
||||
"},", # Object endings
|
||||
"],", # Array endings
|
||||
", ", # Property separators
|
||||
": ", # Key-value separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class XmlChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Element boundaries
|
||||
"\n", # Line breaks
|
||||
">", # Tag endings
|
||||
". ", # Sentence endings (for text content)
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Element boundaries
|
||||
"\n", # Line breaks
|
||||
">", # Tag endings
|
||||
". ", # Sentence endings (for text content)
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,59 +1,76 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1500,
|
||||
chunk_overlap: int = 150,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class DocxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (major sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class MdxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 3000,
|
||||
chunk_overlap: int = 300,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n### ", # H3 headers (subsections)
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class WebsiteChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Major section breaks
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,29 +22,21 @@ class EmbeddingService:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=[text],
|
||||
**self.kwargs
|
||||
)
|
||||
return response.data[0]['embedding']
|
||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||
return response.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
**self.kwargs
|
||||
)
|
||||
return [data['embedding'] for data in response.data]
|
||||
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
||||
return [data["embedding"] for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise
|
||||
@@ -53,18 +45,18 @@ class EmbeddingService:
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
data_type: DataType = DataType.TEXT
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: Optional[str] = None
|
||||
persist_directory: str | None = None
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
embedding_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
@@ -79,10 +71,15 @@ class RAG(Adapter):
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
|
||||
self._embedding_service = EmbeddingService(
|
||||
model=self.embedding_model, **self.embedding_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||
raise
|
||||
@@ -92,11 +89,11 @@ class RAG(Adapter):
|
||||
def add(
|
||||
self,
|
||||
content: str | Path,
|
||||
data_type: Optional[Union[str, DataType]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Any
|
||||
data_type: str | DataType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
loader: BaseLoader | None = None,
|
||||
chunker: BaseChunker | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
source_content = SourceContent(content)
|
||||
|
||||
@@ -111,11 +108,19 @@ class RAG(Adapter):
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
|
||||
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(f"Document with source {loader_result.source} already exists")
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
@@ -128,14 +133,16 @@ class RAG(Adapter):
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata['chunk_index'] = i
|
||||
documents.append(Document(
|
||||
id=compute_sha256(chunk),
|
||||
content=chunk,
|
||||
metadata=doc_metadata,
|
||||
data_type=data_type,
|
||||
source=loader_result.source
|
||||
))
|
||||
doc_metadata["chunk_index"] = i
|
||||
documents.append(
|
||||
Document(
|
||||
id=compute_sha256(chunk),
|
||||
content=chunk,
|
||||
metadata=doc_metadata,
|
||||
data_type=data_type,
|
||||
source=loader_result.source,
|
||||
)
|
||||
)
|
||||
|
||||
if not documents:
|
||||
logger.warning("No documents to add")
|
||||
@@ -153,11 +160,13 @@ class RAG(Adapter):
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update({
|
||||
"data_type": doc.data_type.value,
|
||||
"source": doc.source,
|
||||
"doc_id": doc_id
|
||||
})
|
||||
doc_metadata.update(
|
||||
{
|
||||
"data_type": doc.data_type.value,
|
||||
"source": doc.source,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
)
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
try:
|
||||
@@ -171,7 +180,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
@@ -179,10 +188,14 @@ class RAG(Adapter):
|
||||
query_embeddings=[question_embedding],
|
||||
n_results=self.top_k,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results or not results.get("documents") or not results["documents"][0]:
|
||||
if (
|
||||
not results
|
||||
or not results.get("documents")
|
||||
or not results["documents"][0]
|
||||
):
|
||||
return "No relevant content found."
|
||||
|
||||
documents = results["documents"][0]
|
||||
@@ -195,8 +208,12 @@ class RAG(Adapter):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
source = metadata.get("source", "unknown") if metadata else "unknown"
|
||||
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
|
||||
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
|
||||
score = (
|
||||
1 - distance if distance is not None else 0
|
||||
) # Convert distance to similarity
|
||||
formatted_results.append(
|
||||
f"[Source: {source}, Relevance: {score:.3f}]\n{doc}"
|
||||
)
|
||||
|
||||
return "\n\n".join(formatted_results)
|
||||
except Exception as e:
|
||||
@@ -210,23 +227,25 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
def get_collection_info(self) -> dict[str, Any]:
|
||||
try:
|
||||
count = self._collection.count()
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"count": count,
|
||||
"embedding_model": self.embedding_model
|
||||
"embedding_model": self.embedding_model,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get collection info: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
|
||||
def _get_data_type(
|
||||
self, content: SourceContent, data_type: str | DataType | None = None
|
||||
) -> DataType:
|
||||
try:
|
||||
if isinstance(data_type, str):
|
||||
return DataType(data_type)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return content.data_type
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
PDF_FILE = "pdf_file"
|
||||
@@ -25,29 +27,38 @@ class DataType(str, Enum):
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
from importlib import import_module
|
||||
|
||||
chunkers = {
|
||||
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||
DataType.MDX: ("text_chunker", "MdxChunker"),
|
||||
|
||||
# Structured formats
|
||||
DataType.CSV: ("structured_chunker", "CsvChunker"),
|
||||
DataType.JSON: ("structured_chunker", "JsonChunker"),
|
||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||
|
||||
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
|
||||
DataType.GITHUB: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
|
||||
DataType.MYSQL: ("text_chunker", "TextChunker"),
|
||||
DataType.POSTGRES: ("text_chunker", "TextChunker"),
|
||||
}
|
||||
|
||||
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
|
||||
if self not in chunkers:
|
||||
raise ValueError(f"No chunker defined for {self}")
|
||||
module_name, class_name = chunkers[self]
|
||||
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||
|
||||
try:
|
||||
@@ -60,6 +71,7 @@ class DataType(str, Enum):
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||
@@ -69,9 +81,20 @@ class DataType(str, Enum):
|
||||
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
|
||||
DataType.YOUTUBE_CHANNEL: (
|
||||
"youtube_channel_loader",
|
||||
"YoutubeChannelLoader",
|
||||
),
|
||||
DataType.GITHUB: ("github_loader", "GithubLoader"),
|
||||
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
|
||||
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
|
||||
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
|
||||
}
|
||||
|
||||
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
|
||||
if self not in loaders:
|
||||
raise ValueError(f"No loader defined for {self}")
|
||||
module_name, class_name = loaders[self]
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
@@ -79,6 +102,7 @@ class DataType(str, Enum):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}")
|
||||
|
||||
|
||||
class DataTypes:
|
||||
@staticmethod
|
||||
def from_content(content: str | Path | None = None) -> DataType:
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
|
||||
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
|
||||
|
||||
__all__ = [
|
||||
"CSVLoader",
|
||||
"DOCXLoader",
|
||||
"DirectoryLoader",
|
||||
"JSONLoader",
|
||||
"MDXLoader",
|
||||
"PDFLoader",
|
||||
"TextFileLoader",
|
||||
"TextLoader",
|
||||
"XMLLoader",
|
||||
"WebPageLoader",
|
||||
"MDXLoader",
|
||||
"JSONLoader",
|
||||
"DOCXLoader",
|
||||
"CSVLoader",
|
||||
"DirectoryLoader",
|
||||
"XMLLoader",
|
||||
"YoutubeChannelLoader",
|
||||
"YoutubeVideoLoader",
|
||||
]
|
||||
|
||||
@@ -17,21 +17,23 @@ class CSVLoader(BaseLoader):
|
||||
|
||||
return self._parse_csv(content_str, source_ref)
|
||||
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -57,7 +59,7 @@ class CSVLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "csv",
|
||||
"columns": headers,
|
||||
"rows": len(text_parts) - 2 if headers else 0
|
||||
"rows": len(text_parts) - 2 if headers else 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -68,5 +70,5 @@ class CSVLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
@@ -22,7 +21,9 @@ class DirectoryLoader(BaseLoader):
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
if source_content.is_url():
|
||||
raise ValueError("URL directory loading is not supported. Please provide a local directory path.")
|
||||
raise ValueError(
|
||||
"URL directory loading is not supported. Please provide a local directory path."
|
||||
)
|
||||
|
||||
if not os.path.exists(source_ref):
|
||||
raise FileNotFoundError(f"Directory does not exist: {source_ref}")
|
||||
@@ -38,7 +39,9 @@ class DirectoryLoader(BaseLoader):
|
||||
exclude_extensions = kwargs.get("exclude_extensions", None)
|
||||
max_files = kwargs.get("max_files", None)
|
||||
|
||||
files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions)
|
||||
files = self._find_files(
|
||||
dir_path, recursive, include_extensions, exclude_extensions
|
||||
)
|
||||
|
||||
if max_files and len(files) > max_files:
|
||||
files = files[:max_files]
|
||||
@@ -52,13 +55,15 @@ class DirectoryLoader(BaseLoader):
|
||||
result = self._process_single_file(file_path)
|
||||
if result:
|
||||
all_contents.append(f"=== File: {file_path} ===\n{result.content}")
|
||||
processed_files.append({
|
||||
"path": file_path,
|
||||
"metadata": result.metadata,
|
||||
"source": result.source
|
||||
})
|
||||
processed_files.append(
|
||||
{
|
||||
"path": file_path,
|
||||
"metadata": result.metadata,
|
||||
"source": result.source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing {file_path}: {str(e)}"
|
||||
error_msg = f"Error processing {file_path}: {e!s}"
|
||||
errors.append(error_msg)
|
||||
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
|
||||
|
||||
@@ -71,23 +76,29 @@ class DirectoryLoader(BaseLoader):
|
||||
"processed_files": len(processed_files),
|
||||
"errors": len(errors),
|
||||
"file_details": processed_files,
|
||||
"error_details": errors
|
||||
"error_details": errors,
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=combined_content,
|
||||
source=dir_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content)
|
||||
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content),
|
||||
)
|
||||
|
||||
def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]:
|
||||
def _find_files(
|
||||
self,
|
||||
dir_path: str,
|
||||
recursive: bool,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Find all files in directory matching criteria."""
|
||||
files = []
|
||||
|
||||
if recursive:
|
||||
for root, dirs, filenames in os.walk(dir_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if self._should_include_file(filename, include_ext, exclude_ext):
|
||||
@@ -96,26 +107,37 @@ class DirectoryLoader(BaseLoader):
|
||||
try:
|
||||
for item in os.listdir(dir_path):
|
||||
item_path = os.path.join(dir_path, item)
|
||||
if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext):
|
||||
if os.path.isfile(item_path) and self._should_include_file(
|
||||
item, include_ext, exclude_ext
|
||||
):
|
||||
files.append(item_path)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool:
|
||||
def _should_include_file(
|
||||
self,
|
||||
filename: str,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Determine if a file should be included based on criteria."""
|
||||
if filename.startswith('.'):
|
||||
if filename.startswith("."):
|
||||
return False
|
||||
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
if include_ext:
|
||||
if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]:
|
||||
if ext not in [
|
||||
e.lower() if e.startswith(".") else f".{e.lower()}" for e in include_ext
|
||||
]:
|
||||
return False
|
||||
|
||||
if exclude_ext:
|
||||
if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]:
|
||||
if ext in [
|
||||
e.lower() if e.startswith(".") else f".{e.lower()}" for e in exclude_ext
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -132,11 +154,13 @@ class DirectoryLoader(BaseLoader):
|
||||
if result.metadata is None:
|
||||
result.metadata = {}
|
||||
|
||||
result.metadata.update({
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"data_type": str(data_type),
|
||||
"loader_type": loader.__class__.__name__
|
||||
})
|
||||
result.metadata.update(
|
||||
{
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"data_type": str(data_type),
|
||||
"loader_type": loader.__class__.__name__,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
106
packages/tools/src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
106
packages/tools/src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Documentation site loader."""
|
||||
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
"""Loader for documentation websites."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a documentation site.
|
||||
|
||||
Args:
|
||||
source: Documentation site URL
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
LoaderResult with documentation content
|
||||
"""
|
||||
docs_url = source.source
|
||||
|
||||
try:
|
||||
response = requests.get(docs_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
|
||||
title = soup.find("title")
|
||||
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||
|
||||
main_content = None
|
||||
for selector in [
|
||||
"main",
|
||||
"article",
|
||||
'[role="main"]',
|
||||
".content",
|
||||
"#content",
|
||||
".documentation",
|
||||
]:
|
||||
main_content = soup.select_one(selector)
|
||||
if main_content:
|
||||
break
|
||||
|
||||
if not main_content:
|
||||
main_content = soup.find("body")
|
||||
|
||||
if not main_content:
|
||||
raise ValueError(
|
||||
f"Unable to extract content from documentation site: {docs_url}"
|
||||
)
|
||||
|
||||
text_parts = [f"Title: {title_text}", ""]
|
||||
|
||||
headings = main_content.find_all(["h1", "h2", "h3"])
|
||||
if headings:
|
||||
text_parts.append("Table of Contents:")
|
||||
for heading in headings[:15]:
|
||||
level = int(heading.name[1])
|
||||
indent = " " * (level - 1)
|
||||
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||
text_parts.append("")
|
||||
|
||||
text = main_content.get_text(separator="\n", strip=True)
|
||||
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||
text_parts.extend(lines)
|
||||
|
||||
nav_links = []
|
||||
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
|
||||
nav = soup.select_one(nav_selector)
|
||||
if nav:
|
||||
links = nav.find_all("a", href=True)
|
||||
for link in links[:20]:
|
||||
href = link["href"]
|
||||
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||
full_url = urljoin(docs_url, href)
|
||||
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||
|
||||
if nav_links:
|
||||
text_parts.append("")
|
||||
text_parts.append("Related documentation pages:")
|
||||
text_parts.extend(nav_links[:10])
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": docs_url,
|
||||
"title": title_text,
|
||||
"domain": urlparse(docs_url).netloc,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=docs_url, content=content),
|
||||
)
|
||||
@@ -10,7 +10,9 @@ class DOCXLoader(BaseLoader):
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]")
|
||||
raise ImportError(
|
||||
"python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]"
|
||||
)
|
||||
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
@@ -23,28 +25,35 @@ class DOCXLoader(BaseLoader):
|
||||
elif source_content.path_exists():
|
||||
return self._load_from_file(source_ref, source_ref, DocxDocument)
|
||||
else:
|
||||
raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}")
|
||||
raise ValueError(
|
||||
f"Source must be a valid file path or URL, got: {source_content.source}"
|
||||
)
|
||||
|
||||
def _download_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create temporary file to save the DOCX content
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as temp_file:
|
||||
temp_file.write(response.content)
|
||||
return temp_file.name
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult:
|
||||
def _load_from_file(
|
||||
self, file_path: str, source_ref: str, DocxDocument
|
||||
) -> LoaderResult:
|
||||
try:
|
||||
doc = DocxDocument(file_path)
|
||||
|
||||
@@ -58,15 +67,15 @@ class DOCXLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "docx",
|
||||
"paragraphs": len(doc.paragraphs),
|
||||
"tables": len(doc.tables)
|
||||
"tables": len(doc.tables),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading DOCX file: {str(e)}")
|
||||
raise ValueError(f"Error loading DOCX file: {e!s}")
|
||||
|
||||
112
packages/tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
112
packages/tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""GitHub repository content loader."""
|
||||
|
||||
from github import Github, GithubException
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Loader for GitHub repository content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a GitHub repository.
|
||||
|
||||
Args:
|
||||
source: GitHub repository URL
|
||||
**kwargs: Additional arguments including gh_token and content_types
|
||||
|
||||
Returns:
|
||||
LoaderResult with repository content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
gh_token = metadata.get("gh_token")
|
||||
content_types = metadata.get("content_types", ["code", "repo"])
|
||||
|
||||
repo_url = source.source
|
||||
if not repo_url.startswith("https://github.com/"):
|
||||
raise ValueError(f"Invalid GitHub URL: {repo_url}")
|
||||
|
||||
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
|
||||
|
||||
repo_name = f"{parts[0]}/{parts[1]}"
|
||||
|
||||
g = Github(gh_token) if gh_token else Github()
|
||||
|
||||
try:
|
||||
repo = g.get_repo(repo_name)
|
||||
except GithubException as e:
|
||||
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||
|
||||
all_content = []
|
||||
|
||||
if "repo" in content_types:
|
||||
all_content.append(f"Repository: {repo.full_name}")
|
||||
all_content.append(f"Description: {repo.description or 'No description'}")
|
||||
all_content.append(f"Language: {repo.language or 'Not specified'}")
|
||||
all_content.append(f"Stars: {repo.stargazers_count}")
|
||||
all_content.append(f"Forks: {repo.forks_count}")
|
||||
all_content.append("")
|
||||
|
||||
if "code" in content_types:
|
||||
try:
|
||||
readme = repo.get_readme()
|
||||
all_content.append("README:")
|
||||
all_content.append(
|
||||
readme.decoded_content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
try:
|
||||
contents = repo.get_contents("")
|
||||
if isinstance(contents, list):
|
||||
all_content.append("Repository structure:")
|
||||
for content_file in contents[:20]:
|
||||
all_content.append(
|
||||
f"- {content_file.path} ({content_file.type})"
|
||||
)
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
if "pr" in content_types:
|
||||
prs = repo.get_pulls(state="open")
|
||||
pr_list = list(prs[:5])
|
||||
if pr_list:
|
||||
all_content.append("Recent Pull Requests:")
|
||||
for pr in pr_list:
|
||||
all_content.append(f"- PR #{pr.number}: {pr.title}")
|
||||
if pr.body:
|
||||
body_preview = pr.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if "issue" in content_types:
|
||||
issues = repo.get_issues(state="open")
|
||||
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
|
||||
if issue_list:
|
||||
all_content.append("Recent Issues:")
|
||||
for issue in issue_list:
|
||||
all_content.append(f"- Issue #{issue.number}: {issue.title}")
|
||||
if issue.body:
|
||||
body_preview = issue.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if not all_content:
|
||||
raise ValueError(f"No content could be loaded from repository: {repo_url}")
|
||||
|
||||
content = "\n".join(all_content)
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": repo_url,
|
||||
"repo": repo_name,
|
||||
"content_types": content_types,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=repo_url, content=content),
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
@@ -19,17 +19,24 @@ class JSONLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2)
|
||||
return (
|
||||
response.text
|
||||
if not self._is_json_response(response)
|
||||
else json.dumps(response.json(), indent=2)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}")
|
||||
|
||||
def _is_json_response(self, response) -> bool:
|
||||
try:
|
||||
@@ -46,7 +53,9 @@ class JSONLoader(BaseLoader):
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items())
|
||||
text = "\n".join(
|
||||
f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
text = "\n".join(json.dumps(item, indent=0) for item in data)
|
||||
else:
|
||||
@@ -55,7 +64,7 @@ class JSONLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "json",
|
||||
"type": type(data).__name__,
|
||||
"size": len(data) if isinstance(data, (list, dict)) else 1
|
||||
"size": len(data) if isinstance(data, (list, dict)) else 1,
|
||||
}
|
||||
except json.JSONDecodeError as e:
|
||||
text = content
|
||||
@@ -65,5 +74,5 @@ class JSONLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class MDXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
@@ -18,17 +19,20 @@ class MDXLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -38,16 +42,20 @@ class MDXLoader(BaseLoader):
|
||||
cleaned_content = content
|
||||
|
||||
# Remove import statements
|
||||
cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE)
|
||||
cleaned_content = re.sub(
|
||||
r"^import\s+.*?\n", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove export statements
|
||||
cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE)
|
||||
cleaned_content = re.sub(
|
||||
r"^export\s+.*?(?:\n|$)", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove JSX tags (simple approach)
|
||||
cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content)
|
||||
cleaned_content = re.sub(r"<[^>]+>", "", cleaned_content)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content)
|
||||
cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content)
|
||||
cleaned_content = cleaned_content.strip()
|
||||
|
||||
metadata = {"format": "mdx"}
|
||||
@@ -55,5 +63,5 @@ class MDXLoader(BaseLoader):
|
||||
content=cleaned_content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content),
|
||||
)
|
||||
|
||||
100
packages/tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
100
packages/tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MySQL database loader."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pymysql
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
"""Loader for MySQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a MySQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for MySQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
||||
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor,
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = pymysql.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(
|
||||
source_ref=query, content=content
|
||||
),
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except pymysql.Error as e:
|
||||
raise ValueError(f"MySQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||
71
packages/tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
71
packages/tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""PDF loader for extracting text from PDF files."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
)
|
||||
|
||||
file_path = source.source
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
|
||||
text_content = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}")
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
)
|
||||
100
packages/tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
100
packages/tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""PostgreSQL database loader."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"cursor_factory": RealDictCursor,
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = psycopg2.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(
|
||||
source_ref=query, content=content
|
||||
),
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||
@@ -1,18 +1,23 @@
|
||||
import re
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
url = source_content.source
|
||||
headers = kwargs.get("headers", {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=15, headers=headers)
|
||||
@@ -28,20 +33,22 @@ class WebPageLoader(BaseLoader):
|
||||
text = re.sub("\\s+\n\\s+", "\n", text)
|
||||
text = text.strip()
|
||||
|
||||
title = soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||
title = (
|
||||
soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||
)
|
||||
metadata = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"status_code": response.status_code,
|
||||
"content_type": response.headers.get("content-type", "")
|
||||
"content_type": response.headers.get("content-type", ""),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=url, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=url, content=text),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading webpage {url}: {str(e)}")
|
||||
raise ValueError(f"Error loading webpage {url}: {e!s}")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class XMLLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif os.path.exists(source_ref):
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_xml(content, source_ref)
|
||||
@@ -19,17 +19,20 @@ class XMLLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -37,7 +40,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
if content.strip().startswith('<'):
|
||||
if content.strip().startswith("<"):
|
||||
root = ET.fromstring(content)
|
||||
else:
|
||||
root = ET.parse(source_ref).getroot()
|
||||
@@ -57,5 +60,5 @@ class XMLLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""YouTube channel loader for extracting content from YouTube channels."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for YouTube channels."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract content from a YouTube channel.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube channel URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with channel content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube channel URL
|
||||
"""
|
||||
try:
|
||||
from pytube import Channel
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube channel support requires pytube. Install with: uv add pytube"
|
||||
)
|
||||
|
||||
channel_url = source.source
|
||||
|
||||
if not any(
|
||||
pattern in channel_url
|
||||
for pattern in [
|
||||
"youtube.com/channel/",
|
||||
"youtube.com/c/",
|
||||
"youtube.com/@",
|
||||
"youtube.com/user/",
|
||||
]
|
||||
):
|
||||
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": channel_url,
|
||||
"data_type": "youtube_channel",
|
||||
}
|
||||
|
||||
try:
|
||||
channel = Channel(channel_url)
|
||||
|
||||
metadata["channel_name"] = channel.channel_name
|
||||
metadata["channel_id"] = channel.channel_id
|
||||
|
||||
max_videos = kwargs.get("max_videos", 10)
|
||||
video_urls = list(channel.video_urls)[:max_videos]
|
||||
metadata["num_videos_loaded"] = len(video_urls)
|
||||
metadata["total_videos"] = len(list(channel.video_urls))
|
||||
|
||||
content_parts = [
|
||||
f"YouTube Channel: {channel.channel_name}",
|
||||
f"Channel ID: {channel.channel_id}",
|
||||
f"Total Videos: {metadata['total_videos']}",
|
||||
f"Videos Loaded: {metadata['num_videos_loaded']}",
|
||||
"\n--- Video Summaries ---\n",
|
||||
]
|
||||
|
||||
try:
|
||||
from pytube import YouTube
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
try:
|
||||
video_id = self._extract_video_id(video_url)
|
||||
if not video_id:
|
||||
continue
|
||||
yt = YouTube(video_url)
|
||||
title = yt.title or f"Video {i}"
|
||||
description = (
|
||||
yt.description[:200] if yt.description else "No description"
|
||||
)
|
||||
|
||||
content_parts.append(f"\n{i}. {title}")
|
||||
content_parts.append(f" URL: {video_url}")
|
||||
content_parts.append(f" Description: {description}...")
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
transcript = None
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
try:
|
||||
transcript = (
|
||||
transcript_list.find_generated_transcript(
|
||||
["en"]
|
||||
)
|
||||
)
|
||||
except:
|
||||
transcript = next(iter(transcript_list), None)
|
||||
|
||||
if transcript:
|
||||
transcript_data = transcript.fetch()
|
||||
text_parts = []
|
||||
char_count = 0
|
||||
for entry in transcript_data:
|
||||
text = (
|
||||
entry.text.strip()
|
||||
if hasattr(entry, "text")
|
||||
else ""
|
||||
)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
char_count += len(text)
|
||||
if char_count > 500:
|
||||
break
|
||||
|
||||
if text_parts:
|
||||
preview = " ".join(text_parts)[:500]
|
||||
content_parts.append(
|
||||
f" Transcript Preview: {preview}..."
|
||||
)
|
||||
except:
|
||||
content_parts.append(" Transcript: Not available")
|
||||
|
||||
except Exception as e:
|
||||
content_parts.append(f"\n{i}. Error loading video: {e!s}")
|
||||
|
||||
except ImportError:
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
content_parts.append(f"\n{i}. {video_url}")
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to load YouTube channel {channel_url}: {e!s}"
|
||||
) from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=channel_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=channel_url, content=content),
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from YouTube URL."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,134 @@
|
||||
"""YouTube video loader for extracting transcripts from YouTube videos."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
"""Loader for YouTube videos."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract transcript from a YouTube video.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with transcript content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube video URL
|
||||
"""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube support requires youtube-transcript-api. "
|
||||
"Install with: uv add youtube-transcript-api"
|
||||
)
|
||||
|
||||
video_url = source.source
|
||||
video_id = self._extract_video_id(video_url)
|
||||
|
||||
if not video_id:
|
||||
raise ValueError(f"Invalid YouTube URL: {video_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": video_url,
|
||||
"video_id": video_id,
|
||||
"data_type": "youtube_video",
|
||||
}
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
try:
|
||||
transcript = transcript_list.find_generated_transcript(["en"])
|
||||
except:
|
||||
transcript = next(iter(transcript_list))
|
||||
|
||||
if transcript:
|
||||
metadata["language"] = transcript.language
|
||||
metadata["is_generated"] = transcript.is_generated
|
||||
|
||||
transcript_data = transcript.fetch()
|
||||
|
||||
text_content = []
|
||||
for entry in transcript_data:
|
||||
text = entry.text.strip() if hasattr(entry, "text") else ""
|
||||
if text:
|
||||
text_content.append(text)
|
||||
|
||||
content = " ".join(text_content)
|
||||
|
||||
try:
|
||||
from pytube import YouTube
|
||||
|
||||
yt = YouTube(video_url)
|
||||
metadata["title"] = yt.title
|
||||
metadata["author"] = yt.author
|
||||
metadata["length_seconds"] = yt.length
|
||||
metadata["description"] = (
|
||||
yt.description[:500] if yt.description else None
|
||||
)
|
||||
|
||||
if yt.title:
|
||||
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No transcript available for YouTube video: {video_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to extract transcript from YouTube video {video_id}: {e!s}"
|
||||
) from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=video_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=video_url, content=content),
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from various YouTube URL formats."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if hostname:
|
||||
hostname_lower = hostname.lower()
|
||||
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
|
||||
if (
|
||||
hostname_lower == "youtube.com"
|
||||
or hostname_lower.endswith(".youtube.com")
|
||||
or hostname_lower == "youtu.be"
|
||||
):
|
||||
query_params = parse_qs(parsed.query)
|
||||
if "v" in query_params:
|
||||
return query_params["v"][0]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,4 +1,31 @@
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
def compute_sha256(content: str) -> str:
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize metadata to ensure ChromaDB compatibility.
|
||||
|
||||
ChromaDB only accepts str, int, float, or bool values in metadata.
|
||||
This function converts other types to strings.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary of metadata to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized metadata dictionary with only ChromaDB-compatible types
|
||||
"""
|
||||
sanitized = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
sanitized[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert lists/tuples to pipe-separated strings
|
||||
sanitized[key] = " | ".join(str(v) for v in value)
|
||||
else:
|
||||
# Convert other types to string
|
||||
sanitized[key] = str(value)
|
||||
return sanitized
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
|
||||
@@ -34,7 +34,7 @@ class SourceContent:
|
||||
|
||||
@cached_property
|
||||
def source_ref(self) -> str:
|
||||
""""
|
||||
""" "
|
||||
Returns the source reference for the content.
|
||||
If the content is a URL or a local file, returns the source.
|
||||
Otherwise, returns the hash of the content.
|
||||
|
||||
@@ -70,6 +70,9 @@ from .oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool impo
|
||||
from .oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
|
||||
OxylabsUniversalScraperTool,
|
||||
)
|
||||
from .parallel_tools import (
|
||||
ParallelSearchTool,
|
||||
)
|
||||
from .patronus_eval_tool import (
|
||||
PatronusEvalTool,
|
||||
PatronusLocalEvaluatorTool,
|
||||
@@ -122,6 +125,3 @@ from .youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||
)
|
||||
from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool
|
||||
from .zapier_action_tool.zapier_action_tool import ZapierActionTools
|
||||
from .parallel_tools import (
|
||||
ParallelSearchTool,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import OpenAI
|
||||
@@ -28,20 +28,22 @@ class AIMindTool(BaseTool):
|
||||
"and Google BigQuery. "
|
||||
"Input should be a question in natural language."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: Optional[str] = None
|
||||
datasources: Optional[List[Dict[str, Any]]] = None
|
||||
mind_name: Optional[str] = None
|
||||
package_dependencies: List[str] = ["minds-sdk"]
|
||||
env_vars: List[EnvVar] = [
|
||||
args_schema: type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: str | None = None
|
||||
datasources: list[dict[str, Any]] | None = None
|
||||
mind_name: str | None = None
|
||||
package_dependencies: list[str] = ["minds-sdk"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(name="MINDS_API_KEY", description="API key for AI-Minds", required=True),
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("API key must be provided either through constructor or MINDS_API_KEY environment variable")
|
||||
raise ValueError(
|
||||
"API key must be provided either through constructor or MINDS_API_KEY environment variable"
|
||||
)
|
||||
|
||||
try:
|
||||
from minds.client import Client # type: ignore
|
||||
@@ -74,13 +76,12 @@ class AIMindTool(BaseTool):
|
||||
|
||||
self.mind_name = mind.name
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str
|
||||
):
|
||||
def _run(self, query: str):
|
||||
# Run the query on the AI-Mind.
|
||||
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||
openai_client = OpenAI(base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key)
|
||||
openai_client = OpenAI(
|
||||
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key
|
||||
)
|
||||
|
||||
completion = openai_client.chat.completions.create(
|
||||
model=self.mind_name,
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import Field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
import os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_apify import ApifyActorsTool as _ApifyActorsTool
|
||||
|
||||
|
||||
class ApifyActorsTool(BaseTool):
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="APIFY_API_TOKEN", description="API token for Apify platform access", required=True),
|
||||
env_vars: ClassVar[list[EnvVar]] = [
|
||||
EnvVar(
|
||||
name="APIFY_API_TOKEN",
|
||||
description="API token for Apify platform access",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
"""Tool that runs Apify Actors.
|
||||
|
||||
@@ -40,15 +46,10 @@ class ApifyActorsTool(BaseTool):
|
||||
print(f"URL: {result['metadata']['url']}")
|
||||
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||
"""
|
||||
actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool")
|
||||
package_dependencies: List[str] = ["langchain-apify"]
|
||||
actor_tool: "_ApifyActorsTool" = Field(description="Apify Actor Tool")
|
||||
package_dependencies: ClassVar[list[str]] = ["langchain-apify"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
def __init__(self, actor_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
if not os.environ.get("APIFY_API_TOKEN"):
|
||||
msg = (
|
||||
"APIFY_API_TOKEN environment variable is not set. "
|
||||
@@ -59,11 +60,11 @@ class ApifyActorsTool(BaseTool):
|
||||
|
||||
try:
|
||||
from langchain_apify import ApifyActorsTool as _ApifyActorsTool
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import langchain_apify python package. "
|
||||
"Please install it with `pip install langchain-apify` or `uv add langchain-apify`."
|
||||
)
|
||||
) from e
|
||||
actor_tool = _ApifyActorsTool(actor_name)
|
||||
|
||||
kwargs.update(
|
||||
@@ -76,7 +77,7 @@ class ApifyActorsTool(BaseTool):
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _run(self, run_input: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
def _run(self, run_input: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Run the Actor tool with the given input.
|
||||
|
||||
Returns:
|
||||
@@ -89,8 +90,8 @@ class ApifyActorsTool(BaseTool):
|
||||
return self.actor_tool._run(run_input)
|
||||
except Exception as e:
|
||||
msg = (
|
||||
f'Failed to run ApifyActorsTool {self.name}. '
|
||||
'Please check your Apify account Actor run logs for more details.'
|
||||
f'Error: {e}'
|
||||
f"Failed to run ApifyActorsTool {self.name}. "
|
||||
"Please check your Apify account Actor run logs for more details."
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
@@ -1,35 +1,44 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Type, List, Optional, ClassVar
|
||||
from pydantic import BaseModel, Field
|
||||
from crewai.tools import BaseTool,EnvVar
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class ArxivToolInput(BaseModel):
|
||||
search_query: str = Field(..., description="Search query for Arxiv, e.g., 'transformer neural network'")
|
||||
max_results: int = Field(5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100")
|
||||
search_query: str = Field(
|
||||
..., description="Search query for Arxiv, e.g., 'transformer neural network'"
|
||||
)
|
||||
max_results: int = Field(
|
||||
5, ge=1, le=100, description="Max results to fetch; must be between 1 and 100"
|
||||
)
|
||||
|
||||
|
||||
class ArxivPaperTool(BaseTool):
|
||||
BASE_API_URL: ClassVar[str] = "http://export.arxiv.org/api/query"
|
||||
SLEEP_DURATION: ClassVar[int] = 1
|
||||
SUMMARY_TRUNCATE_LENGTH: ClassVar[int] = 300
|
||||
ATOM_NAMESPACE: ClassVar[str] = "{http://www.w3.org/2005/Atom}"
|
||||
REQUEST_TIMEOUT: ClassVar[int] = 10
|
||||
REQUEST_TIMEOUT: ClassVar[int] = 10
|
||||
name: str = "Arxiv Paper Fetcher and Downloader"
|
||||
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
|
||||
args_schema: Type[BaseModel] = ArxivToolInput
|
||||
model_config = {"extra": "allow"}
|
||||
package_dependencies: List[str] = ["pydantic"]
|
||||
env_vars: List[EnvVar] = []
|
||||
|
||||
def __init__(self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False):
|
||||
args_schema: type[BaseModel] = ArxivToolInput
|
||||
model_config = {"extra": "allow"}
|
||||
package_dependencies: list[str] = ["pydantic"]
|
||||
env_vars: list[EnvVar] = []
|
||||
|
||||
def __init__(
|
||||
self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False
|
||||
):
|
||||
super().__init__()
|
||||
self.download_pdfs = download_pdfs
|
||||
self.save_dir = save_dir
|
||||
@@ -38,44 +47,49 @@ class ArxivPaperTool(BaseTool):
|
||||
def _run(self, search_query: str, max_results: int = 5) -> str:
|
||||
try:
|
||||
args = ArxivToolInput(search_query=search_query, max_results=max_results)
|
||||
logger.info(f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
|
||||
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
|
||||
f"use_title_as_filename={self.use_title_as_filename}")
|
||||
logger.info(
|
||||
f"Running Arxiv tool: query='{args.search_query}', max_results={args.max_results}, "
|
||||
f"download_pdfs={self.download_pdfs}, save_dir='{self.save_dir}', "
|
||||
f"use_title_as_filename={self.use_title_as_filename}"
|
||||
)
|
||||
|
||||
papers = self.fetch_arxiv_data(args.search_query, args.max_results)
|
||||
|
||||
if self.download_pdfs:
|
||||
save_dir = self._validate_save_path(self.save_dir)
|
||||
for paper in papers:
|
||||
if paper['pdf_url']:
|
||||
if paper["pdf_url"]:
|
||||
if self.use_title_as_filename:
|
||||
safe_title = re.sub(r'[\\/*?:"<>|]', "_", paper['title']).strip()
|
||||
filename_base = safe_title or paper['arxiv_id']
|
||||
safe_title = re.sub(
|
||||
r'[\\/*?:"<>|]', "_", paper["title"]
|
||||
).strip()
|
||||
filename_base = safe_title or paper["arxiv_id"]
|
||||
else:
|
||||
filename_base = paper['arxiv_id']
|
||||
filename_base = paper["arxiv_id"]
|
||||
filename = f"{filename_base[:500]}.pdf"
|
||||
save_path = Path(save_dir) / filename
|
||||
|
||||
self.download_pdf(paper['pdf_url'], save_path)
|
||||
self.download_pdf(paper["pdf_url"], save_path)
|
||||
time.sleep(self.SLEEP_DURATION)
|
||||
|
||||
results = [self._format_paper_result(p) for p in papers]
|
||||
return "\n\n" + "-" * 80 + "\n\n".join(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ArxivTool Error: {str(e)}")
|
||||
return f"Failed to fetch or download Arxiv papers: {str(e)}"
|
||||
|
||||
logger.error(f"ArxivTool Error: {e!s}")
|
||||
return f"Failed to fetch or download Arxiv papers: {e!s}"
|
||||
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]:
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> list[dict]:
|
||||
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
|
||||
logger.info(f"Fetching data from Arxiv API: {api_url}")
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(api_url, timeout=self.REQUEST_TIMEOUT) as response:
|
||||
with urllib.request.urlopen(
|
||||
api_url, timeout=self.REQUEST_TIMEOUT
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP {response.status}: {response.reason}")
|
||||
data = response.read().decode('utf-8')
|
||||
data = response.read().decode("utf-8")
|
||||
except urllib.error.URLError as e:
|
||||
logger.error(f"Error fetching data from Arxiv: {e}")
|
||||
raise
|
||||
@@ -85,7 +99,7 @@ class ArxivPaperTool(BaseTool):
|
||||
|
||||
for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
|
||||
raw_id = self._get_element_text(entry, "id")
|
||||
arxiv_id = raw_id.split('/')[-1].replace('.', '_') if raw_id else "unknown"
|
||||
arxiv_id = raw_id.split("/")[-1].replace(".", "_") if raw_id else "unknown"
|
||||
|
||||
title = self._get_element_text(entry, "title") or "No Title"
|
||||
summary = self._get_element_text(entry, "summary") or "No Summary"
|
||||
@@ -97,41 +111,48 @@ class ArxivPaperTool(BaseTool):
|
||||
|
||||
pdf_url = self._extract_pdf_url(entry)
|
||||
|
||||
papers.append({
|
||||
"arxiv_id": arxiv_id,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"authors": authors,
|
||||
"published_date": published,
|
||||
"pdf_url": pdf_url
|
||||
})
|
||||
papers.append(
|
||||
{
|
||||
"arxiv_id": arxiv_id,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"authors": authors,
|
||||
"published_date": published,
|
||||
"pdf_url": pdf_url,
|
||||
}
|
||||
)
|
||||
|
||||
return papers
|
||||
|
||||
@staticmethod
|
||||
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]:
|
||||
elem = entry.find(f'{ArxivPaperTool.ATOM_NAMESPACE}{element_name}')
|
||||
def _get_element_text(entry: ET.Element, element_name: str) -> str | None:
|
||||
elem = entry.find(f"{ArxivPaperTool.ATOM_NAMESPACE}{element_name}")
|
||||
return elem.text.strip() if elem is not None and elem.text else None
|
||||
|
||||
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]:
|
||||
def _extract_pdf_url(self, entry: ET.Element) -> str | None:
|
||||
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||
if link.attrib.get('title', '').lower() == 'pdf':
|
||||
return link.attrib.get('href')
|
||||
if link.attrib.get("title", "").lower() == "pdf":
|
||||
return link.attrib.get("href")
|
||||
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||
href = link.attrib.get('href')
|
||||
if href and 'pdf' in href:
|
||||
href = link.attrib.get("href")
|
||||
if href and "pdf" in href:
|
||||
return href
|
||||
return None
|
||||
|
||||
def _format_paper_result(self, paper: dict) -> str:
|
||||
summary = (paper['summary'][:self.SUMMARY_TRUNCATE_LENGTH] + '...') \
|
||||
if len(paper['summary']) > self.SUMMARY_TRUNCATE_LENGTH else paper['summary']
|
||||
authors_str = ', '.join(paper['authors'])
|
||||
return (f"Title: {paper['title']}\n"
|
||||
f"Authors: {authors_str}\n"
|
||||
f"Published: {paper['published_date']}\n"
|
||||
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
|
||||
f"Summary: {summary}")
|
||||
summary = (
|
||||
(paper["summary"][: self.SUMMARY_TRUNCATE_LENGTH] + "...")
|
||||
if len(paper["summary"]) > self.SUMMARY_TRUNCATE_LENGTH
|
||||
else paper["summary"]
|
||||
)
|
||||
authors_str = ", ".join(paper["authors"])
|
||||
return (
|
||||
f"Title: {paper['title']}\n"
|
||||
f"Authors: {authors_str}\n"
|
||||
f"Published: {paper['published_date']}\n"
|
||||
f"PDF: {paper['pdf_url'] or 'N/A'}\n"
|
||||
f"Summary: {summary}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_save_path(path: str) -> Path:
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import pytest
|
||||
import urllib.error
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
from pathlib import Path
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai_tools import ArxivPaperTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return ArxivPaperTool(download_pdfs=False)
|
||||
|
||||
|
||||
def mock_arxiv_response():
|
||||
return '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||
<entry>
|
||||
<id>http://arxiv.org/abs/1234.5678</id>
|
||||
@@ -20,7 +23,8 @@ def mock_arxiv_response():
|
||||
<author><name>John Doe</name></author>
|
||||
<link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/>
|
||||
</entry>
|
||||
</feed>'''
|
||||
</feed>"""
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_fetch_arxiv_data(mock_urlopen, tool):
|
||||
@@ -31,24 +35,30 @@ def test_fetch_arxiv_data(mock_urlopen, tool):
|
||||
|
||||
results = tool.fetch_arxiv_data("transformer", 1)
|
||||
assert isinstance(results, list)
|
||||
assert results[0]['title'] == "Sample Paper"
|
||||
assert results[0]["title"] == "Sample Paper"
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout"))
|
||||
def test_fetch_arxiv_data_network_error(mock_urlopen, tool):
|
||||
with pytest.raises(urllib.error.URLError):
|
||||
tool.fetch_arxiv_data("transformer", 1)
|
||||
|
||||
|
||||
@patch("urllib.request.urlretrieve")
|
||||
def test_download_pdf_success(mock_urlretrieve):
|
||||
tool = ArxivPaperTool()
|
||||
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf"))
|
||||
mock_urlretrieve.assert_called_once()
|
||||
|
||||
|
||||
@patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied"))
|
||||
def test_download_pdf_oserror(mock_urlretrieve):
|
||||
tool = ArxivPaperTool()
|
||||
with pytest.raises(OSError):
|
||||
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf"))
|
||||
tool.download_pdf(
|
||||
"http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf")
|
||||
)
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
@patch("urllib.request.urlretrieve")
|
||||
@@ -63,6 +73,7 @@ def test_run_with_download(mock_urlretrieve, mock_urlopen):
|
||||
assert "Title: Sample Paper" in output
|
||||
mock_urlretrieve.assert_called_once()
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_run_no_download(mock_urlopen):
|
||||
mock_response = MagicMock()
|
||||
@@ -74,12 +85,14 @@ def test_run_no_download(mock_urlopen):
|
||||
result = tool._run("transformer", 1)
|
||||
assert "Title: Sample Paper" in result
|
||||
|
||||
|
||||
@patch("pathlib.Path.mkdir")
|
||||
def test_validate_save_path_creates_directory(mock_mkdir):
|
||||
path = ArxivPaperTool._validate_save_path("new_folder")
|
||||
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||
assert isinstance(path, Path)
|
||||
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_run_handles_exception(mock_urlopen):
|
||||
mock_urlopen.side_effect = Exception("API failure")
|
||||
@@ -98,16 +111,20 @@ def test_invalid_xml_response(mock_urlopen, tool):
|
||||
with pytest.raises(ET.ParseError):
|
||||
tool.fetch_arxiv_data("quantum", 1)
|
||||
|
||||
|
||||
@patch.object(ArxivPaperTool, "fetch_arxiv_data")
|
||||
def test_run_with_max_results(mock_fetch, tool):
|
||||
mock_fetch.return_value = [{
|
||||
"arxiv_id": f"test_{i}",
|
||||
"title": f"Title {i}",
|
||||
"summary": "Summary",
|
||||
"authors": ["Author"],
|
||||
"published_date": "2023-01-01",
|
||||
"pdf_url": None
|
||||
} for i in range(100)]
|
||||
mock_fetch.return_value = [
|
||||
{
|
||||
"arxiv_id": f"test_{i}",
|
||||
"title": f"Title {i}",
|
||||
"summary": "Summary",
|
||||
"authors": ["Author"],
|
||||
"published_date": "2023-01-01",
|
||||
"pdf_url": None,
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
result = tool._run(search_query="test", max_results=100)
|
||||
assert result.count("Title:") == 100
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar, List, Optional, Type
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
@@ -41,15 +41,17 @@ class BraveSearchTool(BaseTool):
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query."
|
||||
)
|
||||
args_schema: Type[BaseModel] = BraveSearchToolSchema
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
country: Optional[str] = ""
|
||||
country: str | None = ""
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="BRAVE_API_KEY", description="API key for Brave Search", required=True),
|
||||
env_vars: ClassVar[list[EnvVar]] = [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY", description="API key for Brave Search", required=True
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -87,7 +89,9 @@ class BraveSearchTool(BaseTool):
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(self.search_url, headers=headers, params=payload)
|
||||
response = requests.get(
|
||||
self.search_url, headers=headers, params=payload, timeout=30
|
||||
)
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
results = response.json()
|
||||
|
||||
@@ -111,11 +115,10 @@ class BraveSearchTool(BaseTool):
|
||||
|
||||
content = "\n".join(string)
|
||||
except requests.RequestException as e:
|
||||
return f"Error performing search: {str(e)}"
|
||||
return f"Error performing search: {e!s}"
|
||||
except KeyError as e:
|
||||
return f"Error parsing search results: {str(e)}"
|
||||
return f"Error parsing search results: {e!s}"
|
||||
if save_file:
|
||||
_save_results_to_file(content)
|
||||
return f"\nSearch results: {content}\n"
|
||||
else:
|
||||
return content
|
||||
return content
|
||||
|
||||
@@ -2,8 +2,4 @@ from .brightdata_dataset import BrightDataDatasetTool
|
||||
from .brightdata_serp import BrightDataSearchTool
|
||||
from .brightdata_unlocker import BrightDataWebUnlockerTool
|
||||
|
||||
__all__ = [
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool"
|
||||
]
|
||||
__all__ = ["BrightDataDatasetTool", "BrightDataSearchTool", "BrightDataWebUnlockerTool"]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com"
|
||||
DEFAULT_TIMEOUT: int = 600
|
||||
@@ -16,8 +17,12 @@ class BrightDataConfig(BaseModel):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
|
||||
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
|
||||
DEFAULT_POLLING_INTERVAL=int(os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1"))
|
||||
DEFAULT_POLLING_INTERVAL=int(
|
||||
os.environ.get("BRIGHTDATA_DEFAULT_POLLING_INTERVAL", "1")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class BrightDataDatasetToolException(Exception):
|
||||
"""Exception raised for custom error in the application."""
|
||||
|
||||
@@ -43,15 +48,16 @@ class BrightDataDatasetToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
dataset_type: str = Field(..., description="The Bright Data Dataset Type")
|
||||
format: Optional[str] = Field(
|
||||
format: str | None = Field(
|
||||
default="json", description="Response format (json by default)"
|
||||
)
|
||||
url: str = Field(..., description="The URL to extract data from")
|
||||
zipcode: Optional[str] = Field(default=None, description="Optional zipcode")
|
||||
additional_params: Optional[Dict[str, Any]] = Field(
|
||||
zipcode: str | None = Field(default=None, description="Optional zipcode")
|
||||
additional_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional params if any"
|
||||
)
|
||||
|
||||
|
||||
config = BrightDataConfig.from_env()
|
||||
|
||||
BRIGHTDATA_API_URL = config.API_URL
|
||||
@@ -404,14 +410,21 @@ class BrightDataDatasetTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data Dataset Tool"
|
||||
description: str = "Scrapes structured data using Bright Data Dataset API from a URL and optional input parameters"
|
||||
args_schema: Type[BaseModel] = BrightDataDatasetToolSchema
|
||||
dataset_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
args_schema: type[BaseModel] = BrightDataDatasetToolSchema
|
||||
dataset_type: str | None = None
|
||||
url: str | None = None
|
||||
format: str = "json"
|
||||
zipcode: Optional[str] = None
|
||||
additional_params: Optional[Dict[str, Any]] = None
|
||||
zipcode: str | None = None
|
||||
additional_params: dict[str, Any] | None = None
|
||||
|
||||
def __init__(self, dataset_type: str = None, url: str = None, format: str = "json", zipcode: str = None, additional_params: Dict[str, Any] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_type: str | None = None,
|
||||
url: str | None = None,
|
||||
format: str = "json",
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_type = dataset_type
|
||||
self.url = url
|
||||
@@ -427,10 +440,10 @@ class BrightDataDatasetTool(BaseTool):
|
||||
dataset_type: str,
|
||||
output_format: str,
|
||||
url: str,
|
||||
zipcode: Optional[str] = None,
|
||||
additional_params: Optional[Dict[str, Any]] = None,
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
polling_interval: int = 1,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""
|
||||
Asynchronously trigger and poll Bright Data dataset scraping.
|
||||
|
||||
@@ -509,7 +522,7 @@ class BrightDataDatasetTool(BaseTool):
|
||||
if status_data.get("status") == "ready":
|
||||
print("Job is ready")
|
||||
break
|
||||
elif status_data.get("status") == "error":
|
||||
if status_data.get("status") == "error":
|
||||
raise BrightDataDatasetToolException(
|
||||
f"Job failed: {status_data}", 0
|
||||
)
|
||||
@@ -530,7 +543,15 @@ class BrightDataDatasetTool(BaseTool):
|
||||
|
||||
return await snapshot_response.text()
|
||||
|
||||
def _run(self, url: str = None, dataset_type: str = None, format: str = None, zipcode: str = None, additional_params: Dict[str, Any] = None, **kwargs: Any) -> Any:
|
||||
def _run(
|
||||
self,
|
||||
url: str | None = None,
|
||||
dataset_type: str | None = None,
|
||||
format: str | None = None,
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
dataset_type = dataset_type or self.dataset_type
|
||||
output_format = format or self.format
|
||||
url = url or self.url
|
||||
@@ -538,7 +559,9 @@ class BrightDataDatasetTool(BaseTool):
|
||||
additional_params = additional_params or self.additional_params
|
||||
|
||||
if not dataset_type:
|
||||
raise ValueError("dataset_type is required either in constructor or method call")
|
||||
raise ValueError(
|
||||
"dataset_type is required either in constructor or method call"
|
||||
)
|
||||
if not url:
|
||||
raise ValueError("url is required either in constructor or method call")
|
||||
|
||||
@@ -563,8 +586,10 @@ class BrightDataDatasetTool(BaseTool):
|
||||
)
|
||||
)
|
||||
except TimeoutError as e:
|
||||
return f"Timeout Exception occured in method : get_dataset_data_async. Details - {str(e)}"
|
||||
return f"Timeout Exception occured in method : get_dataset_data_async. Details - {e!s}"
|
||||
except BrightDataDatasetToolException as e:
|
||||
return f"Exception occured in method : get_dataset_data_async. Details - {str(e)}"
|
||||
return (
|
||||
f"Exception occured in method : get_dataset_data_async. Details - {e!s}"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Bright Data API error: {str(e)}"
|
||||
return f"Bright Data API error: {e!s}"
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request")
|
||||
API_URL=os.environ.get(
|
||||
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BrightDataSearchToolSchema(BaseModel):
|
||||
"""
|
||||
Schema that defines the input arguments for the BrightDataSearchToolSchema.
|
||||
@@ -30,27 +34,27 @@ class BrightDataSearchToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
search_engine: Optional[str] = Field(
|
||||
search_engine: str | None = Field(
|
||||
default="google",
|
||||
description="Search engine domain (e.g., 'google', 'bing', 'yandex')",
|
||||
)
|
||||
country: Optional[str] = Field(
|
||||
country: str | None = Field(
|
||||
default="us",
|
||||
description="Two-letter country code for geo-targeting (e.g., 'us', 'gb')",
|
||||
)
|
||||
language: Optional[str] = Field(
|
||||
language: str | None = Field(
|
||||
default="en",
|
||||
description="Language code (e.g., 'en', 'es') used in the query URL",
|
||||
)
|
||||
search_type: Optional[str] = Field(
|
||||
search_type: str | None = Field(
|
||||
default=None,
|
||||
description="Type of search (e.g., 'isch' for images, 'nws' for news)",
|
||||
)
|
||||
device_type: Optional[str] = Field(
|
||||
device_type: str | None = Field(
|
||||
default="desktop",
|
||||
description="Device type to simulate (e.g., 'mobile', 'desktop', 'ios')",
|
||||
)
|
||||
parse_results: Optional[bool] = Field(
|
||||
parse_results: bool | None = Field(
|
||||
default=True,
|
||||
description="Whether to parse and return JSON (True) or raw HTML/text (False)",
|
||||
)
|
||||
@@ -75,20 +79,29 @@ class BrightDataSearchTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data SERP Search"
|
||||
description: str = "Tool to perform web search using Bright Data SERP API."
|
||||
args_schema: Type[BaseModel] = BrightDataSearchToolSchema
|
||||
args_schema: type[BaseModel] = BrightDataSearchToolSchema
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
query: Optional[str] = None
|
||||
query: str | None = None
|
||||
search_engine: str = "google"
|
||||
country: str = "us"
|
||||
language: str = "en"
|
||||
search_type: Optional[str] = None
|
||||
search_type: str | None = None
|
||||
device_type: str = "desktop"
|
||||
parse_results: bool = True
|
||||
|
||||
def __init__(self, query: str = None, search_engine: str = "google", country: str = "us", language: str = "en", search_type: str = None, device_type: str = "desktop", parse_results: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
query: str | None = None,
|
||||
search_engine: str = "google",
|
||||
country: str = "us",
|
||||
language: str = "en",
|
||||
search_type: str | None = None,
|
||||
device_type: str = "desktop",
|
||||
parse_results: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.base_url = self._config.API_URL
|
||||
self.query = query
|
||||
@@ -109,11 +122,21 @@ class BrightDataSearchTool(BaseTool):
|
||||
def get_search_url(self, engine: str, query: str):
|
||||
if engine == "yandex":
|
||||
return f"https://yandex.com/search/?text=${query}"
|
||||
elif engine == "bing":
|
||||
if engine == "bing":
|
||||
return f"https://www.bing.com/search?q=${query}"
|
||||
return f"https://www.google.com/search?q=${query}"
|
||||
|
||||
def _run(self, query: str = None, search_engine: str = None, country: str = None, language: str = None, search_type: str = None, device_type: str = None, parse_results: bool = None, **kwargs) -> Any:
|
||||
def _run(
|
||||
self,
|
||||
query: str | None = None,
|
||||
search_engine: str | None = None,
|
||||
country: str | None = None,
|
||||
language: str | None = None,
|
||||
search_type: str | None = None,
|
||||
device_type: str | None = None,
|
||||
parse_results: bool | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Executes a search query using Bright Data SERP API and returns results.
|
||||
|
||||
@@ -137,7 +160,9 @@ class BrightDataSearchTool(BaseTool):
|
||||
language = language or self.language
|
||||
search_type = search_type or self.search_type
|
||||
device_type = device_type or self.device_type
|
||||
parse_results = parse_results if parse_results is not None else self.parse_results
|
||||
parse_results = (
|
||||
parse_results if parse_results is not None else self.parse_results
|
||||
)
|
||||
results_count = kwargs.get("results_count", "10")
|
||||
|
||||
# Validate required parameters
|
||||
@@ -161,7 +186,7 @@ class BrightDataSearchTool(BaseTool):
|
||||
params.append(f"num={results_count}")
|
||||
|
||||
if parse_results:
|
||||
params.append(f"brd_json=1")
|
||||
params.append("brd_json=1")
|
||||
|
||||
if search_type:
|
||||
if search_type == "jobs":
|
||||
@@ -202,6 +227,6 @@ class BrightDataSearchTool(BaseTool):
|
||||
return response.text
|
||||
|
||||
except requests.RequestException as e:
|
||||
return f"Error performing BrightData search: {str(e)}"
|
||||
return f"Error performing BrightData search: {e!s}"
|
||||
except Exception as e:
|
||||
return f"Error fetching results: {str(e)}"
|
||||
return f"Error fetching results: {e!s}"
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com/request")
|
||||
API_URL=os.environ.get(
|
||||
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BrightDataUnlockerToolSchema(BaseModel):
|
||||
"""
|
||||
Pydantic schema for input parameters used by the BrightDataWebUnlockerTool.
|
||||
@@ -28,10 +32,10 @@ class BrightDataUnlockerToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="URL to perform the web scraping")
|
||||
format: Optional[str] = Field(
|
||||
format: str | None = Field(
|
||||
default="raw", description="Response format (raw is standard)"
|
||||
)
|
||||
data_format: Optional[str] = Field(
|
||||
data_format: str | None = Field(
|
||||
default="markdown", description="Response data format (html by default)"
|
||||
)
|
||||
|
||||
@@ -59,16 +63,18 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data Web Unlocker Scraping"
|
||||
description: str = "Tool to perform web scraping using Bright Data Web Unlocker"
|
||||
args_schema: Type[BaseModel] = BrightDataUnlockerToolSchema
|
||||
args_schema: type[BaseModel] = BrightDataUnlockerToolSchema
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
url: Optional[str] = None
|
||||
url: str | None = None
|
||||
format: str = "raw"
|
||||
data_format: str = "markdown"
|
||||
|
||||
def __init__(self, url: str = None, format: str = "raw", data_format: str = "markdown"):
|
||||
def __init__(
|
||||
self, url: str | None = None, format: str = "raw", data_format: str = "markdown"
|
||||
):
|
||||
super().__init__()
|
||||
self.base_url = self._config.API_URL
|
||||
self.url = url
|
||||
@@ -82,7 +88,13 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
if not self.zone:
|
||||
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
|
||||
|
||||
def _run(self, url: str = None, format: str = None, data_format: str = None, **kwargs: Any) -> Any:
|
||||
def _run(
|
||||
self,
|
||||
url: str | None = None,
|
||||
format: str | None = None,
|
||||
data_format: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
url = url or self.url
|
||||
format = format or self.format
|
||||
data_format = data_format or self.data_format
|
||||
@@ -119,4 +131,4 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
except requests.RequestException as e:
|
||||
return f"HTTP Error performing BrightData Web Unlocker Scrape: {e}\nResponse: {getattr(e.response, 'text', '')}"
|
||||
except Exception as e:
|
||||
return f"Error fetching results: {str(e)}"
|
||||
return f"Error fetching results: {e!s}"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,26 +12,34 @@ class BrowserbaseLoadToolSchema(BaseModel):
|
||||
class BrowserbaseLoadTool(BaseTool):
|
||||
name: str = "Browserbase web load tool"
|
||||
description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: Optional[bool] = False
|
||||
session_id: Optional[str] = None
|
||||
proxy: Optional[bool] = None
|
||||
browserbase: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["browserbase"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="BROWSERBASE_API_KEY", description="API key for Browserbase services", required=False),
|
||||
EnvVar(name="BROWSERBASE_PROJECT_ID", description="Project ID for Browserbase services", required=False),
|
||||
args_schema: type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: str | None = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: str | None = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: bool | None = False
|
||||
session_id: str | None = None
|
||||
proxy: bool | None = None
|
||||
browserbase: Any | None = None
|
||||
package_dependencies: list[str] = ["browserbase"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="BROWSERBASE_API_KEY",
|
||||
description="API key for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="BROWSERBASE_PROJECT_ID",
|
||||
description="Project ID for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
text_content: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
api_key: str | None = None,
|
||||
project_id: str | None = None,
|
||||
text_content: bool | None = False,
|
||||
session_id: str | None = None,
|
||||
proxy: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -31,9 +24,9 @@ class CodeDocsSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Code Docs content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||
args_schema: type[BaseModel] = CodeDocsSearchToolSchema
|
||||
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, docs_url: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
@@ -42,15 +35,17 @@ class CodeDocsSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docs_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docs_url: Optional[str] = None,
|
||||
docs_url: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -8,17 +8,16 @@ potentially unsafe operations and importing restricted modules.
|
||||
import importlib.util
|
||||
import os
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.printer import Printer
|
||||
from docker import DockerClient
|
||||
from docker import from_env as docker_from_env
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.printer import Printer
|
||||
|
||||
|
||||
class CodeInterpreterSchema(BaseModel):
|
||||
"""Schema for defining inputs to the CodeInterpreterTool.
|
||||
@@ -32,7 +31,7 @@ class CodeInterpreterSchema(BaseModel):
|
||||
description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code",
|
||||
)
|
||||
|
||||
libraries_used: List[str] = Field(
|
||||
libraries_used: list[str] = Field(
|
||||
...,
|
||||
description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4",
|
||||
)
|
||||
@@ -74,9 +73,9 @@ class SandboxPython:
|
||||
@staticmethod
|
||||
def restricted_import(
|
||||
name: str,
|
||||
custom_globals: Optional[Dict[str, Any]] = None,
|
||||
custom_locals: Optional[Dict[str, Any]] = None,
|
||||
fromlist: Optional[List[str]] = None,
|
||||
custom_globals: dict[str, Any] | None = None,
|
||||
custom_locals: dict[str, Any] | None = None,
|
||||
fromlist: list[str] | None = None,
|
||||
level: int = 0,
|
||||
) -> ModuleType:
|
||||
"""A restricted import function that blocks importing of unsafe modules.
|
||||
@@ -99,7 +98,7 @@ class SandboxPython:
|
||||
return __import__(name, custom_globals, custom_locals, fromlist or (), level)
|
||||
|
||||
@staticmethod
|
||||
def safe_builtins() -> Dict[str, Any]:
|
||||
def safe_builtins() -> dict[str, Any]:
|
||||
"""Creates a dictionary of built-in functions with unsafe ones removed.
|
||||
|
||||
Returns:
|
||||
@@ -116,7 +115,7 @@ class SandboxPython:
|
||||
return safe_builtins
|
||||
|
||||
@staticmethod
|
||||
def exec(code: str, locals: Dict[str, Any]) -> None:
|
||||
def exec(code: str, locals: dict[str, Any]) -> None:
|
||||
"""Executes Python code in a restricted environment.
|
||||
|
||||
Args:
|
||||
@@ -136,11 +135,11 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
name: str = "Code Interpreter"
|
||||
description: str = "Interprets Python3 code strings with a final print statement."
|
||||
args_schema: Type[BaseModel] = CodeInterpreterSchema
|
||||
args_schema: type[BaseModel] = CodeInterpreterSchema
|
||||
default_image_tag: str = "code-interpreter:latest"
|
||||
code: Optional[str] = None
|
||||
user_dockerfile_path: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
code: str | None = None
|
||||
user_dockerfile_path: str | None = None
|
||||
user_docker_base_url: str | None = None
|
||||
unsafe_mode: bool = False
|
||||
|
||||
@staticmethod
|
||||
@@ -205,10 +204,9 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
if self.unsafe_mode:
|
||||
return self.run_code_unsafe(code, libraries_used)
|
||||
else:
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
|
||||
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: list[str]) -> None:
|
||||
"""Installs required Python libraries in the Docker container.
|
||||
|
||||
Args:
|
||||
@@ -278,7 +276,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
Printer.print("Docker is not installed", color="bold_purple")
|
||||
return False
|
||||
|
||||
def run_code_safety(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code in the safest available environment.
|
||||
|
||||
Attempts to run code in Docker if available, falls back to a restricted
|
||||
@@ -293,10 +291,9 @@ class CodeInterpreterTool(BaseTool):
|
||||
"""
|
||||
if self._check_docker_available():
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
else:
|
||||
return self.run_code_in_restricted_sandbox(code)
|
||||
return self.run_code_in_restricted_sandbox(code)
|
||||
|
||||
def run_code_in_docker(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs Python code in a Docker container for safe isolation.
|
||||
|
||||
Creates a Docker container, installs the required libraries, executes the code,
|
||||
@@ -342,9 +339,9 @@ class CodeInterpreterTool(BaseTool):
|
||||
SandboxPython.exec(code=code, locals=exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_unsafe(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code directly on the host machine without any safety restrictions.
|
||||
|
||||
WARNING: This mode is unsafe and should only be used in trusted environments
|
||||
@@ -370,4 +367,4 @@ class CodeInterpreterTool(BaseTool):
|
||||
exec(code, {}, exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
@@ -12,8 +12,12 @@ class ComposioTool(BaseTool):
|
||||
"""Wrapper for composio tools."""
|
||||
|
||||
composio_action: t.Callable
|
||||
env_vars: t.List[EnvVar] = [
|
||||
EnvVar(name="COMPOSIO_API_KEY", description="API key for Composio services", required=True),
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="COMPOSIO_API_KEY",
|
||||
description="API key for Composio services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
@@ -31,7 +35,7 @@ class ComposioTool(BaseTool):
|
||||
return
|
||||
|
||||
connections = t.cast(
|
||||
t.List[ConnectedAccountModel],
|
||||
list[ConnectedAccountModel],
|
||||
toolset.client.connected_accounts.get(),
|
||||
)
|
||||
if tool.app not in [connection.appUniqueId for connection in connections]:
|
||||
@@ -66,7 +70,7 @@ class ComposioTool(BaseTool):
|
||||
schema = action_schema.model_dump(exclude_none=True)
|
||||
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
|
||||
|
||||
def function(**kwargs: t.Any) -> t.Dict:
|
||||
def function(**kwargs: t.Any) -> dict:
|
||||
"""Wrapper function for composio action."""
|
||||
return toolset.execute_action(
|
||||
action=Action(schema["name"]),
|
||||
@@ -93,10 +97,10 @@ class ComposioTool(BaseTool):
|
||||
def from_app(
|
||||
cls,
|
||||
*apps: t.Any,
|
||||
tags: t.Optional[t.List[str]] = None,
|
||||
use_case: t.Optional[str] = None,
|
||||
tags: list[str] | None = None,
|
||||
use_case: str | None = None,
|
||||
**kwargs: t.Any,
|
||||
) -> t.List[te.Self]:
|
||||
) -> list[te.Self]:
|
||||
"""Create toolset from an app."""
|
||||
if len(apps) == 0:
|
||||
raise ValueError("You need to provide at least one app name")
|
||||
|
||||
@@ -1,32 +1,36 @@
|
||||
from typing import Any, Optional, Type, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
|
||||
|
||||
class ContextualAICreateAgentSchema(BaseModel):
|
||||
"""Schema for contextual create agent tool."""
|
||||
|
||||
agent_name: str = Field(..., description="Name for the new agent")
|
||||
agent_description: str = Field(..., description="Description for the new agent")
|
||||
datastore_name: str = Field(..., description="Name for the new datastore")
|
||||
document_paths: List[str] = Field(..., description="List of file paths to upload")
|
||||
document_paths: list[str] = Field(..., description="List of file paths to upload")
|
||||
|
||||
|
||||
class ContextualAICreateAgentTool(BaseTool):
|
||||
"""Tool to create Contextual AI RAG agents with documents."""
|
||||
|
||||
|
||||
name: str = "Contextual AI Create Agent Tool"
|
||||
description: str = "Create a new Contextual AI RAG agent with documents and datastore"
|
||||
args_schema: Type[BaseModel] = ContextualAICreateAgentSchema
|
||||
|
||||
description: str = (
|
||||
"Create a new Contextual AI RAG agent with documents and datastore"
|
||||
)
|
||||
args_schema: type[BaseModel] = ContextualAICreateAgentSchema
|
||||
|
||||
api_key: str
|
||||
contextual_client: Any = None
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = ["contextual-client"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from contextual import ContextualAI
|
||||
|
||||
self.contextual_client = ContextualAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -38,34 +42,38 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
agent_name: str,
|
||||
agent_description: str,
|
||||
datastore_name: str,
|
||||
document_paths: List[str]
|
||||
document_paths: list[str],
|
||||
) -> str:
|
||||
"""Create a complete RAG pipeline with documents."""
|
||||
try:
|
||||
import os
|
||||
|
||||
|
||||
# Create datastore
|
||||
datastore = self.contextual_client.datastores.create(name=datastore_name)
|
||||
datastore_id = datastore.id
|
||||
|
||||
|
||||
# Upload documents
|
||||
document_ids = []
|
||||
for doc_path in document_paths:
|
||||
if not os.path.exists(doc_path):
|
||||
raise FileNotFoundError(f"Document not found: {doc_path}")
|
||||
|
||||
with open(doc_path, 'rb') as f:
|
||||
ingestion_result = self.contextual_client.datastores.documents.ingest(datastore_id, file=f)
|
||||
|
||||
with open(doc_path, "rb") as f:
|
||||
ingestion_result = (
|
||||
self.contextual_client.datastores.documents.ingest(
|
||||
datastore_id, file=f
|
||||
)
|
||||
)
|
||||
document_ids.append(ingestion_result.id)
|
||||
|
||||
|
||||
# Create agent
|
||||
agent = self.contextual_client.agents.create(
|
||||
name=agent_name,
|
||||
description=agent_description,
|
||||
datastore_ids=[datastore_id]
|
||||
datastore_ids=[datastore_id],
|
||||
)
|
||||
|
||||
|
||||
return f"Successfully created agent '{agent_name}' with ID: {agent.id} and datastore ID: {datastore_id}. Uploaded {len(document_ids)} documents."
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to create agent with documents: {str(e)}"
|
||||
return f"Failed to create agent with documents: {e!s}"
|
||||
|
||||
@@ -1,51 +1,62 @@
|
||||
from typing import Any, Optional, Type, List
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContextualAIParseSchema(BaseModel):
|
||||
"""Schema for contextual parse tool."""
|
||||
|
||||
file_path: str = Field(..., description="Path to the document to parse")
|
||||
parse_mode: str = Field(default="standard", description="Parsing mode")
|
||||
figure_caption_mode: str = Field(default="concise", description="Figure caption mode")
|
||||
enable_document_hierarchy: bool = Field(default=True, description="Enable document hierarchy")
|
||||
page_range: Optional[str] = Field(default=None, description="Page range to parse (e.g., '0-5')")
|
||||
output_types: List[str] = Field(default=["markdown-per-page"], description="List of output types")
|
||||
figure_caption_mode: str = Field(
|
||||
default="concise", description="Figure caption mode"
|
||||
)
|
||||
enable_document_hierarchy: bool = Field(
|
||||
default=True, description="Enable document hierarchy"
|
||||
)
|
||||
page_range: str | None = Field(
|
||||
default=None, description="Page range to parse (e.g., '0-5')"
|
||||
)
|
||||
output_types: list[str] = Field(
|
||||
default=["markdown-per-page"], description="List of output types"
|
||||
)
|
||||
|
||||
|
||||
class ContextualAIParseTool(BaseTool):
|
||||
"""Tool to parse documents using Contextual AI's parser."""
|
||||
|
||||
|
||||
name: str = "Contextual AI Document Parser"
|
||||
description: str = "Parse documents using Contextual AI's advanced document parser"
|
||||
args_schema: Type[BaseModel] = ContextualAIParseSchema
|
||||
|
||||
args_schema: type[BaseModel] = ContextualAIParseSchema
|
||||
|
||||
api_key: str
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = ["contextual-client"]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
file_path: str,
|
||||
self,
|
||||
file_path: str,
|
||||
parse_mode: str = "standard",
|
||||
figure_caption_mode: str = "concise",
|
||||
enable_document_hierarchy: bool = True,
|
||||
page_range: Optional[str] = None,
|
||||
output_types: List[str] = ["markdown-per-page"]
|
||||
page_range: str | None = None,
|
||||
output_types: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Parse a document using Contextual AI's parser."""
|
||||
if output_types is None:
|
||||
output_types = ["markdown-per-page"]
|
||||
try:
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import requests
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Document not found: {file_path}")
|
||||
|
||||
base_url = "https://api.contextual.ai/v1"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"authorization": f"Bearer {self.api_key}"
|
||||
"authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
# Submit parse job
|
||||
@@ -63,17 +74,17 @@ class ContextualAIParseTool(BaseTool):
|
||||
file = {"raw_file": fp}
|
||||
result = requests.post(url, headers=headers, data=config, files=file)
|
||||
response = json.loads(result.text)
|
||||
job_id = response['job_id']
|
||||
job_id = response["job_id"]
|
||||
|
||||
# Monitor job status
|
||||
status_url = f"{base_url}/parse/jobs/{job_id}/status"
|
||||
while True:
|
||||
result = requests.get(status_url, headers=headers)
|
||||
parse_response = json.loads(result.text)['status']
|
||||
parse_response = json.loads(result.text)["status"]
|
||||
|
||||
if parse_response == "completed":
|
||||
break
|
||||
elif parse_response == "failed":
|
||||
if parse_response == "failed":
|
||||
raise RuntimeError("Document parsing failed")
|
||||
|
||||
sleep(5)
|
||||
@@ -89,4 +100,4 @@ class ContextualAIParseTool(BaseTool):
|
||||
return json.dumps(json.loads(result.text), indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to parse document: {str(e)}"
|
||||
return f"Failed to parse document: {e!s}"
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
from typing import Any, Optional, Type, List
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import asyncio
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
||||
class ContextualAIQuerySchema(BaseModel):
|
||||
"""Schema for contextual query tool."""
|
||||
|
||||
query: str = Field(..., description="Query to send to the Contextual AI agent.")
|
||||
agent_id: str = Field(..., description="ID of the Contextual AI agent to query")
|
||||
datastore_id: Optional[str] = Field(None, description="Optional datastore ID for document readiness verification")
|
||||
datastore_id: str | None = Field(
|
||||
None, description="Optional datastore ID for document readiness verification"
|
||||
)
|
||||
|
||||
|
||||
class ContextualAIQueryTool(BaseTool):
|
||||
"""Tool to query Contextual AI RAG agents."""
|
||||
|
||||
|
||||
name: str = "Contextual AI Query Tool"
|
||||
description: str = "Use this tool to query a Contextual AI RAG agent with access to your documents"
|
||||
args_schema: Type[BaseModel] = ContextualAIQuerySchema
|
||||
|
||||
description: str = (
|
||||
"Use this tool to query a Contextual AI RAG agent with access to your documents"
|
||||
)
|
||||
args_schema: type[BaseModel] = ContextualAIQuerySchema
|
||||
|
||||
api_key: str
|
||||
contextual_client: Any = None
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = ["contextual-client"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from contextual import ContextualAI
|
||||
|
||||
self.contextual_client = ContextualAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -41,13 +47,17 @@ class ContextualAIQueryTool(BaseTool):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
documents = data.get('documents', [])
|
||||
return not any(doc.get('status') in ('processing', 'pending') for doc in documents)
|
||||
documents = data.get("documents", [])
|
||||
return not any(
|
||||
doc.get("status") in ("processing", "pending") for doc in documents
|
||||
)
|
||||
return True
|
||||
|
||||
async def _wait_for_documents_async(self, datastore_id: str, max_attempts: int = 20, interval: float = 30.0) -> bool:
|
||||
async def _wait_for_documents_async(
|
||||
self, datastore_id: str, max_attempts: int = 20, interval: float = 30.0
|
||||
) -> bool:
|
||||
"""Asynchronously poll until documents are ready, exiting early if possible."""
|
||||
for attempt in range(max_attempts):
|
||||
for _attempt in range(max_attempts):
|
||||
ready = await asyncio.to_thread(self._check_documents_ready, datastore_id)
|
||||
if ready:
|
||||
return True
|
||||
@@ -55,10 +65,10 @@ class ContextualAIQueryTool(BaseTool):
|
||||
print("Processing documents ...")
|
||||
return True # give up but don't fail hard
|
||||
|
||||
def _run(self, query: str, agent_id: str, datastore_id: Optional[str] = None) -> str:
|
||||
def _run(self, query: str, agent_id: str, datastore_id: str | None = None) -> str:
|
||||
if not agent_id:
|
||||
raise ValueError("Agent ID is required to query the Contextual AI agent")
|
||||
|
||||
|
||||
if datastore_id:
|
||||
ready = self._check_documents_ready(datastore_id)
|
||||
if not ready:
|
||||
@@ -69,31 +79,42 @@ class ContextualAIQueryTool(BaseTool):
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop
|
||||
# Already inside an event loop
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply(loop)
|
||||
loop.run_until_complete(self._wait_for_documents_async(datastore_id))
|
||||
loop.run_until_complete(
|
||||
self._wait_for_documents_async(datastore_id)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to apply nest_asyncio: {str(e)}")
|
||||
print(f"Failed to apply nest_asyncio: {e!s}")
|
||||
else:
|
||||
asyncio.run(self._wait_for_documents_async(datastore_id))
|
||||
else:
|
||||
print("Warning: No datastore_id provided. Document status checking disabled.")
|
||||
print(
|
||||
"Warning: No datastore_id provided. Document status checking disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.contextual_client.agents.query.create(
|
||||
agent_id=agent_id,
|
||||
messages=[{"role": "user", "content": query}]
|
||||
agent_id=agent_id, messages=[{"role": "user", "content": query}]
|
||||
)
|
||||
if hasattr(response, 'content'):
|
||||
if hasattr(response, "content"):
|
||||
return response.content
|
||||
elif hasattr(response, 'message'):
|
||||
return response.message.content if hasattr(response.message, 'content') else str(response.message)
|
||||
elif hasattr(response, 'messages') and len(response.messages) > 0:
|
||||
if hasattr(response, "message"):
|
||||
return (
|
||||
response.message.content
|
||||
if hasattr(response.message, "content")
|
||||
else str(response.message)
|
||||
)
|
||||
if hasattr(response, "messages") and len(response.messages) > 0:
|
||||
last_message = response.messages[-1]
|
||||
return last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
else:
|
||||
return str(response)
|
||||
return (
|
||||
last_message.content
|
||||
if hasattr(last_message, "content")
|
||||
else str(last_message)
|
||||
)
|
||||
return str(response)
|
||||
except Exception as e:
|
||||
return f"Error querying Contextual AI agent: {str(e)}"
|
||||
return f"Error querying Contextual AI agent: {e!s}"
|
||||
|
||||
@@ -1,68 +1,79 @@
|
||||
from typing import Any, Optional, Type, List
|
||||
from typing import ClassVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContextualAIRerankSchema(BaseModel):
|
||||
"""Schema for contextual rerank tool."""
|
||||
|
||||
query: str = Field(..., description="The search query to rerank documents against")
|
||||
documents: List[str] = Field(..., description="List of document texts to rerank")
|
||||
instruction: Optional[str] = Field(default=None, description="Optional instruction for reranking behavior")
|
||||
metadata: Optional[List[str]] = Field(default=None, description="Optional metadata for each document")
|
||||
model: str = Field(default="ctxl-rerank-en-v1-instruct", description="Reranker model to use")
|
||||
documents: list[str] = Field(..., description="List of document texts to rerank")
|
||||
instruction: str | None = Field(
|
||||
default=None, description="Optional instruction for reranking behavior"
|
||||
)
|
||||
metadata: list[str] | None = Field(
|
||||
default=None, description="Optional metadata for each document"
|
||||
)
|
||||
model: str = Field(
|
||||
default="ctxl-rerank-en-v1-instruct", description="Reranker model to use"
|
||||
)
|
||||
|
||||
|
||||
class ContextualAIRerankTool(BaseTool):
|
||||
"""Tool to rerank documents using Contextual AI's instruction-following reranker."""
|
||||
|
||||
|
||||
name: str = "Contextual AI Document Reranker"
|
||||
description: str = "Rerank documents using Contextual AI's instruction-following reranker"
|
||||
args_schema: Type[BaseModel] = ContextualAIRerankSchema
|
||||
|
||||
description: str = (
|
||||
"Rerank documents using Contextual AI's instruction-following reranker"
|
||||
)
|
||||
args_schema: type[BaseModel] = ContextualAIRerankSchema
|
||||
|
||||
api_key: str
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: ClassVar[list[str]] = ["contextual-client"]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
instruction: Optional[str] = None,
|
||||
metadata: Optional[List[str]] = None,
|
||||
model: str = "ctxl-rerank-en-v1-instruct"
|
||||
documents: list[str],
|
||||
instruction: str | None = None,
|
||||
metadata: list[str] | None = None,
|
||||
model: str = "ctxl-rerank-en-v1-instruct",
|
||||
) -> str:
|
||||
"""Rerank documents using Contextual AI's instruction-following reranker."""
|
||||
try:
|
||||
import requests
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
base_url = "https://api.contextual.ai/v1"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.api_key}"
|
||||
"authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"model": model
|
||||
}
|
||||
payload = {"query": query, "documents": documents, "model": model}
|
||||
|
||||
if instruction:
|
||||
payload["instruction"] = instruction
|
||||
|
||||
if metadata:
|
||||
if len(metadata) != len(documents):
|
||||
raise ValueError("Metadata list must have the same length as documents list")
|
||||
raise ValueError(
|
||||
"Metadata list must have the same length as documents list"
|
||||
)
|
||||
payload["metadata"] = metadata
|
||||
|
||||
rerank_url = f"{base_url}/rerank"
|
||||
result = requests.post(rerank_url, json=payload, headers=headers)
|
||||
result = requests.post(rerank_url, json=payload, headers=headers, timeout=30)
|
||||
|
||||
if result.status_code != 200:
|
||||
raise RuntimeError(f"Reranker API returned status {result.status_code}: {result.text}")
|
||||
raise RuntimeError(
|
||||
f"Reranker API returned status {result.status_code}: {result.text}"
|
||||
)
|
||||
|
||||
return json.dumps(result.json(), indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to rerank documents: {str(e)}"
|
||||
return f"Failed to rerank documents: {e!s}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Type, List, Dict, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import couchbase.search as search
|
||||
@@ -29,30 +29,33 @@ class CouchbaseToolSchema(BaseModel):
|
||||
description="The query to search retrieve relevant information from the Couchbase database. Pass only the query, not the question.",
|
||||
)
|
||||
|
||||
|
||||
class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
"""Tool to search the Couchbase database"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
name: str = "CouchbaseFTSVectorSearchTool"
|
||||
description: str = "A tool to search the Couchbase database for relevant information on internal documents."
|
||||
args_schema: Type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Optional[Cluster]] = None
|
||||
collection_name: Optional[str] = None,
|
||||
scope_name: Optional[str] = None,
|
||||
bucket_name: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
embedding_key: Optional[str] = Field(
|
||||
args_schema: type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Cluster | None] = None
|
||||
collection_name: str | None = (None,)
|
||||
scope_name: str | None = (None,)
|
||||
bucket_name: str | None = (None,)
|
||||
index_name: str | None = (None,)
|
||||
embedding_key: str | None = Field(
|
||||
default="embedding",
|
||||
description="Name of the field in the search index that stores the vector"
|
||||
description="Name of the field in the search index that stores the vector",
|
||||
)
|
||||
scoped_index: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Specify whether the index is scoped. Is True by default."
|
||||
),
|
||||
limit: Optional[int] = Field(default=3)
|
||||
embedding_function: SkipValidation[Callable[[str], List[float]]] = Field(
|
||||
scoped_index: bool | None = (
|
||||
Field(
|
||||
default=True,
|
||||
description="Specify whether the index is scoped. Is True by default.",
|
||||
),
|
||||
)
|
||||
limit: int | None = Field(default=3)
|
||||
embedding_function: SkipValidation[Callable[[str], list[float]]] = Field(
|
||||
default=None,
|
||||
description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database."
|
||||
description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database.",
|
||||
)
|
||||
|
||||
def _check_bucket_exists(self) -> bool:
|
||||
@@ -67,7 +70,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
def _check_scope_and_collection_exists(self) -> bool:
|
||||
"""Check if the scope and collection exists in the linked Couchbase bucket
|
||||
Raises a ValueError if either is not found"""
|
||||
scope_collection_map: Dict[str, Any] = {}
|
||||
scope_collection_map: dict[str, Any] = {}
|
||||
|
||||
# Get a list of all scopes in the bucket
|
||||
for scope in self._bucket.collections().get_all_scopes():
|
||||
@@ -203,11 +206,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
|
||||
search_req = search.SearchRequest.create(
|
||||
VectorSearch.from_vector_query(
|
||||
VectorQuery(
|
||||
self.embedding_key,
|
||||
query_embedding,
|
||||
self.limit
|
||||
)
|
||||
VectorQuery(self.embedding_key, query_embedding, self.limit)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -219,16 +218,13 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
SearchOptions(
|
||||
limit=self.limit,
|
||||
fields=fields,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
search_iter = self.cluster.search(
|
||||
self.index_name,
|
||||
search_req,
|
||||
SearchOptions(
|
||||
limit=self.limit,
|
||||
fields=fields
|
||||
)
|
||||
SearchOptions(limit=self.limit, fields=fields),
|
||||
)
|
||||
|
||||
json_response = []
|
||||
@@ -238,4 +234,4 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Search failed with error: {e}"
|
||||
|
||||
return json.dumps(json_response, indent=2)
|
||||
return json.dumps(json_response, indent=2)
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
Crewai Enterprise Tools
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
@@ -13,11 +13,11 @@ from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CrewaiEnterpriseTools(
|
||||
enterprise_token: t.Optional[str] = None,
|
||||
actions_list: t.Optional[t.List[str]] = None,
|
||||
enterprise_action_kit_project_id: t.Optional[str] = None,
|
||||
enterprise_action_kit_project_url: t.Optional[str] = None,
|
||||
def CrewaiEnterpriseTools( # noqa: N802
|
||||
enterprise_token: str | None = None,
|
||||
actions_list: list[str] | None = None,
|
||||
enterprise_action_kit_project_id: str | None = None,
|
||||
enterprise_action_kit_project_url: str | None = None,
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai enterprise tools.
|
||||
|
||||
@@ -34,10 +34,11 @@ def CrewaiEnterpriseTools(
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"CrewaiEnterpriseTools will be removed in v1.0.0. Considering use `Agent(apps=[...])` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if enterprise_token is None or enterprise_token == "":
|
||||
@@ -65,7 +66,7 @@ def CrewaiEnterpriseTools(
|
||||
|
||||
|
||||
# ENTERPRISE INJECTION ONLY
|
||||
def _parse_actions_list(actions_list: t.Optional[t.List[str]]) -> t.List[str] | None:
|
||||
def _parse_actions_list(actions_list: list[str] | None) -> list[str] | None:
|
||||
"""Parse a string representation of a list of tool names to a list of tool names.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,13 +4,18 @@ This module provides tools for integrating with various platform applications
|
||||
through the CrewAI platform API.
|
||||
"""
|
||||
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import CrewaiPlatformTools
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import CrewAIPlatformActionTool
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import CrewaiPlatformToolBuilder
|
||||
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
|
||||
CrewAIPlatformActionTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import (
|
||||
CrewaiPlatformToolBuilder,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CrewaiPlatformTools",
|
||||
"CrewAIPlatformActionTool",
|
||||
"CrewaiPlatformToolBuilder",
|
||||
"CrewaiPlatformTools",
|
||||
]
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
"""
|
||||
Crewai Enterprise Tools
|
||||
"""
|
||||
import re
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Union, cast, get_origin
|
||||
|
||||
import requests
|
||||
from typing import Dict, Any, List, Type, Optional, Union, get_origin, cast, Literal
|
||||
from pydantic import Field, create_model
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import get_platform_api_base_url, get_platform_integration_token
|
||||
from pydantic import Field, create_model
|
||||
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_api_base_url,
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
class CrewAIPlatformActionTool(BaseTool):
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
action_schema: Dict[str, Any] = Field(
|
||||
action_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the action"
|
||||
)
|
||||
|
||||
@@ -20,7 +26,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
self,
|
||||
description: str,
|
||||
action_name: str,
|
||||
action_schema: Dict[str, Any],
|
||||
action_schema: dict[str, Any],
|
||||
):
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(action_name)
|
||||
@@ -36,7 +42,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
field_type = self._process_schema_type(
|
||||
param_details, self._sanitize_name(param_name).title()
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
field_type = str
|
||||
|
||||
field_definitions[param_name] = self._create_field_definition(
|
||||
@@ -60,7 +66,11 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
input_text=(str, Field(description="Input for the action")),
|
||||
)
|
||||
|
||||
super().__init__(name=action_name.lower().replace(" ", "_"), description=description, args_schema=args_schema)
|
||||
super().__init__(
|
||||
name=action_name.lower().replace(" ", "_"),
|
||||
description=description,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
self.action_name = action_name
|
||||
self.action_schema = action_schema
|
||||
|
||||
@@ -71,8 +81,8 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
return "".join(word.capitalize() for word in parts if word)
|
||||
|
||||
def _extract_schema_info(
|
||||
self, action_schema: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], List[str]]:
|
||||
self, action_schema: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
schema_props = (
|
||||
action_schema.get("function", {})
|
||||
.get("parameters", {})
|
||||
@@ -83,7 +93,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||
@@ -92,7 +102,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
if non_null_types:
|
||||
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||
return Optional[base_type] if is_nullable else base_type
|
||||
return cast(Type[Any], Optional[str])
|
||||
return cast(type[Any], Optional[str])
|
||||
|
||||
if "oneOf" in schema:
|
||||
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||
@@ -111,14 +121,16 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return List[item_type]
|
||||
return list[item_type]
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]:
|
||||
def _create_nested_model(
|
||||
self, schema: dict[str, Any], model_name: str
|
||||
) -> type[Any]:
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
if full_model_name in self._model_registry:
|
||||
@@ -139,7 +151,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
prop_type = self._process_schema_type(
|
||||
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
prop_type = str
|
||||
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
@@ -155,20 +167,18 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: Type[Any], is_required: bool, description: str
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
if is_required:
|
||||
return (field_type, Field(description=description))
|
||||
else:
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
else:
|
||||
return (
|
||||
Optional[field_type],
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
return (
|
||||
Optional[field_type],
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
@@ -180,7 +190,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
def _get_required_nullable_fields(self) -> List[str]:
|
||||
def _get_required_nullable_fields(self) -> list[str]:
|
||||
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||
|
||||
required_nullable_fields = []
|
||||
@@ -191,7 +201,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
|
||||
return required_nullable_fields
|
||||
|
||||
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
|
||||
if "anyOf" in schema:
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
return schema.get("type") == "null"
|
||||
@@ -209,8 +219,9 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
if field_name not in cleaned_kwargs:
|
||||
cleaned_kwargs[field_name] = None
|
||||
|
||||
|
||||
api_url = f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
|
||||
api_url = (
|
||||
f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
|
||||
)
|
||||
token = get_platform_integration_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
@@ -230,4 +241,4 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing action {self.action_name}: {str(e)}"
|
||||
return f"Error executing action {self.action_name}: {e!s}"
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from typing import List, Any, Dict
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import get_platform_api_base_url, get_platform_integration_token
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import CrewAIPlatformActionTool
|
||||
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
|
||||
CrewAIPlatformActionTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_api_base_url,
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
class CrewaiPlatformToolBuilder:
|
||||
@@ -27,13 +33,15 @@ class CrewaiPlatformToolBuilder:
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
actions_url, headers=headers, timeout=30, params={"apps": ",".join(self._apps)}
|
||||
actions_url,
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
params={"apps": ",".join(self._apps)},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
raw_data = response.json()
|
||||
|
||||
self._actions_schema = {}
|
||||
@@ -46,7 +54,9 @@ class CrewaiPlatformToolBuilder:
|
||||
action_schema = {
|
||||
"function": {
|
||||
"name": action_name,
|
||||
"description": action.get("description", f"Execute {action_name}"),
|
||||
"description": action.get(
|
||||
"description", f"Execute {action_name}"
|
||||
),
|
||||
"parameters": action.get("parameters", {}),
|
||||
"app": app,
|
||||
}
|
||||
@@ -54,8 +64,8 @@ class CrewaiPlatformToolBuilder:
|
||||
self._actions_schema[action_name] = action_schema
|
||||
|
||||
def _generate_detailed_description(
|
||||
self, schema: Dict[str, Any], indent: int = 0
|
||||
) -> List[str]:
|
||||
self, schema: dict[str, Any], indent: int = 0
|
||||
) -> list[str]:
|
||||
descriptions = []
|
||||
indent_str = " " * indent
|
||||
|
||||
@@ -127,7 +137,6 @@ class CrewaiPlatformToolBuilder:
|
||||
|
||||
self._tools = tools
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self.tools()
|
||||
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import re
|
||||
import os
|
||||
import typing as t
|
||||
from typing import Literal
|
||||
import logging
|
||||
import json
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import CrewaiPlatformToolBuilder
|
||||
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder import (
|
||||
CrewaiPlatformToolBuilder,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def CrewaiPlatformTools(
|
||||
def CrewaiPlatformTools( # noqa: N802
|
||||
apps: list[str],
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai platform tools.
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_platform_api_base_url() -> str:
|
||||
"""Get the platform API base URL from environment or use default."""
|
||||
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
|
||||
return f"{base_url}/crewai_plus/api/v1/integrations"
|
||||
|
||||
|
||||
def get_platform_integration_token() -> str:
|
||||
"""Get the platform API base URL from environment or use default."""
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or ""
|
||||
if not token:
|
||||
raise ValueError("No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable")
|
||||
return token # TODO: Use context manager to get token
|
||||
raise ValueError(
|
||||
"No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable"
|
||||
)
|
||||
return token # TODO: Use context manager to get token
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -31,9 +24,9 @@ class CSVSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a CSV's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
||||
args_schema: type[BaseModel] = CSVSearchToolSchema
|
||||
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
def __init__(self, csv: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
@@ -42,15 +35,17 @@ class CSVSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, csv: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
csv: Optional[str] = None,
|
||||
csv: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import List, Type
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import OpenAI
|
||||
@@ -9,21 +8,27 @@ from pydantic import BaseModel, Field
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Dall-E Tool."""
|
||||
|
||||
image_description: str = Field(description="Description of the image to be generated by Dall-E.")
|
||||
image_description: str = Field(
|
||||
description="Description of the image to be generated by Dall-E."
|
||||
)
|
||||
|
||||
|
||||
class DallETool(BaseTool):
|
||||
name: str = "Dall-E Tool"
|
||||
description: str = "Generates images using OpenAI's Dall-E model."
|
||||
args_schema: Type[BaseModel] = ImagePromptSchema
|
||||
args_schema: type[BaseModel] = ImagePromptSchema
|
||||
|
||||
model: str = "dall-e-3"
|
||||
size: str = "1024x1024"
|
||||
quality: str = "standard"
|
||||
n: int = 1
|
||||
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="OPENAI_API_KEY", description="API key for OpenAI services", required=True),
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="OPENAI_API_KEY",
|
||||
description="API key for OpenAI services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
@@ -42,11 +47,9 @@ class DallETool(BaseTool):
|
||||
n=self.n,
|
||||
)
|
||||
|
||||
image_data = json.dumps(
|
||||
return json.dumps(
|
||||
{
|
||||
"image_url": response.data[0].url,
|
||||
"image_description": response.data[0].revised_prompt,
|
||||
}
|
||||
)
|
||||
|
||||
return image_data
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -7,27 +7,31 @@ from pydantic import BaseModel, Field, model_validator
|
||||
if TYPE_CHECKING:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
|
||||
class DatabricksQueryToolSchema(BaseModel):
|
||||
"""Input schema for DatabricksQueryTool."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="SQL query to execute against the Databricks workspace table"
|
||||
)
|
||||
catalog: Optional[str] = Field(
|
||||
None, description="Databricks catalog name (optional, defaults to configured catalog)"
|
||||
catalog: str | None = Field(
|
||||
None,
|
||||
description="Databricks catalog name (optional, defaults to configured catalog)",
|
||||
)
|
||||
db_schema: Optional[str] = Field(
|
||||
None, description="Databricks schema name (optional, defaults to configured schema)"
|
||||
db_schema: str | None = Field(
|
||||
None,
|
||||
description="Databricks schema name (optional, defaults to configured schema)",
|
||||
)
|
||||
warehouse_id: Optional[str] = Field(
|
||||
None, description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)"
|
||||
warehouse_id: str | None = Field(
|
||||
None,
|
||||
description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)",
|
||||
)
|
||||
row_limit: Optional[int] = Field(
|
||||
row_limit: int | None = Field(
|
||||
1000, description="Maximum number of rows to return (default: 1000)"
|
||||
)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_input(self) -> 'DatabricksQueryToolSchema':
|
||||
@model_validator(mode="after")
|
||||
def validate_input(self) -> "DatabricksQueryToolSchema":
|
||||
"""Validate the input parameters."""
|
||||
# Ensure the query is not empty
|
||||
if not self.query or not self.query.strip():
|
||||
@@ -61,21 +65,21 @@ class DatabricksQueryTool(BaseTool):
|
||||
"Execute SQL queries against Databricks workspace tables and return the results."
|
||||
" Provide a 'query' parameter with the SQL query to execute."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DatabricksQueryToolSchema
|
||||
args_schema: type[BaseModel] = DatabricksQueryToolSchema
|
||||
|
||||
# Optional default parameters
|
||||
default_catalog: Optional[str] = None
|
||||
default_schema: Optional[str] = None
|
||||
default_warehouse_id: Optional[str] = None
|
||||
default_catalog: str | None = None
|
||||
default_schema: str | None = None
|
||||
default_warehouse_id: str | None = None
|
||||
|
||||
_workspace_client: Optional["WorkspaceClient"] = None
|
||||
package_dependencies: List[str] = ["databricks-sdk"]
|
||||
package_dependencies: list[str] = ["databricks-sdk"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_catalog: Optional[str] = None,
|
||||
default_schema: Optional[str] = None,
|
||||
default_warehouse_id: Optional[str] = None,
|
||||
default_catalog: str | None = None,
|
||||
default_schema: str | None = None,
|
||||
default_warehouse_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -96,7 +100,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
def _validate_credentials(self) -> None:
|
||||
"""Validate that Databricks credentials are available."""
|
||||
has_profile = "DATABRICKS_CONFIG_PROFILE" in os.environ
|
||||
has_direct_auth = "DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ
|
||||
has_direct_auth = (
|
||||
"DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ
|
||||
)
|
||||
|
||||
if not (has_profile or has_direct_auth):
|
||||
raise ValueError(
|
||||
@@ -110,6 +116,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
if self._workspace_client is None:
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
self._workspace_client = WorkspaceClient()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -117,7 +124,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
return self._workspace_client
|
||||
|
||||
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
||||
def _format_results(self, results: list[dict[str, Any]]) -> str:
|
||||
"""Format query results as a readable string."""
|
||||
if not results:
|
||||
return "Query returned no results."
|
||||
@@ -149,8 +156,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
data_rows = []
|
||||
for row in results:
|
||||
# Handle None values by displaying "NULL"
|
||||
row_values = {col: str(row[col]) if row[col] is not None else "NULL" for col in columns}
|
||||
data_row = " | ".join(f"{row_values[col]:{col_widths[col]}}" for col in columns)
|
||||
row_values = {
|
||||
col: str(row[col]) if row[col] is not None else "NULL"
|
||||
for col in columns
|
||||
}
|
||||
data_row = " | ".join(
|
||||
f"{row_values[col]:{col_widths[col]}}" for col in columns
|
||||
)
|
||||
data_rows.append(data_row)
|
||||
|
||||
# Add row count information
|
||||
@@ -190,7 +202,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
catalog=catalog,
|
||||
db_schema=db_schema,
|
||||
warehouse_id=warehouse_id,
|
||||
row_limit=row_limit
|
||||
row_limit=row_limit,
|
||||
)
|
||||
|
||||
# Extract validated parameters
|
||||
@@ -212,18 +224,17 @@ class DatabricksQueryTool(BaseTool):
|
||||
try:
|
||||
# Execute the statement
|
||||
execution = statement.execute_statement(
|
||||
warehouse_id=warehouse_id,
|
||||
statement=query,
|
||||
**context
|
||||
warehouse_id=warehouse_id, statement=query, **context
|
||||
)
|
||||
|
||||
statement_id = execution.statement_id
|
||||
except Exception as execute_error:
|
||||
# Handle immediate execution errors
|
||||
return f"Error starting query execution: {str(execute_error)}"
|
||||
return f"Error starting query execution: {execute_error!s}"
|
||||
|
||||
# Poll for results with better error handling
|
||||
import time
|
||||
|
||||
result = None
|
||||
timeout = 300 # 5 minutes timeout
|
||||
start_time = time.time()
|
||||
@@ -237,8 +248,10 @@ class DatabricksQueryTool(BaseTool):
|
||||
result = statement.get_statement(statement_id)
|
||||
|
||||
# Check if finished - be very explicit about state checking
|
||||
if hasattr(result, 'status') and hasattr(result.status, 'state'):
|
||||
state_value = str(result.status.state) # Convert to string to handle both string and enum
|
||||
if hasattr(result, "status") and hasattr(result.status, "state"):
|
||||
state_value = str(
|
||||
result.status.state
|
||||
) # Convert to string to handle both string and enum
|
||||
|
||||
# Track state changes for debugging
|
||||
if previous_state != state_value:
|
||||
@@ -247,33 +260,38 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Check if state indicates completion
|
||||
if "SUCCEEDED" in state_value:
|
||||
break
|
||||
elif "FAILED" in state_value:
|
||||
if "FAILED" in state_value:
|
||||
# Extract error message with more robust handling
|
||||
error_info = "No detailed error info"
|
||||
try:
|
||||
# First try direct access to error.message
|
||||
if hasattr(result.status, 'error') and result.status.error:
|
||||
if hasattr(result.status.error, 'message'):
|
||||
if (
|
||||
hasattr(result.status, "error")
|
||||
and result.status.error
|
||||
):
|
||||
if hasattr(result.status.error, "message"):
|
||||
error_info = result.status.error.message
|
||||
# Some APIs may have a different structure
|
||||
elif hasattr(result.status.error, 'error_message'):
|
||||
elif hasattr(result.status.error, "error_message"):
|
||||
error_info = result.status.error.error_message
|
||||
# Last resort, try to convert the whole error object to string
|
||||
else:
|
||||
error_info = str(result.status.error)
|
||||
except Exception as err_extract_error:
|
||||
# If all else fails, try to get any info we can
|
||||
error_info = f"Error details unavailable: {str(err_extract_error)}"
|
||||
error_info = (
|
||||
f"Error details unavailable: {err_extract_error!s}"
|
||||
)
|
||||
|
||||
# Return immediately on first FAILED state detection
|
||||
return f"Query execution failed: {error_info}"
|
||||
elif "CANCELED" in state_value:
|
||||
if "CANCELED" in state_value:
|
||||
return "Query was canceled"
|
||||
|
||||
except Exception as poll_error:
|
||||
# Don't immediately fail - try again a few times
|
||||
if poll_count > 3:
|
||||
return f"Error checking query status: {str(poll_error)}"
|
||||
return f"Error checking query status: {poll_error!s}"
|
||||
|
||||
# Wait before polling again
|
||||
time.sleep(2)
|
||||
@@ -282,21 +300,27 @@ class DatabricksQueryTool(BaseTool):
|
||||
if result is None:
|
||||
return "Query returned no result (likely timed out or failed)"
|
||||
|
||||
if not hasattr(result, 'status') or not hasattr(result.status, 'state'):
|
||||
if not hasattr(result, "status") or not hasattr(result.status, "state"):
|
||||
return "Query completed but returned an invalid result structure"
|
||||
|
||||
# Convert state to string for comparison
|
||||
state_value = str(result.status.state)
|
||||
if not any(state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]):
|
||||
if not any(
|
||||
state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]
|
||||
):
|
||||
return f"Query timed out after 5 minutes (last state: {state_value})"
|
||||
|
||||
# Get results - adapt this based on the actual structure of the result object
|
||||
chunk_results = []
|
||||
|
||||
# Check if we have results and a schema in a very defensive way
|
||||
has_schema = (hasattr(result, 'manifest') and result.manifest is not None and
|
||||
hasattr(result.manifest, 'schema') and result.manifest.schema is not None)
|
||||
has_result = (hasattr(result, 'result') and result.result is not None)
|
||||
has_schema = (
|
||||
hasattr(result, "manifest")
|
||||
and result.manifest is not None
|
||||
and hasattr(result.manifest, "schema")
|
||||
and result.manifest.schema is not None
|
||||
)
|
||||
has_result = hasattr(result, "result") and result.result is not None
|
||||
|
||||
if has_schema and has_result:
|
||||
try:
|
||||
@@ -309,10 +333,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
all_columns = set(columns)
|
||||
|
||||
# Dump the raw structure of result data to help troubleshoot
|
||||
if hasattr(result.result, 'data_array'):
|
||||
if hasattr(result.result, "data_array"):
|
||||
# Add defensive check for None data_array
|
||||
if result.result.data_array is None:
|
||||
print("data_array is None - likely an empty result set or DDL query")
|
||||
print(
|
||||
"data_array is None - likely an empty result set or DDL query"
|
||||
)
|
||||
# Return empty result handling rather than trying to process null data
|
||||
return "Query executed successfully (no data returned)"
|
||||
|
||||
@@ -321,7 +347,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
is_likely_incorrect_row_structure = False
|
||||
|
||||
# Only try to analyze sample if data_array exists and has content
|
||||
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0 and len(result.result.data_array[0]) > 0:
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array
|
||||
and len(result.result.data_array) > 0
|
||||
and len(result.result.data_array[0]) > 0
|
||||
):
|
||||
sample_size = min(20, len(result.result.data_array[0]))
|
||||
|
||||
if sample_size > 0:
|
||||
@@ -332,40 +363,81 @@ class DatabricksQueryTool(BaseTool):
|
||||
for i in range(sample_size):
|
||||
val = result.result.data_array[0][i]
|
||||
total_items += 1
|
||||
if isinstance(val, str) and len(val) == 1 and not val.isdigit():
|
||||
if (
|
||||
isinstance(val, str)
|
||||
and len(val) == 1
|
||||
and not val.isdigit()
|
||||
):
|
||||
single_char_count += 1
|
||||
elif isinstance(val, str) and len(val) == 1 and val.isdigit():
|
||||
elif (
|
||||
isinstance(val, str)
|
||||
and len(val) == 1
|
||||
and val.isdigit()
|
||||
):
|
||||
single_digit_count += 1
|
||||
|
||||
# If a significant portion of the first values are single characters or digits,
|
||||
# this likely indicates data is being incorrectly structured
|
||||
if total_items > 0 and (single_char_count + single_digit_count) / total_items > 0.5:
|
||||
if (
|
||||
total_items > 0
|
||||
and (single_char_count + single_digit_count)
|
||||
/ total_items
|
||||
> 0.5
|
||||
):
|
||||
is_likely_incorrect_row_structure = True
|
||||
|
||||
# Additional check: if many rows have just 1 item when we expect multiple columns
|
||||
rows_with_single_item = 0
|
||||
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0:
|
||||
sample_size_for_rows = min(sample_size, len(result.result.data_array[0])) if 'sample_size' in locals() else min(20, len(result.result.data_array[0]))
|
||||
rows_with_single_item = sum(1 for row in result.result.data_array[0][:sample_size_for_rows] if isinstance(row, list) and len(row) == 1)
|
||||
if rows_with_single_item > sample_size_for_rows * 0.5 and len(columns) > 1:
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array
|
||||
and len(result.result.data_array) > 0
|
||||
):
|
||||
sample_size_for_rows = (
|
||||
min(sample_size, len(result.result.data_array[0]))
|
||||
if "sample_size" in locals()
|
||||
else min(20, len(result.result.data_array[0]))
|
||||
)
|
||||
rows_with_single_item = sum(
|
||||
1
|
||||
for row in result.result.data_array[0][
|
||||
:sample_size_for_rows
|
||||
]
|
||||
if isinstance(row, list) and len(row) == 1
|
||||
)
|
||||
if (
|
||||
rows_with_single_item > sample_size_for_rows * 0.5
|
||||
and len(columns) > 1
|
||||
):
|
||||
is_likely_incorrect_row_structure = True
|
||||
|
||||
# Check if we're getting primarily single characters or the data structure seems off,
|
||||
# we should use special handling
|
||||
if 'is_likely_incorrect_row_structure' in locals() and is_likely_incorrect_row_structure:
|
||||
print("Data appears to be malformed - will use special row reconstruction")
|
||||
if (
|
||||
"is_likely_incorrect_row_structure" in locals()
|
||||
and is_likely_incorrect_row_structure
|
||||
):
|
||||
print(
|
||||
"Data appears to be malformed - will use special row reconstruction"
|
||||
)
|
||||
needs_special_string_handling = True
|
||||
else:
|
||||
needs_special_string_handling = False
|
||||
|
||||
# Process results differently based on detection
|
||||
if 'needs_special_string_handling' in locals() and needs_special_string_handling:
|
||||
if (
|
||||
"needs_special_string_handling" in locals()
|
||||
and needs_special_string_handling
|
||||
):
|
||||
# We're dealing with data where the rows may be incorrectly structured
|
||||
print("Using row reconstruction processing mode")
|
||||
|
||||
# Collect all values into a flat list
|
||||
all_values = []
|
||||
if hasattr(result.result, 'data_array') and result.result.data_array:
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array
|
||||
):
|
||||
# Flatten all values into a single list
|
||||
for chunk in result.result.data_array:
|
||||
for item in chunk:
|
||||
@@ -386,32 +458,43 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# Use regex pattern to identify ID columns that likely start a new row
|
||||
import re
|
||||
id_pattern = re.compile(r'^\d{5,9}$') # Netflix IDs are often 5-9 digits
|
||||
|
||||
id_pattern = re.compile(
|
||||
r"^\d{5,9}$"
|
||||
) # Netflix IDs are often 5-9 digits
|
||||
id_indices = []
|
||||
|
||||
for i, val in enumerate(all_values):
|
||||
if isinstance(val, str) and id_pattern.match(val):
|
||||
# This value looks like an ID, might be the start of a row
|
||||
if i < len(all_values) - 1:
|
||||
next_few_values = all_values[i+1:i+5]
|
||||
next_few_values = all_values[i + 1 : i + 5]
|
||||
# If following values look like they could be part of a title
|
||||
if any(isinstance(v, str) and len(v) > 1 for v in next_few_values):
|
||||
if any(
|
||||
isinstance(v, str) and len(v) > 1
|
||||
for v in next_few_values
|
||||
):
|
||||
id_indices.append(i)
|
||||
|
||||
if id_indices:
|
||||
|
||||
# If we found potential row starts, use them to extract rows
|
||||
for i in range(len(id_indices)):
|
||||
start_idx = id_indices[i]
|
||||
end_idx = id_indices[i+1] if i+1 < len(id_indices) else len(all_values)
|
||||
end_idx = (
|
||||
id_indices[i + 1]
|
||||
if i + 1 < len(id_indices)
|
||||
else len(all_values)
|
||||
)
|
||||
|
||||
# Extract values for this row
|
||||
row_values = all_values[start_idx:end_idx]
|
||||
|
||||
# Special handling for Netflix title data
|
||||
# Titles might be split into individual characters
|
||||
if 'Title' in columns and len(row_values) > expected_column_count:
|
||||
|
||||
if (
|
||||
"Title" in columns
|
||||
and len(row_values) > expected_column_count
|
||||
):
|
||||
# Try to reconstruct by looking for patterns
|
||||
# We know ID is first, then Title (which may be split)
|
||||
# Then other fields like Genre, etc.
|
||||
@@ -424,7 +507,14 @@ class DatabricksQueryTool(BaseTool):
|
||||
for j in range(2, min(100, len(row_values))):
|
||||
val = row_values[j]
|
||||
# Check for common genres or non-title markers
|
||||
if isinstance(val, str) and val in ['Comedy', 'Drama', 'Action', 'Horror', 'Thriller', 'Documentary']:
|
||||
if isinstance(val, str) and val in [
|
||||
"Comedy",
|
||||
"Drama",
|
||||
"Action",
|
||||
"Horror",
|
||||
"Thriller",
|
||||
"Documentary",
|
||||
]:
|
||||
# Likely found the Genre field
|
||||
title_end_idx = j
|
||||
break
|
||||
@@ -433,15 +523,24 @@ class DatabricksQueryTool(BaseTool):
|
||||
if title_end_idx > 1:
|
||||
title_chars = row_values[1:title_end_idx]
|
||||
# Check if they're individual characters
|
||||
if all(isinstance(c, str) and len(c) == 1 for c in title_chars):
|
||||
title = ''.join(title_chars)
|
||||
row_dict['Title'] = title
|
||||
if all(
|
||||
isinstance(c, str) and len(c) == 1
|
||||
for c in title_chars
|
||||
):
|
||||
title = "".join(title_chars)
|
||||
row_dict["Title"] = title
|
||||
|
||||
# Assign remaining values to columns
|
||||
remaining_values = row_values[title_end_idx:]
|
||||
for j, col_name in enumerate(columns[2:], 2):
|
||||
if j-2 < len(remaining_values):
|
||||
row_dict[col_name] = remaining_values[j-2]
|
||||
remaining_values = row_values[
|
||||
title_end_idx:
|
||||
]
|
||||
for j, col_name in enumerate(
|
||||
columns[2:], 2
|
||||
):
|
||||
if j - 2 < len(remaining_values):
|
||||
row_dict[col_name] = (
|
||||
remaining_values[j - 2]
|
||||
)
|
||||
else:
|
||||
row_dict[col_name] = None
|
||||
else:
|
||||
@@ -463,7 +562,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
reconstructed_rows.append(row_dict)
|
||||
else:
|
||||
# More intelligent chunking - try to detect where columns like Title might be split
|
||||
title_idx = columns.index('Title') if 'Title' in columns else -1
|
||||
title_idx = (
|
||||
columns.index("Title") if "Title" in columns else -1
|
||||
)
|
||||
|
||||
if title_idx >= 0:
|
||||
print("Attempting title reconstruction method")
|
||||
@@ -471,21 +572,27 @@ class DatabricksQueryTool(BaseTool):
|
||||
i = 0
|
||||
while i < len(all_values):
|
||||
# Check if this could be an ID (start of a row)
|
||||
if isinstance(all_values[i], str) and id_pattern.match(all_values[i]):
|
||||
if isinstance(
|
||||
all_values[i], str
|
||||
) and id_pattern.match(all_values[i]):
|
||||
row_dict = {columns[0]: all_values[i]}
|
||||
i += 1
|
||||
|
||||
# Try to reconstruct title if it appears to be split
|
||||
title_chars = []
|
||||
while (i < len(all_values) and
|
||||
isinstance(all_values[i], str) and
|
||||
len(all_values[i]) <= 1 and
|
||||
len(title_chars) < 100): # Cap title length
|
||||
while (
|
||||
i < len(all_values)
|
||||
and isinstance(all_values[i], str)
|
||||
and len(all_values[i]) <= 1
|
||||
and len(title_chars) < 100
|
||||
): # Cap title length
|
||||
title_chars.append(all_values[i])
|
||||
i += 1
|
||||
|
||||
if title_chars:
|
||||
row_dict[columns[title_idx]] = ''.join(title_chars)
|
||||
row_dict[columns[title_idx]] = "".join(
|
||||
title_chars
|
||||
)
|
||||
|
||||
# Add remaining fields
|
||||
for j in range(title_idx + 1, len(columns)):
|
||||
@@ -502,11 +609,18 @@ class DatabricksQueryTool(BaseTool):
|
||||
# If we still don't have rows, use simple chunking as fallback
|
||||
if not reconstructed_rows:
|
||||
print("Falling back to basic chunking approach")
|
||||
chunks = [all_values[i:i+expected_column_count] for i in range(0, len(all_values), expected_column_count)]
|
||||
chunks = [
|
||||
all_values[i : i + expected_column_count]
|
||||
for i in range(
|
||||
0, len(all_values), expected_column_count
|
||||
)
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
# Skip chunks that seem to be partial/incomplete rows
|
||||
if len(chunk) < expected_column_count * 0.75: # Allow for some missing values
|
||||
if (
|
||||
len(chunk) < expected_column_count * 0.75
|
||||
): # Allow for some missing values
|
||||
continue
|
||||
|
||||
row_dict = {}
|
||||
@@ -521,13 +635,16 @@ class DatabricksQueryTool(BaseTool):
|
||||
reconstructed_rows.append(row_dict)
|
||||
|
||||
# Apply post-processing to fix known issues
|
||||
if reconstructed_rows and 'Title' in columns:
|
||||
if reconstructed_rows and "Title" in columns:
|
||||
print("Applying post-processing to improve data quality")
|
||||
for row in reconstructed_rows:
|
||||
# Fix titles that might still have issues
|
||||
if isinstance(row.get('Title'), str) and len(row.get('Title')) <= 1:
|
||||
if (
|
||||
isinstance(row.get("Title"), str)
|
||||
and len(row.get("Title")) <= 1
|
||||
):
|
||||
# This is likely still a fragmented title - mark as potentially incomplete
|
||||
row['Title'] = f"[INCOMPLETE] {row.get('Title')}"
|
||||
row["Title"] = f"[INCOMPLETE] {row.get('Title')}"
|
||||
|
||||
# Ensure we respect the row limit
|
||||
if row_limit and len(reconstructed_rows) > row_limit:
|
||||
@@ -539,28 +656,53 @@ class DatabricksQueryTool(BaseTool):
|
||||
print("Using standard processing mode")
|
||||
|
||||
# Check different result structures
|
||||
if hasattr(result.result, 'data_array') and result.result.data_array:
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array
|
||||
):
|
||||
# Check if data appears to be malformed within chunks
|
||||
for chunk_idx, chunk in enumerate(result.result.data_array):
|
||||
|
||||
for _chunk_idx, chunk in enumerate(
|
||||
result.result.data_array
|
||||
):
|
||||
# Check if chunk might actually contain individual columns of a single row
|
||||
# This is another way data might be malformed - check the first few values
|
||||
if len(chunk) > 0 and len(columns) > 1:
|
||||
# If there seems to be a mismatch between chunk structure and expected columns
|
||||
first_few_values = chunk[:min(5, len(chunk))]
|
||||
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in first_few_values):
|
||||
if len(chunk) > len(columns) * 3: # Heuristic: if chunk has way more items than columns
|
||||
print("Chunk appears to contain individual values rather than rows - switching to row reconstruction")
|
||||
first_few_values = chunk[: min(5, len(chunk))]
|
||||
if all(
|
||||
isinstance(val, (str, int, float))
|
||||
and not isinstance(val, (list, dict))
|
||||
for val in first_few_values
|
||||
):
|
||||
if (
|
||||
len(chunk) > len(columns) * 3
|
||||
): # Heuristic: if chunk has way more items than columns
|
||||
print(
|
||||
"Chunk appears to contain individual values rather than rows - switching to row reconstruction"
|
||||
)
|
||||
|
||||
# This chunk might actually be values of multiple rows - try to reconstruct
|
||||
values = chunk # All values in this chunk
|
||||
reconstructed_rows = []
|
||||
|
||||
# Try to create rows based on expected column count
|
||||
for i in range(0, len(values), len(columns)):
|
||||
if i + len(columns) <= len(values): # Ensure we have enough values
|
||||
row_values = values[i:i+len(columns)]
|
||||
row_dict = {col: val for col, val in zip(columns, row_values)}
|
||||
for i in range(
|
||||
0, len(values), len(columns)
|
||||
):
|
||||
if i + len(columns) <= len(
|
||||
values
|
||||
): # Ensure we have enough values
|
||||
row_values = values[
|
||||
i : i + len(columns)
|
||||
]
|
||||
row_dict = {
|
||||
col: val
|
||||
for col, val in zip(
|
||||
columns,
|
||||
row_values,
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
reconstructed_rows.append(row_dict)
|
||||
|
||||
if reconstructed_rows:
|
||||
@@ -569,21 +711,36 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# Special case: when chunk contains exactly the right number of values for a single row
|
||||
# This handles the case where instead of a list of rows, we just got all values in a flat list
|
||||
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in chunk):
|
||||
if len(chunk) == len(columns) or (len(chunk) > 0 and len(chunk) % len(columns) == 0):
|
||||
|
||||
if all(
|
||||
isinstance(val, (str, int, float))
|
||||
and not isinstance(val, (list, dict))
|
||||
for val in chunk
|
||||
):
|
||||
if len(chunk) == len(columns) or (
|
||||
len(chunk) > 0
|
||||
and len(chunk) % len(columns) == 0
|
||||
):
|
||||
# Process flat list of values as rows
|
||||
for i in range(0, len(chunk), len(columns)):
|
||||
row_values = chunk[i:i+len(columns)]
|
||||
if len(row_values) == len(columns): # Only process complete rows
|
||||
row_dict = {col: val for col, val in zip(columns, row_values)}
|
||||
row_values = chunk[i : i + len(columns)]
|
||||
if len(row_values) == len(
|
||||
columns
|
||||
): # Only process complete rows
|
||||
row_dict = {
|
||||
col: val
|
||||
for col, val in zip(
|
||||
columns,
|
||||
row_values,
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
chunk_results.append(row_dict)
|
||||
|
||||
# Skip regular row processing for this chunk
|
||||
continue
|
||||
|
||||
# Normal processing for typical row structure
|
||||
for row_idx, row in enumerate(chunk):
|
||||
for _row_idx, row in enumerate(chunk):
|
||||
# Ensure row is actually a collection of values
|
||||
if not isinstance(row, (list, tuple, dict)):
|
||||
# This might be a single value; skip it or handle specially
|
||||
@@ -599,7 +756,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
elif isinstance(row, (list, tuple)):
|
||||
# Map list of values to columns
|
||||
for i, val in enumerate(row):
|
||||
if i < len(columns): # Only process if we have a matching column
|
||||
if (
|
||||
i < len(columns)
|
||||
): # Only process if we have a matching column
|
||||
row_dict[columns[i]] = val
|
||||
else:
|
||||
# Extra values without column names
|
||||
@@ -614,16 +773,18 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
chunk_results.append(row_dict)
|
||||
|
||||
elif hasattr(result.result, 'data') and result.result.data:
|
||||
elif hasattr(result.result, "data") and result.result.data:
|
||||
# Alternative data structure
|
||||
|
||||
for row_idx, row in enumerate(result.result.data):
|
||||
for _row_idx, row in enumerate(result.result.data):
|
||||
# Debug info
|
||||
|
||||
# Safely create dictionary matching column names to values
|
||||
row_dict = {}
|
||||
for i, val in enumerate(row):
|
||||
if i < len(columns): # Only process if we have a matching column
|
||||
if i < len(
|
||||
columns
|
||||
): # Only process if we have a matching column
|
||||
row_dict[columns[i]] = val
|
||||
else:
|
||||
# Extra values without column names
|
||||
@@ -642,7 +803,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
normalized_results = []
|
||||
for row in chunk_results:
|
||||
# Create a new row with all columns, defaulting to None for missing ones
|
||||
normalized_row = {col: row.get(col, None) for col in all_columns}
|
||||
normalized_row = {
|
||||
col: row.get(col, None) for col in all_columns
|
||||
}
|
||||
normalized_results.append(normalized_row)
|
||||
|
||||
# Replace the original results with normalized ones
|
||||
@@ -651,11 +814,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
except Exception as results_error:
|
||||
# Enhanced error message with more context
|
||||
import traceback
|
||||
|
||||
error_details = traceback.format_exc()
|
||||
return f"Error processing query results: {str(results_error)}\n\nDetails:\n{error_details}"
|
||||
return f"Error processing query results: {results_error!s}\n\nDetails:\n{error_details}"
|
||||
|
||||
# If we have no results but the query succeeded (e.g., for DDL statements)
|
||||
if not chunk_results and hasattr(result, 'status'):
|
||||
if not chunk_results and hasattr(result, "status"):
|
||||
state_value = str(result.status.state)
|
||||
if "SUCCEEDED" in state_value:
|
||||
return "Query executed successfully (no results to display)"
|
||||
@@ -666,5 +830,8 @@ class DatabricksQueryTool(BaseTool):
|
||||
except Exception as e:
|
||||
# Include more details in the error message to help with debugging
|
||||
import traceback
|
||||
|
||||
error_details = traceback.format_exc()
|
||||
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
|
||||
return (
|
||||
f"Error executing Databricks query: {e!s}\n\nDetails:\n{error_details}"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -20,10 +20,10 @@ class DirectoryReadTool(BaseTool):
|
||||
description: str = (
|
||||
"A tool that can be used to recursively list a directory's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DirectoryReadToolSchema
|
||||
directory: Optional[str] = None
|
||||
args_schema: type[BaseModel] = DirectoryReadToolSchema
|
||||
directory: str | None = None
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.directory = directory
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -31,11 +24,9 @@ class DirectorySearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a directory's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
args_schema: type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
@@ -44,16 +35,17 @@ class DirectorySearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(
|
||||
directory,
|
||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
||||
)
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
directory: Optional[str] = None,
|
||||
directory: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -14,7 +9,7 @@ from ..rag.rag_tool import RagTool
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
"""Input for DOCXSearchTool."""
|
||||
|
||||
docx: Optional[str] = Field(
|
||||
docx: str | None = Field(
|
||||
..., description="File path or URL of a DOCX file to be searched"
|
||||
)
|
||||
search_query: str = Field(
|
||||
@@ -37,9 +32,9 @@ class DOCXSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a DOCX's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
||||
args_schema: type[BaseModel] = DOCXSearchToolSchema
|
||||
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
def __init__(self, docx: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
@@ -48,15 +43,17 @@ class DOCXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
docx: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Any:
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -17,13 +17,11 @@ class EXABaseToolSchema(BaseModel):
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to search the internet"
|
||||
)
|
||||
start_published_date: Optional[str] = Field(
|
||||
start_published_date: str | None = Field(
|
||||
None, description="Start date for the search"
|
||||
)
|
||||
end_published_date: Optional[str] = Field(
|
||||
None, description="End date for the search"
|
||||
)
|
||||
include_domains: Optional[list[str]] = Field(
|
||||
end_published_date: str | None = Field(None, description="End date for the search")
|
||||
include_domains: list[str] | None = Field(
|
||||
None, description="List of domains to include in the search"
|
||||
)
|
||||
|
||||
@@ -32,18 +30,18 @@ class EXASearchTool(BaseTool):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
name: str = "EXASearchTool"
|
||||
description: str = "Search the internet using Exa"
|
||||
args_schema: Type[BaseModel] = EXABaseToolSchema
|
||||
args_schema: type[BaseModel] = EXABaseToolSchema
|
||||
client: Optional["Exa"] = None
|
||||
content: Optional[bool] = False
|
||||
summary: Optional[bool] = False
|
||||
type: Optional[str] = "auto"
|
||||
package_dependencies: List[str] = ["exa_py"]
|
||||
api_key: Optional[str] = Field(
|
||||
content: bool | None = False
|
||||
summary: bool | None = False
|
||||
type: str | None = "auto"
|
||||
package_dependencies: list[str] = ["exa_py"]
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("EXA_API_KEY"),
|
||||
description="API key for Exa services",
|
||||
json_schema_extra={"required": False},
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="EXA_API_KEY", description="API key for Exa services", required=False
|
||||
),
|
||||
@@ -51,9 +49,9 @@ class EXASearchTool(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[bool] = False,
|
||||
summary: Optional[bool] = False,
|
||||
type: Optional[str] = "auto",
|
||||
content: bool | None = False,
|
||||
summary: bool | None = False,
|
||||
type: str | None = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -81,9 +79,9 @@ class EXASearchTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
start_published_date: Optional[str] = None,
|
||||
end_published_date: Optional[str] = None,
|
||||
include_domains: Optional[list[str]] = None,
|
||||
start_published_date: str | None = None,
|
||||
end_published_date: str | None = None,
|
||||
include_domains: list[str] | None = None,
|
||||
) -> Any:
|
||||
if self.client is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -8,8 +8,12 @@ class FileReadToolSchema(BaseModel):
|
||||
"""Input for FileReadTool."""
|
||||
|
||||
file_path: str = Field(..., description="Mandatory file full path to read the file")
|
||||
start_line: Optional[int] = Field(1, description="Line number to start reading from (1-indexed)")
|
||||
line_count: Optional[int] = Field(None, description="Number of lines to read. If None, reads the entire file")
|
||||
start_line: int | None = Field(
|
||||
1, description="Line number to start reading from (1-indexed)"
|
||||
)
|
||||
line_count: int | None = Field(
|
||||
None, description="Number of lines to read. If None, reads the entire file"
|
||||
)
|
||||
|
||||
|
||||
class FileReadTool(BaseTool):
|
||||
@@ -38,10 +42,10 @@ class FileReadTool(BaseTool):
|
||||
|
||||
name: str = "Read a file's content"
|
||||
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read. Optionally, provide 'start_line' to start reading from a specific line and 'line_count' to limit the number of lines read."
|
||||
args_schema: Type[BaseModel] = FileReadToolSchema
|
||||
file_path: Optional[str] = None
|
||||
args_schema: type[BaseModel] = FileReadToolSchema
|
||||
file_path: str | None = None
|
||||
|
||||
def __init__(self, file_path: Optional[str] = None, **kwargs: Any) -> None:
|
||||
def __init__(self, file_path: str | None = None, **kwargs: Any) -> None:
|
||||
"""Initialize the FileReadTool.
|
||||
|
||||
Args:
|
||||
@@ -59,18 +63,16 @@ class FileReadTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
file_path: Optional[str] = None,
|
||||
start_line: Optional[int] = 1,
|
||||
line_count: Optional[int] = None,
|
||||
file_path: str | None = None,
|
||||
start_line: int | None = 1,
|
||||
line_count: int | None = None,
|
||||
) -> str:
|
||||
file_path = file_path or self.file_path
|
||||
start_line = start_line or 1
|
||||
line_count = line_count or None
|
||||
|
||||
if file_path is None:
|
||||
return (
|
||||
"Error: No file path provided. Please provide a file path either in the constructor or as an argument."
|
||||
)
|
||||
return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
|
||||
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
@@ -82,7 +84,8 @@ class FileReadTool(BaseTool):
|
||||
selected_lines = [
|
||||
line
|
||||
for i, line in enumerate(file)
|
||||
if i >= start_idx and (line_count is None or i < start_idx + line_count)
|
||||
if i >= start_idx
|
||||
and (line_count is None or i < start_idx + line_count)
|
||||
]
|
||||
|
||||
if not selected_lines and start_idx > 0:
|
||||
@@ -94,4 +97,4 @@ class FileReadTool(BaseTool):
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when trying to read file: {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
return f"Error: Failed to read file {file_path}. {e!s}"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
@@ -11,25 +11,22 @@ def strtobool(val) -> bool:
|
||||
val = val.lower()
|
||||
if val in ("y", "yes", "t", "true", "on", "1"):
|
||||
return True
|
||||
elif val in ("n", "no", "f", "false", "off", "0"):
|
||||
if val in ("n", "no", "f", "false", "off", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"invalid value to cast to bool: {val!r}")
|
||||
raise ValueError(f"invalid value to cast to bool: {val!r}")
|
||||
|
||||
|
||||
class FileWriterToolInput(BaseModel):
|
||||
filename: str
|
||||
directory: Optional[str] = "./"
|
||||
directory: str | None = "./"
|
||||
overwrite: str | bool = False
|
||||
content: str
|
||||
|
||||
|
||||
class FileWriterTool(BaseTool):
|
||||
name: str = "File Writer Tool"
|
||||
description: str = (
|
||||
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
)
|
||||
args_schema: Type[BaseModel] = FileWriterToolInput
|
||||
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
args_schema: type[BaseModel] = FileWriterToolInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
@@ -57,6 +54,6 @@ class FileWriterTool(BaseTool):
|
||||
f"File {filepath} already exists and overwrite option was not passed."
|
||||
)
|
||||
except KeyError as e:
|
||||
return f"An error occurred while accessing key: {str(e)}"
|
||||
return f"An error occurred while accessing key: {e!s}"
|
||||
except Exception as e:
|
||||
return f"An error occurred while writing to the file: {str(e)}"
|
||||
return f"An error occurred while writing to the file: {e!s}"
|
||||
|
||||
@@ -3,7 +3,6 @@ import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
import os
|
||||
import zipfile
|
||||
import tarfile
|
||||
from typing import Type, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import zipfile
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileCompressorToolInput(BaseModel):
|
||||
"""Input schema for FileCompressorTool."""
|
||||
input_path: str = Field(..., description="Path to the file or directory to compress.")
|
||||
output_path: Optional[str] = Field(default=None, description="Optional output archive filename.")
|
||||
overwrite: bool = Field(default=False, description="Whether to overwrite the archive if it already exists.")
|
||||
format: str = Field(default="zip", description="Compression format ('zip', 'tar', 'tar.gz', 'tar.bz2', 'tar.xz').")
|
||||
|
||||
input_path: str = Field(
|
||||
..., description="Path to the file or directory to compress."
|
||||
)
|
||||
output_path: str | None = Field(
|
||||
default=None, description="Optional output archive filename."
|
||||
)
|
||||
overwrite: bool = Field(
|
||||
default=False,
|
||||
description="Whether to overwrite the archive if it already exists.",
|
||||
)
|
||||
format: str = Field(
|
||||
default="zip",
|
||||
description="Compression format ('zip', 'tar', 'tar.gz', 'tar.bz2', 'tar.xz').",
|
||||
)
|
||||
|
||||
|
||||
class FileCompressorTool(BaseTool):
|
||||
@@ -20,58 +31,65 @@ class FileCompressorTool(BaseTool):
|
||||
"Compresses a file or directory into an archive (.zip currently supported). "
|
||||
"Useful for archiving logs, documents, or backups."
|
||||
)
|
||||
args_schema: Type[BaseModel] = FileCompressorToolInput
|
||||
args_schema: type[BaseModel] = FileCompressorToolInput
|
||||
|
||||
|
||||
def _run(self, input_path: str, output_path: Optional[str] = None, overwrite: bool = False, format: str = "zip") -> str:
|
||||
|
||||
if not os.path.exists(input_path):
|
||||
return f"Input path '{input_path}' does not exist."
|
||||
|
||||
if not output_path:
|
||||
output_path = self._generate_output_path(input_path, format)
|
||||
|
||||
FORMAT_EXTENSION = {
|
||||
"zip": ".zip",
|
||||
"tar": ".tar",
|
||||
"tar.gz": ".tar.gz",
|
||||
"tar.bz2": ".tar.bz2",
|
||||
"tar.xz": ".tar.xz"
|
||||
}
|
||||
|
||||
if format not in FORMAT_EXTENSION:
|
||||
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(FORMAT_EXTENSION.keys())}"
|
||||
elif not output_path.endswith(FORMAT_EXTENSION[format]):
|
||||
return f"Error: If '{format}' format is chosen, output file must have a '{FORMAT_EXTENSION[format]}' extension."
|
||||
if not self._prepare_output(output_path, overwrite):
|
||||
return f"Output '{output_path}' already exists and overwrite is set to False."
|
||||
def _run(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str | None = None,
|
||||
overwrite: bool = False,
|
||||
format: str = "zip",
|
||||
) -> str:
|
||||
if not os.path.exists(input_path):
|
||||
return f"Input path '{input_path}' does not exist."
|
||||
|
||||
try:
|
||||
format_compression = {
|
||||
"zip": self._compress_zip,
|
||||
"tar": self._compress_tar,
|
||||
"tar.gz": self._compress_tar,
|
||||
"tar.bz2": self._compress_tar,
|
||||
"tar.xz": self._compress_tar
|
||||
}
|
||||
if format == "zip":
|
||||
format_compression[format](input_path, output_path)
|
||||
else:
|
||||
format_compression[format](input_path, output_path, format)
|
||||
|
||||
return f"Successfully compressed '{input_path}' into '{output_path}'"
|
||||
except FileNotFoundError:
|
||||
return f"Error: File not found at path: {input_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when accessing '{input_path}' or writing '{output_path}'"
|
||||
except Exception as e:
|
||||
return f"An unexpected error occurred during compression: {str(e)}"
|
||||
if not output_path:
|
||||
output_path = self._generate_output_path(input_path, format)
|
||||
|
||||
FORMAT_EXTENSION = {
|
||||
"zip": ".zip",
|
||||
"tar": ".tar",
|
||||
"tar.gz": ".tar.gz",
|
||||
"tar.bz2": ".tar.bz2",
|
||||
"tar.xz": ".tar.xz",
|
||||
}
|
||||
|
||||
if format not in FORMAT_EXTENSION:
|
||||
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(FORMAT_EXTENSION.keys())}"
|
||||
if not output_path.endswith(FORMAT_EXTENSION[format]):
|
||||
return f"Error: If '{format}' format is chosen, output file must have a '{FORMAT_EXTENSION[format]}' extension."
|
||||
if not self._prepare_output(output_path, overwrite):
|
||||
return (
|
||||
f"Output '{output_path}' already exists and overwrite is set to False."
|
||||
)
|
||||
|
||||
try:
|
||||
format_compression = {
|
||||
"zip": self._compress_zip,
|
||||
"tar": self._compress_tar,
|
||||
"tar.gz": self._compress_tar,
|
||||
"tar.bz2": self._compress_tar,
|
||||
"tar.xz": self._compress_tar,
|
||||
}
|
||||
if format == "zip":
|
||||
format_compression[format](input_path, output_path)
|
||||
else:
|
||||
format_compression[format](input_path, output_path, format)
|
||||
|
||||
return f"Successfully compressed '{input_path}' into '{output_path}'"
|
||||
except FileNotFoundError:
|
||||
return f"Error: File not found at path: {input_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when accessing '{input_path}' or writing '{output_path}'"
|
||||
except Exception as e:
|
||||
return f"An unexpected error occurred during compression: {e!s}"
|
||||
|
||||
def _generate_output_path(self, input_path: str, format: str) -> str:
|
||||
"""Generates output path based on input path and format."""
|
||||
if os.path.isfile(input_path):
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[0] # Remove extension
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[
|
||||
0
|
||||
] # Remove extension
|
||||
else:
|
||||
base_name = os.path.basename(os.path.normpath(input_path)) # Directory name
|
||||
return os.path.join(os.getcwd(), f"{base_name}.{format}")
|
||||
@@ -87,7 +105,7 @@ class FileCompressorTool(BaseTool):
|
||||
|
||||
def _compress_zip(self, input_path: str, output_path: str):
|
||||
"""Compresses input into a zip archive."""
|
||||
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
if os.path.isfile(input_path):
|
||||
zipf.write(input_path, os.path.basename(input_path))
|
||||
else:
|
||||
@@ -97,19 +115,18 @@ class FileCompressorTool(BaseTool):
|
||||
arcname = os.path.relpath(full_path, start=input_path)
|
||||
zipf.write(full_path, arcname)
|
||||
|
||||
|
||||
def _compress_tar(self, input_path: str, output_path: str, format: str):
|
||||
"""Compresses input into a tar archive with the given format."""
|
||||
format_mode = {
|
||||
"tar": "w",
|
||||
"tar.gz": "w:gz",
|
||||
"tar.bz2": "w:bz2",
|
||||
"tar.xz": "w:xz"
|
||||
"tar.xz": "w:xz",
|
||||
}
|
||||
|
||||
if format not in format_mode:
|
||||
raise ValueError(f"Unsupported tar format: {format}")
|
||||
|
||||
|
||||
mode = format_mode[format]
|
||||
|
||||
with tarfile.open(output_path, mode) as tarf:
|
||||
|
||||
@@ -1,88 +1,126 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from crewai_tools.tools.files_compressor_tool import FileCompressorTool
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return FileCompressorTool()
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_input_path_does_not_exist(mock_exists, tool):
|
||||
result = tool._run("nonexistent_path")
|
||||
assert "does not exist" in result
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("os.getcwd", return_value="/mocked/cwd")
|
||||
@patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression
|
||||
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
|
||||
def test_generate_output_path_default(mock_prepare, mock_compress, mock_cwd, mock_exists, tool):
|
||||
def test_generate_output_path_default(
|
||||
mock_prepare, mock_compress, mock_cwd, mock_exists, tool
|
||||
):
|
||||
result = tool._run(input_path="mydir", format="zip")
|
||||
assert "Successfully compressed" in result
|
||||
mock_compress.assert_called_once()
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch.object(FileCompressorTool, "_compress_zip")
|
||||
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
|
||||
def test_zip_compression(mock_prepare, mock_compress, mock_exists, tool):
|
||||
result = tool._run(input_path="some/path", output_path="archive.zip", format="zip", overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="some/path", output_path="archive.zip", format="zip", overwrite=True
|
||||
)
|
||||
assert "Successfully compressed" in result
|
||||
mock_compress.assert_called_once()
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch.object(FileCompressorTool, "_compress_tar")
|
||||
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
|
||||
def test_tar_gz_compression(mock_prepare, mock_compress, mock_exists, tool):
|
||||
result = tool._run(input_path="some/path", output_path="archive.tar.gz", format="tar.gz", overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="some/path",
|
||||
output_path="archive.tar.gz",
|
||||
format="tar.gz",
|
||||
overwrite=True,
|
||||
)
|
||||
assert "Successfully compressed" in result
|
||||
mock_compress.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format", ["tar", "tar.bz2", "tar.xz"])
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch.object(FileCompressorTool, "_compress_tar")
|
||||
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
|
||||
def test_other_tar_formats(mock_prepare, mock_compress, mock_exists, format, tool):
|
||||
result = tool._run(input_path="path/to/input", output_path=f"archive.{format}", format=format, overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="path/to/input",
|
||||
output_path=f"archive.{format}",
|
||||
format=format,
|
||||
overwrite=True,
|
||||
)
|
||||
assert "Successfully compressed" in result
|
||||
mock_compress.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format", ["rar", "7z"])
|
||||
@patch("os.path.exists", return_value=True) #Ensure input_path exists
|
||||
@patch("os.path.exists", return_value=True) # Ensure input_path exists
|
||||
def test_unsupported_format(_, tool, format):
|
||||
result = tool._run(input_path="some/path", output_path=f"archive.{format}", format=format)
|
||||
result = tool._run(
|
||||
input_path="some/path", output_path=f"archive.{format}", format=format
|
||||
)
|
||||
assert "not supported" in result
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_extension_mismatch(_ , tool):
|
||||
result = tool._run(input_path="some/path", output_path="archive.zip", format="tar.gz")
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_extension_mismatch(_, tool):
|
||||
result = tool._run(
|
||||
input_path="some/path", output_path="archive.zip", format="tar.gz"
|
||||
)
|
||||
assert "must have a '.tar.gz' extension" in result
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("os.path.isfile", return_value=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_existing_output_no_overwrite(_, __, ___, tool):
|
||||
result = tool._run(input_path="some/path", output_path="archive.zip", format="zip", overwrite=False)
|
||||
result = tool._run(
|
||||
input_path="some/path", output_path="archive.zip", format="zip", overwrite=False
|
||||
)
|
||||
assert "overwrite is set to False" in result
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("zipfile.ZipFile", side_effect=PermissionError)
|
||||
def test_permission_error(mock_zip, _, tool):
|
||||
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
|
||||
)
|
||||
assert "Permission denied" in result
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("zipfile.ZipFile", side_effect=FileNotFoundError)
|
||||
def test_file_not_found_during_zip(mock_zip, _, tool):
|
||||
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
|
||||
)
|
||||
assert "File not found" in result
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("zipfile.ZipFile", side_effect=Exception("Unexpected"))
|
||||
def test_general_exception_during_zip(mock_zip, _, tool):
|
||||
result = tool._run(input_path="file.txt", output_path="file.zip", format="zip", overwrite=True)
|
||||
result = tool._run(
|
||||
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
|
||||
)
|
||||
assert "unexpected error" in result
|
||||
|
||||
|
||||
|
||||
# Test: Output directory is created when missing
|
||||
@patch("os.makedirs")
|
||||
@patch("os.path.exists", return_value=False)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -43,9 +43,9 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
)
|
||||
name: str = "Firecrawl web crawl tool"
|
||||
description: str = "Crawl webpages using Firecrawl and return the contents"
|
||||
args_schema: Type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
|
||||
api_key: Optional[str] = None
|
||||
config: Optional[dict[str, Any]] = Field(
|
||||
args_schema: type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"maxDepth": 2,
|
||||
"ignoreSitemap": True,
|
||||
@@ -60,12 +60,16 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
}
|
||||
)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True),
|
||||
package_dependencies: list[str] = ["firecrawl-py"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="FIRECRAWL_API_KEY",
|
||||
description="API key for Firecrawl services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self._initialize_firecrawl()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type, Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -41,9 +41,9 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
)
|
||||
name: str = "Firecrawl web scrape tool"
|
||||
description: str = "Scrape webpages using Firecrawl and return the contents"
|
||||
args_schema: Type[BaseModel] = FirecrawlScrapeWebsiteToolSchema
|
||||
api_key: Optional[str] = None
|
||||
config: Dict[str, Any] = Field(
|
||||
args_schema: type[BaseModel] = FirecrawlScrapeWebsiteToolSchema
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
@@ -55,12 +55,16 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
)
|
||||
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True),
|
||||
package_dependencies: list[str] = ["firecrawl-py"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="FIRECRAWL_API_KEY",
|
||||
description="API key for Firecrawl services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -36,17 +36,14 @@ class FirecrawlSearchTool(BaseTool):
|
||||
timeout (int): Timeout in milliseconds. Default: 60000
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
name: str = "Firecrawl web search tool"
|
||||
description: str = "Search webpages using Firecrawl and return the results"
|
||||
args_schema: Type[BaseModel] = FirecrawlSearchToolSchema
|
||||
api_key: Optional[str] = None
|
||||
config: Optional[dict[str, Any]] = Field(
|
||||
args_schema: type[BaseModel] = FirecrawlSearchToolSchema
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"limit": 5,
|
||||
"tbs": None,
|
||||
@@ -57,12 +54,16 @@ class FirecrawlSearchTool(BaseTool):
|
||||
}
|
||||
)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="FIRECRAWL_API_KEY", description="API key for Firecrawl services", required=True),
|
||||
package_dependencies: list[str] = ["firecrawl-py"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="FIRECRAWL_API_KEY",
|
||||
description="API key for Firecrawl services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self._initialize_firecrawl()
|
||||
@@ -116,4 +117,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
@@ -10,7 +9,7 @@ class GenerateCrewaiAutomationToolSchema(BaseModel):
|
||||
prompt: str = Field(
|
||||
description="The prompt to generate the CrewAI automation, e.g. 'Generate a CrewAI automation that will scrape the website and store the data in a database.'"
|
||||
)
|
||||
organization_id: Optional[str] = Field(
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="The identifier for the CrewAI Enterprise organization. If not specified, a default organization will be used.",
|
||||
)
|
||||
@@ -23,16 +22,16 @@ class GenerateCrewaiAutomationTool(BaseTool):
|
||||
"automations based on natural language descriptions. It translates high-level requirements into "
|
||||
"functional CrewAI implementations."
|
||||
)
|
||||
args_schema: Type[BaseModel] = GenerateCrewaiAutomationToolSchema
|
||||
args_schema: type[BaseModel] = GenerateCrewaiAutomationToolSchema
|
||||
crewai_enterprise_url: str = Field(
|
||||
default_factory=lambda: os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com"),
|
||||
description="The base URL of CrewAI Enterprise. If not provided, it will be loaded from the environment variable CREWAI_PLUS_URL with default https://app.crewai.com.",
|
||||
)
|
||||
personal_access_token: Optional[str] = Field(
|
||||
personal_access_token: str | None = Field(
|
||||
default_factory=lambda: os.getenv("CREWAI_PERSONAL_ACCESS_TOKEN"),
|
||||
description="The user's Personal Access Token to access CrewAI Enterprise API. If not provided, it will be loaded from the environment variable CREWAI_PERSONAL_ACCESS_TOKEN.",
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="CREWAI_PERSONAL_ACCESS_TOKEN",
|
||||
description="Personal Access Token for CrewAI Enterprise API",
|
||||
@@ -57,7 +56,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
|
||||
studio_project_url = response.json().get("url")
|
||||
return f"Generated CrewAI Studio project URL: {studio_project_url}"
|
||||
|
||||
def _get_headers(self, organization_id: Optional[str] = None) -> dict:
|
||||
def _get_headers(self, organization_id: str | None = None) -> dict:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.personal_access_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
from typing import List, Optional, Type, Any
|
||||
|
||||
try:
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -24,7 +17,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
"""Input for GithubSearchTool."""
|
||||
|
||||
github_repo: str = Field(..., description="Mandatory github you want to search")
|
||||
content_types: List[str] = Field(
|
||||
content_types: list[str] = Field(
|
||||
...,
|
||||
description="Mandatory content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
@@ -32,28 +25,22 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
content_types: List[str] = Field(
|
||||
args_schema: type[BaseModel] = GithubSearchToolSchema
|
||||
content_types: list[str] = Field(
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: Any | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
github_repo: str | None = None,
|
||||
content_types: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
if github_repo and content_types:
|
||||
self.add(repo=github_repo, content_types=content_types)
|
||||
@@ -64,25 +51,28 @@ class GithubSearchTool(RagTool):
|
||||
def add(
|
||||
self,
|
||||
repo: str,
|
||||
content_types: Optional[List[str]] = None,
|
||||
content_types: list[str] | None = None,
|
||||
) -> None:
|
||||
content_types = content_types or self.content_types
|
||||
|
||||
super().add(
|
||||
f"repo:{repo} type:{','.join(content_types)}",
|
||||
data_type="github",
|
||||
loader=self._loader,
|
||||
f"https://github.com/{repo}",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": content_types, "gh_token": self.gh_token},
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
github_repo: str | None = None,
|
||||
content_types: list[str] | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if github_repo:
|
||||
self.add(
|
||||
repo=github_repo,
|
||||
content_types=content_types,
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, Dict, Literal, Union, List
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,8 +7,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class HyperbrowserLoadToolSchema(BaseModel):
|
||||
url: str = Field(description="Website URL")
|
||||
operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'")
|
||||
params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait")
|
||||
operation: Literal["scrape", "crawl"] = Field(
|
||||
description="Operation to perform on the website. Either 'scrape' or 'crawl'"
|
||||
)
|
||||
params: dict | None = Field(
|
||||
description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait"
|
||||
)
|
||||
|
||||
|
||||
class HyperbrowserLoadTool(BaseTool):
|
||||
"""HyperbrowserLoadTool.
|
||||
@@ -20,19 +25,24 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
Args:
|
||||
api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly
|
||||
"""
|
||||
|
||||
name: str = "Hyperbrowser web load tool"
|
||||
description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html"
|
||||
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
|
||||
api_key: Optional[str] = None
|
||||
hyperbrowser: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["hyperbrowser"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="HYPERBROWSER_API_KEY", description="API key for Hyperbrowser services", required=False),
|
||||
args_schema: type[BaseModel] = HyperbrowserLoadToolSchema
|
||||
api_key: str | None = None
|
||||
hyperbrowser: Any | None = None
|
||||
package_dependencies: list[str] = ["hyperbrowser"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(
|
||||
name="HYPERBROWSER_API_KEY",
|
||||
description="API key for Hyperbrowser services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY')
|
||||
self.api_key = api_key or os.getenv("HYPERBROWSER_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly"
|
||||
@@ -41,18 +51,22 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
try:
|
||||
from hyperbrowser import Hyperbrowser
|
||||
except ImportError:
|
||||
raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`")
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.")
|
||||
raise ValueError(
|
||||
"HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable."
|
||||
)
|
||||
|
||||
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
|
||||
|
||||
def _prepare_params(self, params: Dict) -> Dict:
|
||||
def _prepare_params(self, params: dict) -> dict:
|
||||
"""Prepare session and scrape options parameters."""
|
||||
try:
|
||||
from hyperbrowser.models.session import CreateSessionParams
|
||||
from hyperbrowser.models.scrape import ScrapeOptions
|
||||
from hyperbrowser.models.session import CreateSessionParams
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
@@ -70,17 +84,24 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
|
||||
return params
|
||||
|
||||
def _extract_content(self, data: Union[Any, None]):
|
||||
def _extract_content(self, data: Any | None):
|
||||
"""Extract content from response data."""
|
||||
content = ""
|
||||
if data:
|
||||
content = data.markdown or data.html or ""
|
||||
return content
|
||||
|
||||
def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}):
|
||||
def _run(
|
||||
self,
|
||||
url: str,
|
||||
operation: Literal["scrape", "crawl"] = "scrape",
|
||||
params: dict | None = None,
|
||||
):
|
||||
if params is None:
|
||||
params = {}
|
||||
try:
|
||||
from hyperbrowser.models.scrape import StartScrapeJobParams
|
||||
from hyperbrowser.models.crawl import StartCrawlJobParams
|
||||
from hyperbrowser.models.scrape import StartScrapeJobParams
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
@@ -88,20 +109,18 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
|
||||
params = self._prepare_params(params)
|
||||
|
||||
if operation == 'scrape':
|
||||
if operation == "scrape":
|
||||
scrape_params = StartScrapeJobParams(url=url, **params)
|
||||
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
|
||||
content = self._extract_content(scrape_resp.data)
|
||||
return content
|
||||
else:
|
||||
crawl_params = StartCrawlJobParams(url=url, **params)
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
|
||||
content = ""
|
||||
if crawl_resp.data:
|
||||
for page in crawl_resp.data:
|
||||
page_content = self._extract_content(page)
|
||||
if page_content:
|
||||
content += (
|
||||
f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n"
|
||||
)
|
||||
return content
|
||||
return self._extract_content(scrape_resp.data)
|
||||
crawl_params = StartCrawlJobParams(url=url, **params)
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
|
||||
content = ""
|
||||
if crawl_resp.data:
|
||||
for page in crawl_resp.data:
|
||||
page_content = self._extract_content(page)
|
||||
if page_content:
|
||||
content += (
|
||||
f"\n{'-' * 50}\nUrl: {page.url}\nContent:\n{page_content}\n"
|
||||
)
|
||||
return content
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Type
|
||||
import requests
|
||||
import time
|
||||
|
||||
|
||||
class InvokeCrewAIAutomationInput(BaseModel):
|
||||
"""Input schema for InvokeCrewAIAutomationTool."""
|
||||
|
||||
prompt: str = Field(..., description="The prompt or query to send to the crew")
|
||||
|
||||
|
||||
class InvokeCrewAIAutomationTool(BaseTool):
|
||||
"""
|
||||
A CrewAI tool for invoking external crew/flows APIs.
|
||||
|
||||
|
||||
This tool provides CrewAI Platform API integration with external crew services, supporting:
|
||||
- Dynamic input schema configuration
|
||||
- Automatic polling for task completion
|
||||
- Bearer token authentication
|
||||
- Comprehensive error handling
|
||||
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
>>> tool = InvokeCrewAIAutomationTool(
|
||||
@@ -26,7 +30,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
... crew_name="My Crew",
|
||||
... crew_description="Description of what the crew does"
|
||||
... )
|
||||
|
||||
|
||||
With custom inputs:
|
||||
>>> custom_inputs = {
|
||||
... "param1": Field(..., description="Description of param1"),
|
||||
@@ -39,7 +43,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
... crew_description="Description of what the crew does",
|
||||
... crew_inputs=custom_inputs
|
||||
... )
|
||||
|
||||
|
||||
Example:
|
||||
>>> tools=[
|
||||
... InvokeCrewAIAutomationTool(
|
||||
@@ -53,25 +57,27 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
... )
|
||||
... ]
|
||||
"""
|
||||
|
||||
name: str = "invoke_amp_automation"
|
||||
description: str = "Invokes an CrewAI Platform Automation using API"
|
||||
args_schema: Type[BaseModel] = InvokeCrewAIAutomationInput
|
||||
|
||||
args_schema: type[BaseModel] = InvokeCrewAIAutomationInput
|
||||
|
||||
crew_api_url: str
|
||||
crew_bearer_token: str
|
||||
max_polling_time: int = 10 * 60 # 10 minutes
|
||||
|
||||
max_polling_time: int = 10 * 60 # 10 minutes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew_api_url: str,
|
||||
crew_bearer_token: str,
|
||||
self,
|
||||
crew_api_url: str,
|
||||
crew_bearer_token: str,
|
||||
crew_name: str,
|
||||
crew_description: str,
|
||||
max_polling_time: int = 10 * 60,
|
||||
crew_inputs: dict[str, Any] = None):
|
||||
crew_inputs: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the InvokeCrewAIAutomationTool.
|
||||
|
||||
|
||||
Args:
|
||||
crew_api_url: Base URL of the crew API service
|
||||
crew_bearer_token: Bearer token for API authentication
|
||||
@@ -84,7 +90,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
if crew_inputs:
|
||||
# Start with the base prompt field
|
||||
fields = {}
|
||||
|
||||
|
||||
# Add custom fields
|
||||
for field_name, field_def in crew_inputs.items():
|
||||
if isinstance(field_def, tuple):
|
||||
@@ -92,12 +98,12 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
else:
|
||||
# Assume it's a Field object, extract type from annotation if available
|
||||
fields[field_name] = (str, field_def)
|
||||
|
||||
|
||||
# Create dynamic model
|
||||
args_schema = create_model('DynamicInvokeCrewAIAutomationInput', **fields)
|
||||
args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields)
|
||||
else:
|
||||
args_schema = InvokeCrewAIAutomationInput
|
||||
|
||||
|
||||
# Initialize the parent class with proper field values
|
||||
super().__init__(
|
||||
name=crew_name,
|
||||
@@ -105,7 +111,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
args_schema=args_schema,
|
||||
crew_api_url=crew_api_url,
|
||||
crew_bearer_token=crew_bearer_token,
|
||||
max_polling_time=max_polling_time
|
||||
max_polling_time=max_polling_time,
|
||||
)
|
||||
|
||||
def _kickoff_crew(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -125,8 +131,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
},
|
||||
json={"inputs": inputs},
|
||||
)
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
return response.json()
|
||||
|
||||
def _get_crew_status(self, crew_id: str) -> dict[str, Any]:
|
||||
"""Get the status of a crew task
|
||||
@@ -150,27 +155,27 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
"""Execute the crew invocation tool."""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
|
||||
# Start the crew
|
||||
response = self._kickoff_crew(inputs=kwargs)
|
||||
|
||||
|
||||
if response.get("kickoff_id") is None:
|
||||
return f"Error: Failed to kickoff crew. Response: {response}"
|
||||
|
||||
kickoff_id = response.get("kickoff_id")
|
||||
|
||||
|
||||
# Poll for completion
|
||||
for i in range(self.max_polling_time):
|
||||
try:
|
||||
status_response = self._get_crew_status(crew_id=kickoff_id)
|
||||
if status_response.get("state", "").lower() == "success":
|
||||
return status_response.get("result", "No result returned")
|
||||
elif status_response.get("state", "").lower() == "failed":
|
||||
if status_response.get("state", "").lower() == "failed":
|
||||
return f"Error: Crew task failed. Response: {status_response}"
|
||||
except Exception as e:
|
||||
if i == self.max_polling_time - 1: # Last attempt
|
||||
return f"Error: Failed to get crew status after {self.max_polling_time} attempts. Last error: {e}"
|
||||
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
return f"Error: Crew did not complete within {self.max_polling_time} seconds"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,16 +12,16 @@ class JinaScrapeWebsiteToolInput(BaseModel):
|
||||
class JinaScrapeWebsiteTool(BaseTool):
|
||||
name: str = "JinaScrapeWebsiteTool"
|
||||
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
args_schema: type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: str | None = None
|
||||
api_key: str | None = None
|
||||
headers: dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
website_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
custom_headers: Optional[dict] = None,
|
||||
website_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -38,7 +36,7 @@ class JinaScrapeWebsiteTool(BaseTool):
|
||||
if api_key is not None:
|
||||
self.headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
def _run(self, website_url: Optional[str] = None) -> str:
|
||||
def _run(self, website_url: str | None = None) -> str:
|
||||
url = website_url or self.website_url
|
||||
if not url:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -27,9 +25,9 @@ class JSONSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a JSON's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = JSONSearchToolSchema
|
||||
args_schema: type[BaseModel] = JSONSearchToolSchema
|
||||
|
||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||
def __init__(self, json_path: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
@@ -40,8 +38,12 @@ class JSONSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
json_path: Optional[str] = None,
|
||||
json_path: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(
|
||||
query=search_query, similarity_threshold=similarity_threshold, limit=limit
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
|
||||
@@ -20,12 +20,8 @@ class LinkupSearchTool(BaseTool):
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
description: str = (
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
package_dependencies: List[str] = ["linkup-sdk"]
|
||||
env_vars: List[EnvVar] = [
|
||||
package_dependencies: list[str] = ["linkup-sdk"]
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(name="LINKUP_API_KEY", description="API key for Linkup", required=True),
|
||||
]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user