mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
chore: apply linting fixes to crewai-tools
This commit is contained in:
@@ -24,7 +24,7 @@ CrewAI provides an extensive collection of powerful tools ready to enhance your
|
||||
|
||||
- **File Management**: `FileReadTool`, `FileWriteTool`
|
||||
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
|
||||
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
|
||||
- **Database Integrations**: `MySQLSearchTool`
|
||||
- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool`
|
||||
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
|
||||
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from collections.abc import Mapping
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from crewai_tools import tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
@@ -18,19 +20,19 @@ class SchemaGenerator(GenerateJsonSchema):
|
||||
|
||||
class ToolSpecExtractor:
|
||||
def __init__(self) -> None:
|
||||
self.tools_spec: List[Dict[str, Any]] = []
|
||||
self.tools_spec: list[dict[str, Any]] = []
|
||||
self.processed_tools: set[str] = set()
|
||||
|
||||
def extract_all_tools(self) -> List[Dict[str, Any]]:
|
||||
def extract_all_tools(self) -> list[dict[str, Any]]:
|
||||
for name in dir(tools):
|
||||
if name.endswith("Tool") and name not in self.processed_tools:
|
||||
obj = getattr(tools, name, None)
|
||||
if inspect.isclass(obj):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool):
|
||||
self.extract_tool_info(obj)
|
||||
self.processed_tools.add(name)
|
||||
return self.tools_spec
|
||||
|
||||
def extract_tool_info(self, tool_class: BaseTool) -> None:
|
||||
def extract_tool_info(self, tool_class: type[BaseTool]) -> None:
|
||||
try:
|
||||
core_schema = tool_class.__pydantic_core_schema__
|
||||
if not core_schema:
|
||||
@@ -44,8 +46,8 @@ class ToolSpecExtractor:
|
||||
"humanized_name": self._extract_field_default(
|
||||
fields.get("name"), fallback=tool_class.__name__
|
||||
),
|
||||
"description": self._extract_field_default(
|
||||
fields.get("description")
|
||||
"description": str(
|
||||
self._extract_field_default(fields.get("description"))
|
||||
).strip(),
|
||||
"run_params_schema": self._extract_params(fields.get("args_schema")),
|
||||
"init_params_schema": self._extract_init_params(tool_class),
|
||||
@@ -57,17 +59,22 @@ class ToolSpecExtractor:
|
||||
|
||||
self.tools_spec.append(tool_info)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting {tool_class.__name__}: {e}")
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _unwrap_schema(self, schema: Dict) -> Dict:
|
||||
@staticmethod
|
||||
def _unwrap_schema(schema: Mapping[str, Any] | dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = dict(schema)
|
||||
while (
|
||||
schema.get("type") in {"function-after", "default"} and "schema" in schema
|
||||
result.get("type") in {"function-after", "default"} and "schema" in result
|
||||
):
|
||||
schema = schema["schema"]
|
||||
return schema
|
||||
result = dict(result["schema"])
|
||||
return result
|
||||
|
||||
def _extract_field_default(self, field: Optional[Dict], fallback: str = "") -> str:
|
||||
@staticmethod
|
||||
def _extract_field_default(
|
||||
field: dict | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
if not field:
|
||||
return fallback
|
||||
|
||||
@@ -75,45 +82,43 @@ class ToolSpecExtractor:
|
||||
default = schema.get("default")
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
def _extract_params(
|
||||
self, args_schema_field: Optional[Dict]
|
||||
) -> List[Dict[str, str]]:
|
||||
@staticmethod
|
||||
def _extract_params(args_schema_field: dict | None) -> dict[str, Any]:
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
args_schema_class = args_schema_field.get("schema", {}).get("default")
|
||||
if not (
|
||||
inspect.isclass(args_schema_class)
|
||||
and hasattr(args_schema_class, "__pydantic_core_schema__")
|
||||
and issubclass(args_schema_class, BaseModel)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Cast to type[BaseModel] after runtime check
|
||||
schema_class = cast(type[BaseModel], args_schema_class)
|
||||
try:
|
||||
return args_schema_class.model_json_schema(
|
||||
schema_generator=SchemaGenerator, mode="validation"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error extracting params from {args_schema_class}: {e}")
|
||||
return schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _extract_env_vars(self, env_vars_field: Optional[Dict]) -> List[Dict[str, str]]:
|
||||
@staticmethod
|
||||
def _extract_env_vars(env_vars_field: dict | None) -> list[dict[str, Any]]:
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
env_vars = []
|
||||
for env_var in env_vars_field.get("schema", {}).get("default", []):
|
||||
if isinstance(env_var, EnvVar):
|
||||
env_vars.append(
|
||||
{
|
||||
"name": env_var.name,
|
||||
"description": env_var.description,
|
||||
"required": env_var.required,
|
||||
"default": env_var.default,
|
||||
}
|
||||
)
|
||||
return env_vars
|
||||
return [
|
||||
{
|
||||
"name": env_var.name,
|
||||
"description": env_var.description,
|
||||
"required": env_var.required,
|
||||
"default": env_var.default,
|
||||
}
|
||||
for env_var in env_vars_field.get("schema", {}).get("default", [])
|
||||
if isinstance(env_var, EnvVar)
|
||||
]
|
||||
|
||||
def _extract_init_params(self, tool_class: BaseTool) -> dict:
|
||||
@staticmethod
|
||||
def _extract_init_params(tool_class: type[BaseTool]) -> dict[str, Any]:
|
||||
ignored_init_params = [
|
||||
"name",
|
||||
"description",
|
||||
@@ -131,25 +136,21 @@ class ToolSpecExtractor:
|
||||
schema_generator=SchemaGenerator, mode="serialization"
|
||||
)
|
||||
|
||||
properties = {}
|
||||
for key, value in json_schema["properties"].items():
|
||||
if key not in ignored_init_params:
|
||||
properties[key] = value
|
||||
|
||||
json_schema["properties"] = properties
|
||||
json_schema["properties"] = {
|
||||
key: value
|
||||
for key, value in json_schema["properties"].items()
|
||||
if key not in ignored_init_params
|
||||
}
|
||||
return json_schema
|
||||
|
||||
def save_to_json(self, output_path: str) -> None:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"tools": self.tools_spec}, f, indent=2, sort_keys=True)
|
||||
print(f"Saved tool specs to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
output_file = Path(__file__).parent / "tool.specs.json"
|
||||
extractor = ToolSpecExtractor()
|
||||
|
||||
specs = extractor.extract_all_tools()
|
||||
extractor.extract_all_tools()
|
||||
extractor.save_to_json(str(output_file))
|
||||
|
||||
print(f"Extracted {len(specs)} tool classes.")
|
||||
|
||||
@@ -10,7 +10,7 @@ requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"lancedb>=0.5.4",
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.0.0a1",
|
||||
"lancedb>=0.5.4",
|
||||
|
||||
@@ -1,102 +1,294 @@
|
||||
# ruff: noqa: F401
|
||||
from .adapters.enterprise_adapter import EnterpriseActionTool
|
||||
from .adapters.mcp_adapter import MCPServerAdapter
|
||||
from .adapters.zapier_adapter import ZapierActionTool
|
||||
from .aws import (
|
||||
BedrockInvokeAgentTool,
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionTool
|
||||
from crewai_tools.adapters.mcp_adapter import MCPServerAdapter
|
||||
from crewai_tools.adapters.zapier_adapter import ZapierActionTool
|
||||
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
from crewai_tools.aws.bedrock.knowledge_base.retriever_tool import (
|
||||
BedrockKBRetrieverTool,
|
||||
S3ReaderTool,
|
||||
S3WriterTool,
|
||||
)
|
||||
from .tools import (
|
||||
AIMindTool,
|
||||
ApifyActorsTool,
|
||||
ArxivPaperTool,
|
||||
BraveSearchTool,
|
||||
from crewai_tools.aws.s3.reader_tool import S3ReaderTool
|
||||
from crewai_tools.aws.s3.writer_tool import S3WriterTool
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_dataset import (
|
||||
BrightDataDatasetTool,
|
||||
BrightDataSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_serp import BrightDataSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_unlocker import (
|
||||
BrightDataWebUnlockerTool,
|
||||
)
|
||||
from crewai_tools.tools.browserbase_load_tool.browserbase_load_tool import (
|
||||
BrowserbaseLoadTool,
|
||||
CSVSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.code_docs_search_tool.code_docs_search_tool import (
|
||||
CodeDocsSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
CodeInterpreterTool,
|
||||
ComposioTool,
|
||||
)
|
||||
from crewai_tools.tools.composio_tool.composio_tool import ComposioTool
|
||||
from crewai_tools.tools.contextualai_create_agent_tool.contextual_create_agent_tool import (
|
||||
ContextualAICreateAgentTool,
|
||||
)
|
||||
from crewai_tools.tools.contextualai_parse_tool.contextual_parse_tool import (
|
||||
ContextualAIParseTool,
|
||||
)
|
||||
from crewai_tools.tools.contextualai_query_tool.contextual_query_tool import (
|
||||
ContextualAIQueryTool,
|
||||
)
|
||||
from crewai_tools.tools.contextualai_rerank_tool.contextual_rerank_tool import (
|
||||
ContextualAIRerankTool,
|
||||
)
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
|
||||
CouchbaseFTSVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools import (
|
||||
CrewaiEnterpriseTools,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
DOCXSearchTool,
|
||||
DallETool,
|
||||
)
|
||||
from crewai_tools.tools.csv_search_tool.csv_search_tool import CSVSearchTool
|
||||
from crewai_tools.tools.dalle_tool.dalle_tool import DallETool
|
||||
from crewai_tools.tools.databricks_query_tool.databricks_query_tool import (
|
||||
DatabricksQueryTool,
|
||||
)
|
||||
from crewai_tools.tools.directory_read_tool.directory_read_tool import (
|
||||
DirectoryReadTool,
|
||||
)
|
||||
from crewai_tools.tools.directory_search_tool.directory_search_tool import (
|
||||
DirectorySearchTool,
|
||||
EXASearchTool,
|
||||
)
|
||||
from crewai_tools.tools.docx_search_tool.docx_search_tool import DOCXSearchTool
|
||||
from crewai_tools.tools.exa_tools.exa_search_tool import EXASearchTool
|
||||
from crewai_tools.tools.file_read_tool.file_read_tool import FileReadTool
|
||||
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool
|
||||
from crewai_tools.tools.files_compressor_tool.files_compressor_tool import (
|
||||
FileCompressorTool,
|
||||
FileReadTool,
|
||||
FileWriterTool,
|
||||
)
|
||||
from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import (
|
||||
FirecrawlSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automation_tool import (
|
||||
GenerateCrewaiAutomationTool,
|
||||
GithubSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.github_search_tool.github_search_tool import GithubSearchTool
|
||||
from crewai_tools.tools.hyperbrowser_load_tool.hyperbrowser_load_tool import (
|
||||
HyperbrowserLoadTool,
|
||||
)
|
||||
from crewai_tools.tools.invoke_crewai_automation_tool.invoke_crewai_automation_tool import (
|
||||
InvokeCrewAIAutomationTool,
|
||||
JSONSearchTool,
|
||||
LinkupSearchTool,
|
||||
LlamaIndexTool,
|
||||
MDXSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import (
|
||||
JinaScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.mongodb_vector_search_tool.vector_search import (
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
MultiOnTool,
|
||||
MySQLSearchTool,
|
||||
NL2SQLTool,
|
||||
OCRTool,
|
||||
)
|
||||
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
|
||||
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
from crewai_tools.tools.ocr_tool.ocr_tool import OCRTool
|
||||
from crewai_tools.tools.oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
|
||||
OxylabsAmazonProductScraperTool,
|
||||
)
|
||||
from crewai_tools.tools.oxylabs_amazon_search_scraper_tool.oxylabs_amazon_search_scraper_tool import (
|
||||
OxylabsAmazonSearchScraperTool,
|
||||
)
|
||||
from crewai_tools.tools.oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
|
||||
OxylabsGoogleSearchScraperTool,
|
||||
)
|
||||
from crewai_tools.tools.oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
|
||||
OxylabsUniversalScraperTool,
|
||||
PDFSearchTool,
|
||||
PGSearchTool,
|
||||
ParallelSearchTool,
|
||||
PatronusEvalTool,
|
||||
)
|
||||
from crewai_tools.tools.parallel_tools.parallel_search_tool import ParallelSearchTool
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_eval_tool import PatronusEvalTool
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_local_evaluator_tool import (
|
||||
PatronusLocalEvaluatorTool,
|
||||
)
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_predefined_criteria_eval_tool import (
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
from crewai_tools.tools.qdrant_vector_search_tool.qdrant_search_tool import (
|
||||
QdrantVectorSearchTool,
|
||||
RagTool,
|
||||
)
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
from crewai_tools.tools.scrape_element_from_website.scrape_element_from_website import (
|
||||
ScrapeElementFromWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.scrape_website_tool.scrape_website_tool import (
|
||||
ScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.scrapegraph_scrape_tool.scrapegraph_scrape_tool import (
|
||||
ScrapegraphScrapeTool,
|
||||
ScrapegraphScrapeToolSchema,
|
||||
)
|
||||
from crewai_tools.tools.scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import (
|
||||
ScrapflyScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
|
||||
SeleniumScrapingTool,
|
||||
)
|
||||
from crewai_tools.tools.serpapi_tool.serpapi_google_search_tool import (
|
||||
SerpApiGoogleSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serpapi_tool.serpapi_google_shopping_tool import (
|
||||
SerpApiGoogleShoppingTool,
|
||||
SerperDevTool,
|
||||
)
|
||||
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
from crewai_tools.tools.serper_scrape_website_tool.serper_scrape_website_tool import (
|
||||
SerperScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_job_search_tool import (
|
||||
SerplyJobSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_news_search_tool import (
|
||||
SerplyNewsSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_scholar_search_tool import (
|
||||
SerplyScholarSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_web_search_tool import (
|
||||
SerplyWebSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_webpage_to_markdown_tool import (
|
||||
SerplyWebpageToMarkdownTool,
|
||||
)
|
||||
from crewai_tools.tools.singlestore_search_tool.singlestore_search_tool import (
|
||||
SingleStoreSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SpiderTool,
|
||||
StagehandTool,
|
||||
TXTSearchTool,
|
||||
TavilyExtractorTool,
|
||||
TavilySearchTool,
|
||||
VisionTool,
|
||||
WeaviateVectorSearchTool,
|
||||
WebsiteSearchTool,
|
||||
XMLSearchTool,
|
||||
YoutubeChannelSearchTool,
|
||||
YoutubeVideoSearchTool,
|
||||
ZapierActionTools,
|
||||
)
|
||||
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
|
||||
from crewai_tools.tools.stagehand_tool.stagehand_tool import StagehandTool
|
||||
from crewai_tools.tools.tavily_extractor_tool.tavily_extractor_tool import (
|
||||
TavilyExtractorTool,
|
||||
)
|
||||
from crewai_tools.tools.tavily_search_tool.tavily_search_tool import TavilySearchTool
|
||||
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from crewai_tools.tools.vision_tool.vision_tool import VisionTool
|
||||
from crewai_tools.tools.weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||
from crewai_tools.tools.website_search.website_search_tool import WebsiteSearchTool
|
||||
from crewai_tools.tools.xml_search_tool.xml_search_tool import XMLSearchTool
|
||||
from crewai_tools.tools.youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||
YoutubeChannelSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.youtube_video_search_tool.youtube_video_search_tool import (
|
||||
YoutubeVideoSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.zapier_action_tool.zapier_action_tool import ZapierActionTools
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMindTool",
|
||||
"ApifyActorsTool",
|
||||
"ArxivPaperTool",
|
||||
"BedrockInvokeAgentTool",
|
||||
"BedrockKBRetrieverTool",
|
||||
"BraveSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
"BrowserbaseLoadTool",
|
||||
"CSVSearchTool",
|
||||
"CodeDocsSearchTool",
|
||||
"CodeInterpreterTool",
|
||||
"ComposioTool",
|
||||
"ContextualAICreateAgentTool",
|
||||
"ContextualAIParseTool",
|
||||
"ContextualAIQueryTool",
|
||||
"ContextualAIRerankTool",
|
||||
"CouchbaseFTSVectorSearchTool",
|
||||
"CrewaiEnterpriseTools",
|
||||
"CrewaiPlatformTools",
|
||||
"DOCXSearchTool",
|
||||
"DallETool",
|
||||
"DatabricksQueryTool",
|
||||
"DirectoryReadTool",
|
||||
"DirectorySearchTool",
|
||||
"EXASearchTool",
|
||||
"EnterpriseActionTool",
|
||||
"FileCompressorTool",
|
||||
"FileReadTool",
|
||||
"FileWriterTool",
|
||||
"FirecrawlCrawlWebsiteTool",
|
||||
"FirecrawlScrapeWebsiteTool",
|
||||
"FirecrawlSearchTool",
|
||||
"GenerateCrewaiAutomationTool",
|
||||
"GithubSearchTool",
|
||||
"HyperbrowserLoadTool",
|
||||
"InvokeCrewAIAutomationTool",
|
||||
"JSONSearchTool",
|
||||
"JinaScrapeWebsiteTool",
|
||||
"LinkupSearchTool",
|
||||
"LlamaIndexTool",
|
||||
"MCPServerAdapter",
|
||||
"MDXSearchTool",
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
"MultiOnTool",
|
||||
"MySQLSearchTool",
|
||||
"NL2SQLTool",
|
||||
"OCRTool",
|
||||
"OxylabsAmazonProductScraperTool",
|
||||
"OxylabsAmazonSearchScraperTool",
|
||||
"OxylabsGoogleSearchScraperTool",
|
||||
"OxylabsUniversalScraperTool",
|
||||
"PDFSearchTool",
|
||||
"ParallelSearchTool",
|
||||
"PatronusEvalTool",
|
||||
"PatronusLocalEvaluatorTool",
|
||||
"PatronusPredefinedCriteriaEvalTool",
|
||||
"QdrantVectorSearchTool",
|
||||
"RagTool",
|
||||
"S3ReaderTool",
|
||||
"S3WriterTool",
|
||||
"ScrapeElementFromWebsiteTool",
|
||||
"ScrapeWebsiteTool",
|
||||
"ScrapegraphScrapeTool",
|
||||
"ScrapegraphScrapeToolSchema",
|
||||
"ScrapflyScrapeWebsiteTool",
|
||||
"SeleniumScrapingTool",
|
||||
"SerpApiGoogleSearchTool",
|
||||
"SerpApiGoogleShoppingTool",
|
||||
"SerperDevTool",
|
||||
"SerperScrapeWebsiteTool",
|
||||
"SerplyJobSearchTool",
|
||||
"SerplyNewsSearchTool",
|
||||
"SerplyScholarSearchTool",
|
||||
"SerplyWebSearchTool",
|
||||
"SerplyWebpageToMarkdownTool",
|
||||
"SingleStoreSearchTool",
|
||||
"SnowflakeConfig",
|
||||
"SnowflakeSearchTool",
|
||||
"SpiderTool",
|
||||
"StagehandTool",
|
||||
"TXTSearchTool",
|
||||
"TavilyExtractorTool",
|
||||
"TavilySearchTool",
|
||||
"VisionTool",
|
||||
"WeaviateVectorSearchTool",
|
||||
"WebsiteSearchTool",
|
||||
"XMLSearchTool",
|
||||
"YoutubeChannelSearchTool",
|
||||
"YoutubeVideoSearchTool",
|
||||
"ZapierActionTool",
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0a1"
|
||||
|
||||
@@ -9,11 +9,12 @@ from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from pydantic import PrivateAttr
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
@@ -213,7 +214,7 @@ class CrewAIRagAdapter(Adapter):
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
except Exception: # noqa: S112
|
||||
# Silently skip files that can't be processed
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast, get_origin
|
||||
from typing import Any, Literal, Optional, Union, cast, get_origin
|
||||
import warnings
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -25,7 +25,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
default="", description="The enterprise action token"
|
||||
)
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
action_schema: Dict[str, Any] = Field(
|
||||
action_schema: dict[str, Any] = Field(
|
||||
default={}, description="The schema of the action"
|
||||
)
|
||||
enterprise_api_base_url: str = Field(
|
||||
@@ -38,8 +38,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
description: str,
|
||||
enterprise_action_token: str,
|
||||
action_name: str,
|
||||
action_schema: Dict[str, Any],
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
action_schema: dict[str, Any],
|
||||
enterprise_api_base_url: str | None = None,
|
||||
):
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(name)
|
||||
@@ -56,8 +56,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
field_type = self._process_schema_type(
|
||||
param_details, self._sanitize_name(param_name).title()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process schema for {param_name}: {e}")
|
||||
except Exception:
|
||||
field_type = str
|
||||
|
||||
# Create field definition based on requirement
|
||||
@@ -71,8 +70,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema", **field_definitions
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create main schema model: {e}")
|
||||
except Exception:
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema",
|
||||
input_text=(str, Field(description="Input for the action")),
|
||||
@@ -99,8 +97,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
return "".join(word.capitalize() for word in parts if word)
|
||||
|
||||
def _extract_schema_info(
|
||||
self, action_schema: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], List[str]]:
|
||||
self, action_schema: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Extract schema properties and required fields from action schema."""
|
||||
schema_props = (
|
||||
action_schema.get("function", {})
|
||||
@@ -112,7 +110,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
"""Process a JSON schema and return appropriate Python type."""
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
@@ -121,8 +119,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
|
||||
if non_null_types:
|
||||
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||
return Optional[base_type] if is_nullable else base_type
|
||||
return cast(Type[Any], Optional[str])
|
||||
return Optional[base_type] if is_nullable else base_type # noqa: UP045
|
||||
return cast(type[Any], Optional[str]) # noqa: UP045
|
||||
|
||||
if "oneOf" in schema:
|
||||
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||
@@ -141,7 +139,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return List[item_type]
|
||||
return list[item_type]
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
@@ -149,8 +147,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _create_nested_model(
|
||||
self, schema: Dict[str, Any], model_name: str
|
||||
) -> Type[Any]:
|
||||
self, schema: dict[str, Any], model_name: str
|
||||
) -> type[Any]:
|
||||
"""Create a nested Pydantic model for complex objects."""
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
@@ -172,8 +170,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
prop_type = self._process_schema_type(
|
||||
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process schema for {prop_name}: {e}")
|
||||
except Exception:
|
||||
prop_type = str
|
||||
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
@@ -184,12 +181,11 @@ class EnterpriseActionTool(BaseTool):
|
||||
nested_model = create_model(full_model_name, **field_definitions)
|
||||
self._model_registry[full_model_name] = nested_model
|
||||
return nested_model
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create nested model {full_model_name}: {e}")
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: Type[Any], is_required: bool, description: str
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
"""Create Pydantic field definition based on type and requirement."""
|
||||
if is_required:
|
||||
@@ -197,11 +193,11 @@ class EnterpriseActionTool(BaseTool):
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
return (
|
||||
Optional[field_type],
|
||||
Optional[field_type], # noqa: UP045
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
|
||||
"""Map basic JSON schema types to Python types."""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
@@ -214,7 +210,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
def _get_required_nullable_fields(self) -> List[str]:
|
||||
def _get_required_nullable_fields(self) -> list[str]:
|
||||
"""Get a list of required nullable fields from the action schema."""
|
||||
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||
|
||||
@@ -226,7 +222,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
|
||||
return required_nullable_fields
|
||||
|
||||
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
|
||||
"""Check if a schema represents a nullable type."""
|
||||
if "anyOf" in schema:
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
@@ -238,7 +234,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
cleaned_kwargs[key] = value
|
||||
cleaned_kwargs[key] = value # noqa: PERF403
|
||||
|
||||
required_nullable_fields = self._get_required_nullable_fields()
|
||||
|
||||
@@ -276,7 +272,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
enterprise_action_token: str,
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
enterprise_api_base_url: str | None = None,
|
||||
):
|
||||
"""Initialize the adapter with an enterprise action token."""
|
||||
self._set_enterprise_action_token(enterprise_action_token)
|
||||
@@ -286,7 +282,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
)
|
||||
|
||||
def tools(self) -> List[BaseTool]:
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Get the list of tools created from enterprise actions."""
|
||||
if self._tools is None:
|
||||
self._fetch_actions()
|
||||
@@ -304,13 +300,12 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
raw_data = response.json()
|
||||
if "actions" not in raw_data:
|
||||
print(f"Unexpected API response structure: {raw_data}")
|
||||
return
|
||||
|
||||
parsed_schema = {}
|
||||
action_categories = raw_data["actions"]
|
||||
|
||||
for integration_type, action_list in action_categories.items():
|
||||
for action_list in action_categories.values():
|
||||
if isinstance(action_list, list):
|
||||
for action in action_list:
|
||||
action_name = action.get("name")
|
||||
@@ -328,15 +323,14 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
self._actions_schema = parsed_schema
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching actions: {e}")
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def _generate_detailed_description(
|
||||
self, schema: Dict[str, Any], indent: int = 0
|
||||
) -> List[str]:
|
||||
self, schema: dict[str, Any], indent: int = 0
|
||||
) -> list[str]:
|
||||
"""Generate detailed description for nested schema structures."""
|
||||
descriptions = []
|
||||
indent_str = " " * indent
|
||||
@@ -413,7 +407,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
self._tools = tools
|
||||
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]):
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: str | None):
|
||||
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
|
||||
warnings.warn(
|
||||
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from lancedb import DBConnection as LanceDBConnection, connect as lancedb_connect
|
||||
from lancedb.table import Table as LanceDBTable
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
client = OpenAIClient()
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"""MCPServer for CrewAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
"""
|
||||
MCPServer for CrewAI.
|
||||
|
||||
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -77,8 +75,8 @@ class MCPServerAdapter:
|
||||
serverparams: StdioServerParameters | dict[str, Any],
|
||||
*tool_names: str,
|
||||
connect_timeout: int = 30,
|
||||
):
|
||||
"""Initialize the MCP Server
|
||||
) -> None:
|
||||
"""Initialize the MCP Server.
|
||||
|
||||
Args:
|
||||
serverparams: The parameters for the MCP server it supports either a
|
||||
@@ -88,7 +86,6 @@ class MCPServerAdapter:
|
||||
connect_timeout: Connection timeout in seconds to the MCP server (default is 30s).
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._adapter = None
|
||||
self._tools = None
|
||||
@@ -103,10 +100,10 @@ class MCPServerAdapter:
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True)
|
||||
subprocess.run(["uv", "add", "mcp crewai-tools[mcp]"], check=True) # noqa: S607
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install mcp package")
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install mcp package") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"`mcp` package not found, please run `uv add crewai-tools[mcp]`"
|
||||
@@ -132,7 +129,7 @@ class MCPServerAdapter:
|
||||
self._tools = self._adapter.__enter__()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the MCP server"""
|
||||
"""Stop the MCP server."""
|
||||
self._adapter.__exit__(None, None, None)
|
||||
|
||||
@property
|
||||
@@ -156,8 +153,7 @@ class MCPServerAdapter:
|
||||
return tools_collection
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the context manager. Note that `__init__()` already starts the MCP server.
|
||||
"""Enter the context manager. Note that `__init__()` already starts the MCP server.
|
||||
So tools should already be available.
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.core import RAG
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
@@ -8,10 +8,10 @@ class RAGAdapter(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "crewai_knowledge_base",
|
||||
persist_directory: Optional[str] = None,
|
||||
persist_directory: str | None = None,
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
top_k: int = 5,
|
||||
embedding_api_key: Optional[str] = None,
|
||||
embedding_api_key: str | None = None,
|
||||
**embedding_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Callable, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -7,8 +8,7 @@ T = TypeVar("T", bound=BaseTool)
|
||||
|
||||
|
||||
class ToolCollection(list, Generic[T]):
|
||||
"""
|
||||
A collection of tools that can be accessed by index or name
|
||||
"""A collection of tools that can be accessed by index or name.
|
||||
|
||||
This class extends the built-in list to provide dictionary-like
|
||||
access to tools based on their name property.
|
||||
@@ -21,15 +21,15 @@ class ToolCollection(list, Generic[T]):
|
||||
search_tool = tools["search"]
|
||||
"""
|
||||
|
||||
def __init__(self, tools: Optional[List[T]] = None):
|
||||
def __init__(self, tools: list[T] | None = None):
|
||||
super().__init__(tools or [])
|
||||
self._name_cache: Dict[str, T] = {}
|
||||
self._name_cache: dict[str, T] = {}
|
||||
self._build_name_cache()
|
||||
|
||||
def _build_name_cache(self) -> None:
|
||||
self._name_cache = {tool.name.lower(): tool for tool in self}
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> T:
|
||||
def __getitem__(self, key: int | str) -> T:
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key.lower()]
|
||||
return super().__getitem__(key)
|
||||
@@ -38,7 +38,7 @@ class ToolCollection(list, Generic[T]):
|
||||
super().append(tool)
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
def extend(self, tools: List[T]) -> None:
|
||||
def extend(self, tools: list[T]) -> None:
|
||||
super().extend(tools)
|
||||
self._build_name_cache()
|
||||
|
||||
@@ -57,7 +57,7 @@ class ToolCollection(list, Generic[T]):
|
||||
del self._name_cache[tool.name.lower()]
|
||||
return tool
|
||||
|
||||
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]":
|
||||
def filter_by_names(self, names: list[str] | None = None) -> "ToolCollection[T]":
|
||||
if names is None:
|
||||
return self
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import Field, create_model
|
||||
@@ -13,9 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZapierActionTool(BaseTool):
|
||||
"""
|
||||
A tool that wraps a Zapier action
|
||||
"""
|
||||
"""A tool that wraps a Zapier action."""
|
||||
|
||||
name: str = Field(description="Tool name")
|
||||
description: str = Field(description="Tool description")
|
||||
@@ -23,7 +20,7 @@ class ZapierActionTool(BaseTool):
|
||||
api_key: str = Field(description="Zapier API key")
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Execute the Zapier action"""
|
||||
"""Execute the Zapier action."""
|
||||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
||||
|
||||
instructions = kwargs.pop(
|
||||
@@ -43,7 +40,11 @@ class ZapierActionTool(BaseTool):
|
||||
|
||||
execute_url = f"{ACTIONS_URL}/{self.action_id}/execute/"
|
||||
response = requests.request(
|
||||
"POST", execute_url, headers=headers, json=action_params
|
||||
"POST",
|
||||
execute_url,
|
||||
headers=headers,
|
||||
json=action_params,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -52,13 +53,11 @@ class ZapierActionTool(BaseTool):
|
||||
|
||||
|
||||
class ZapierActionsAdapter:
|
||||
"""
|
||||
Adapter for Zapier Actions
|
||||
"""
|
||||
"""Adapter for Zapier Actions."""
|
||||
|
||||
api_key: str
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("ZAPIER_API_KEY")
|
||||
if not self.api_key:
|
||||
logger.error("Zapier Actions API key is required")
|
||||
@@ -68,14 +67,18 @@ class ZapierActionsAdapter:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
response = requests.request("GET", ACTIONS_URL, headers=headers)
|
||||
response = requests.request(
|
||||
"GET",
|
||||
ACTIONS_URL,
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
return response.json()
|
||||
|
||||
def tools(self) -> List[BaseTool]:
|
||||
"""Convert Zapier actions to BaseTool instances"""
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Convert Zapier actions to BaseTool instances."""
|
||||
actions_response = self.get_zapier_actions()
|
||||
tools = []
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
@@ -24,22 +23,22 @@ class BedrockInvokeAgentToolInput(BaseModel):
|
||||
class BedrockInvokeAgentTool(BaseTool):
|
||||
name: str = "Bedrock Agent Invoke Tool"
|
||||
description: str = "An agent responsible for policy analysis."
|
||||
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
args_schema: type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
agent_id: str = None
|
||||
agent_alias_id: str = None
|
||||
session_id: str = None
|
||||
enable_trace: bool = False
|
||||
end_session: bool = False
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str = None,
|
||||
agent_alias_id: str = None,
|
||||
session_id: str = None,
|
||||
agent_id: str | None = None,
|
||||
agent_alias_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
enable_trace: bool = False,
|
||||
end_session: bool = False,
|
||||
description: Optional[str] = None,
|
||||
description: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
||||
@@ -90,14 +89,16 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}")
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}") from e
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`boto3` package not found, please run `uv add boto3`"
|
||||
) from e
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
@@ -175,9 +176,9 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
error_code = e.response["Error"].get("Code", "Unknown")
|
||||
error_message = e.response["Error"].get("Message", str(e))
|
||||
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}") from e
|
||||
except BedrockAgentError:
|
||||
# Re-raise BedrockAgentError exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BedrockAgentError(f"Unexpected error: {e!s}")
|
||||
raise BedrockAgentError(f"Unexpected error: {e!s}") from e
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,8 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrowserSessionManager:
|
||||
"""
|
||||
Manages browser sessions for different threads.
|
||||
"""Manages browser sessions for different threads.
|
||||
|
||||
This class maintains separate browser sessions for different threads,
|
||||
enabling concurrent usage of browsers in multi-threaded environments.
|
||||
@@ -22,19 +21,17 @@ class BrowserSessionManager:
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
"""
|
||||
Initialize the browser session manager.
|
||||
"""Initialize the browser session manager.
|
||||
|
||||
Args:
|
||||
region: AWS region for browser client
|
||||
"""
|
||||
self.region = region
|
||||
self._async_sessions: Dict[str, Tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: Dict[str, Tuple[BrowserClient, SyncBrowser]] = {}
|
||||
self._async_sessions: dict[str, tuple[BrowserClient, AsyncBrowser]] = {}
|
||||
self._sync_sessions: dict[str, tuple[BrowserClient, SyncBrowser]] = {}
|
||||
|
||||
async def get_async_browser(self, thread_id: str) -> AsyncBrowser:
|
||||
"""
|
||||
Get or create an async browser for the specified thread.
|
||||
"""Get or create an async browser for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread requesting the browser
|
||||
@@ -48,8 +45,7 @@ class BrowserSessionManager:
|
||||
return await self._create_async_browser_session(thread_id)
|
||||
|
||||
def get_sync_browser(self, thread_id: str) -> SyncBrowser:
|
||||
"""
|
||||
Get or create a sync browser for the specified thread.
|
||||
"""Get or create a sync browser for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread requesting the browser
|
||||
@@ -63,8 +59,7 @@ class BrowserSessionManager:
|
||||
return self._create_sync_browser_session(thread_id)
|
||||
|
||||
async def _create_async_browser_session(self, thread_id: str) -> AsyncBrowser:
|
||||
"""
|
||||
Create a new async browser session for the specified thread.
|
||||
"""Create a new async browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
@@ -121,8 +116,7 @@ class BrowserSessionManager:
|
||||
raise
|
||||
|
||||
def _create_sync_browser_session(self, thread_id: str) -> SyncBrowser:
|
||||
"""
|
||||
Create a new sync browser session for the specified thread.
|
||||
"""Create a new sync browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
@@ -179,8 +173,7 @@ class BrowserSessionManager:
|
||||
raise
|
||||
|
||||
async def close_async_browser(self, thread_id: str) -> None:
|
||||
"""
|
||||
Close the async browser session for the specified thread.
|
||||
"""Close the async browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
@@ -214,8 +207,7 @@ class BrowserSessionManager:
|
||||
logger.info(f"Async browser session cleaned up for thread {thread_id}")
|
||||
|
||||
def close_sync_browser(self, thread_id: str) -> None:
|
||||
"""
|
||||
Close the sync browser session for the specified thread.
|
||||
"""Close the sync browser session for the specified thread.
|
||||
|
||||
Args:
|
||||
thread_id: Unique identifier for the thread
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -106,14 +106,12 @@ class BrowserBaseTool(BaseTool):
|
||||
async def get_async_page(self, thread_id: str) -> Any:
|
||||
"""Get or create a page for the specified thread."""
|
||||
browser = await self._session_manager.get_async_browser(thread_id)
|
||||
page = await aget_current_page(browser)
|
||||
return page
|
||||
return await aget_current_page(browser)
|
||||
|
||||
def get_sync_page(self, thread_id: str) -> Any:
|
||||
"""Get or create a page for the specified thread."""
|
||||
browser = self._session_manager.get_sync_browser(thread_id)
|
||||
page = get_current_page(browser)
|
||||
return page
|
||||
return get_current_page(browser)
|
||||
|
||||
def _is_in_asyncio_loop(self) -> bool:
|
||||
"""Check if we're currently in an asyncio event loop."""
|
||||
@@ -130,7 +128,7 @@ class NavigateTool(BrowserBaseTool):
|
||||
|
||||
name: str = "navigate_browser"
|
||||
description: str = "Navigate a browser to the specified URL"
|
||||
args_schema: Type[BaseModel] = NavigateToolInput
|
||||
args_schema: type[BaseModel] = NavigateToolInput
|
||||
|
||||
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -174,7 +172,7 @@ class ClickTool(BrowserBaseTool):
|
||||
|
||||
name: str = "click_element"
|
||||
description: str = "Click on an element with the given CSS selector"
|
||||
args_schema: Type[BaseModel] = ClickToolInput
|
||||
args_schema: type[BaseModel] = ClickToolInput
|
||||
|
||||
visible_only: bool = True
|
||||
"""Whether to consider only visible elements."""
|
||||
@@ -244,7 +242,7 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
|
||||
name: str = "navigate_back"
|
||||
description: str = "Navigate back to the previous page"
|
||||
args_schema: Type[BaseModel] = NavigateBackToolInput
|
||||
args_schema: type[BaseModel] = NavigateBackToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -282,7 +280,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
|
||||
name: str = "extract_text"
|
||||
description: str = "Extract all the text on the current webpage"
|
||||
args_schema: Type[BaseModel] = ExtractTextToolInput
|
||||
args_schema: type[BaseModel] = ExtractTextToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -334,7 +332,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
|
||||
name: str = "extract_hyperlinks"
|
||||
description: str = "Extract all hyperlinks on the current webpage"
|
||||
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
||||
args_schema: type[BaseModel] = ExtractHyperlinksToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -358,7 +356,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith("http") or href.startswith("https"):
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
@@ -390,7 +388,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith("http") or href.startswith("https"):
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
@@ -406,7 +404,7 @@ class GetElementsTool(BrowserBaseTool):
|
||||
|
||||
name: str = "get_elements"
|
||||
description: str = "Get elements from the webpage using a CSS selector"
|
||||
args_schema: Type[BaseModel] = GetElementsToolInput
|
||||
args_schema: type[BaseModel] = GetElementsToolInput
|
||||
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -454,7 +452,7 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
|
||||
name: str = "current_webpage"
|
||||
description: str = "Get information about the current webpage"
|
||||
args_schema: Type[BaseModel] = CurrentWebPageToolInput
|
||||
args_schema: type[BaseModel] = CurrentWebPageToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
"""Use the sync tool."""
|
||||
@@ -524,15 +522,14 @@ class BrowserToolkit:
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
"""
|
||||
Initialize the toolkit
|
||||
"""Initialize the toolkit.
|
||||
|
||||
Args:
|
||||
region: AWS region for the browser client
|
||||
"""
|
||||
self.region = region
|
||||
self.session_manager = BrowserSessionManager(region=region)
|
||||
self.tools: List[BaseTool] = []
|
||||
self.tools: list[BaseTool] = []
|
||||
self._nest_current_loop()
|
||||
self._setup_tools()
|
||||
|
||||
@@ -562,18 +559,16 @@ class BrowserToolkit:
|
||||
CurrentWebPageTool(session_manager=self.session_manager),
|
||||
]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""
|
||||
Get the list of browser tools
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the list of browser tools.
|
||||
|
||||
Returns:
|
||||
List of CrewAI tools
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||
"""
|
||||
Get a dictionary of tools mapped by their names
|
||||
def get_tools_by_name(self) -> dict[str, BaseTool]:
|
||||
"""Get a dictionary of tools mapped by their names.
|
||||
|
||||
Returns:
|
||||
Dictionary of {tool_name: tool}
|
||||
@@ -581,18 +576,18 @@ class BrowserToolkit:
|
||||
return {tool.name: tool for tool in self.tools}
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up all browser sessions asynchronously"""
|
||||
"""Clean up all browser sessions asynchronously."""
|
||||
await self.session_manager.close_all_browsers()
|
||||
logger.info("All browser sessions cleaned up")
|
||||
|
||||
def sync_cleanup(self) -> None:
|
||||
"""Clean up all browser sessions from synchronous code"""
|
||||
"""Clean up all browser sessions from synchronous code."""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(self.cleanup())
|
||||
asyncio.create_task(self.cleanup()) # noqa: RUF006
|
||||
else:
|
||||
loop.run_until_complete(self.cleanup())
|
||||
except RuntimeError:
|
||||
@@ -601,9 +596,8 @@ class BrowserToolkit:
|
||||
|
||||
def create_browser_toolkit(
|
||||
region: str = "us-west-2",
|
||||
) -> Tuple[BrowserToolkit, List[BaseTool]]:
|
||||
"""
|
||||
Create a BrowserToolkit
|
||||
) -> tuple[BrowserToolkit, list[BaseTool]]:
|
||||
"""Create a BrowserToolkit.
|
||||
|
||||
Args:
|
||||
region: AWS region for browser client
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -8,11 +8,12 @@ if TYPE_CHECKING:
|
||||
from playwright.sync_api import Browser as SyncBrowser, Page as SyncPage
|
||||
|
||||
|
||||
async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
|
||||
"""
|
||||
Asynchronously get the current page of the browser.
|
||||
async def aget_current_page(browser: AsyncBrowser | Any) -> AsyncPage:
|
||||
"""Asynchronously get the current page of the browser.
|
||||
|
||||
Args:
|
||||
browser: The browser (AsyncBrowser) to get the current page from.
|
||||
|
||||
Returns:
|
||||
AsyncPage: The current page.
|
||||
"""
|
||||
@@ -25,11 +26,12 @@ async def aget_current_page(browser: Union[AsyncBrowser, Any]) -> AsyncPage:
|
||||
return context.pages[-1]
|
||||
|
||||
|
||||
def get_current_page(browser: Union[SyncBrowser, Any]) -> SyncPage:
|
||||
"""
|
||||
Get the current page of the browser.
|
||||
def get_current_page(browser: SyncBrowser | Any) -> SyncPage:
|
||||
"""Get the current page of the browser.
|
||||
|
||||
Args:
|
||||
browser: The browser to get the current page from.
|
||||
|
||||
Returns:
|
||||
SyncPage: The current page.
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_output_from_stream(response):
|
||||
"""
|
||||
Extract output from code interpreter response stream
|
||||
"""Extract output from code interpreter response stream.
|
||||
|
||||
Args:
|
||||
response: Response from code interpreter execution
|
||||
@@ -73,7 +72,7 @@ class ExecuteCommandInput(BaseModel):
|
||||
class ReadFilesInput(BaseModel):
|
||||
"""Input for ReadFiles."""
|
||||
|
||||
paths: List[str] = Field(description="List of file paths to read")
|
||||
paths: list[str] = Field(description="List of file paths to read")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
@@ -91,7 +90,7 @@ class ListFilesInput(BaseModel):
|
||||
class DeleteFilesInput(BaseModel):
|
||||
"""Input for DeleteFiles."""
|
||||
|
||||
paths: List[str] = Field(description="List of file paths to delete")
|
||||
paths: list[str] = Field(description="List of file paths to delete")
|
||||
thread_id: str = Field(
|
||||
default="default", description="Thread ID for the code interpreter session"
|
||||
)
|
||||
@@ -100,7 +99,7 @@ class DeleteFilesInput(BaseModel):
|
||||
class WriteFilesInput(BaseModel):
|
||||
"""Input for WriteFiles."""
|
||||
|
||||
files: List[Dict[str, str]] = Field(
|
||||
files: list[dict[str, str]] = Field(
|
||||
description="List of dictionaries with path and text fields"
|
||||
)
|
||||
thread_id: str = Field(
|
||||
@@ -141,7 +140,7 @@ class ExecuteCodeTool(BaseTool):
|
||||
|
||||
name: str = "execute_code"
|
||||
description: str = "Execute code in various languages (primarily Python)"
|
||||
args_schema: Type[BaseModel] = ExecuteCodeInput
|
||||
args_schema: type[BaseModel] = ExecuteCodeInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -196,7 +195,7 @@ class ExecuteCommandTool(BaseTool):
|
||||
|
||||
name: str = "execute_command"
|
||||
description: str = "Run shell commands in the code interpreter environment"
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
args_schema: type[BaseModel] = ExecuteCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -229,14 +228,14 @@ class ReadFilesTool(BaseTool):
|
||||
|
||||
name: str = "read_files"
|
||||
description: str = "Read content of files in the environment"
|
||||
args_schema: Type[BaseModel] = ReadFilesInput
|
||||
args_schema: type[BaseModel] = ReadFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
@@ -252,7 +251,7 @@ class ReadFilesTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error reading files: {e!s}"
|
||||
|
||||
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(paths=paths, thread_id=thread_id)
|
||||
|
||||
@@ -262,7 +261,7 @@ class ListFilesTool(BaseTool):
|
||||
|
||||
name: str = "list_files"
|
||||
description: str = "List files in directories in the environment"
|
||||
args_schema: Type[BaseModel] = ListFilesInput
|
||||
args_schema: type[BaseModel] = ListFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -295,14 +294,14 @@ class DeleteFilesTool(BaseTool):
|
||||
|
||||
name: str = "delete_files"
|
||||
description: str = "Remove files from the environment"
|
||||
args_schema: Type[BaseModel] = DeleteFilesInput
|
||||
args_schema: type[BaseModel] = DeleteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
@@ -318,7 +317,7 @@ class DeleteFilesTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error deleting files: {e!s}"
|
||||
|
||||
async def _arun(self, paths: List[str], thread_id: str = "default") -> str:
|
||||
async def _arun(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(paths=paths, thread_id=thread_id)
|
||||
|
||||
@@ -328,14 +327,14 @@ class WriteFilesTool(BaseTool):
|
||||
|
||||
name: str = "write_files"
|
||||
description: str = "Create or update files in the environment"
|
||||
args_schema: Type[BaseModel] = WriteFilesInput
|
||||
args_schema: type[BaseModel] = WriteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, files: List[Dict[str, str]], thread_id: str = "default") -> str:
|
||||
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
|
||||
try:
|
||||
# Get or create code interpreter
|
||||
code_interpreter = self.toolkit._get_or_create_interpreter(
|
||||
@@ -352,7 +351,7 @@ class WriteFilesTool(BaseTool):
|
||||
return f"Error writing files: {e!s}"
|
||||
|
||||
async def _arun(
|
||||
self, files: List[Dict[str, str]], thread_id: str = "default"
|
||||
self, files: list[dict[str, str]], thread_id: str = "default"
|
||||
) -> str:
|
||||
# Use _run as we're working with a synchronous API that's thread-safe
|
||||
return self._run(files=files, thread_id=thread_id)
|
||||
@@ -363,7 +362,7 @@ class StartCommandTool(BaseTool):
|
||||
|
||||
name: str = "start_command_execution"
|
||||
description: str = "Start long-running commands asynchronously"
|
||||
args_schema: Type[BaseModel] = StartCommandInput
|
||||
args_schema: type[BaseModel] = StartCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -396,7 +395,7 @@ class GetTaskTool(BaseTool):
|
||||
|
||||
name: str = "get_task"
|
||||
description: str = "Check status of async tasks"
|
||||
args_schema: Type[BaseModel] = GetTaskInput
|
||||
args_schema: type[BaseModel] = GetTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -429,7 +428,7 @@ class StopTaskTool(BaseTool):
|
||||
|
||||
name: str = "stop_task"
|
||||
description: str = "Stop running tasks"
|
||||
args_schema: Type[BaseModel] = StopTaskInput
|
||||
args_schema: type[BaseModel] = StopTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
@@ -511,15 +510,14 @@ class CodeInterpreterToolkit:
|
||||
"""
|
||||
|
||||
def __init__(self, region: str = "us-west-2"):
|
||||
"""
|
||||
Initialize the toolkit
|
||||
"""Initialize the toolkit.
|
||||
|
||||
Args:
|
||||
region: AWS region for the code interpreter
|
||||
"""
|
||||
self.region = region
|
||||
self._code_interpreters: Dict[str, CodeInterpreter] = {}
|
||||
self.tools: List[BaseTool] = []
|
||||
self._code_interpreters: dict[str, CodeInterpreter] = {}
|
||||
self.tools: list[BaseTool] = []
|
||||
self._setup_tools()
|
||||
|
||||
def _setup_tools(self) -> None:
|
||||
@@ -561,26 +559,24 @@ class CodeInterpreterToolkit:
|
||||
self._code_interpreters[thread_id] = code_interpreter
|
||||
return code_interpreter
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""
|
||||
Get the list of code interpreter tools
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the list of code interpreter tools.
|
||||
|
||||
Returns:
|
||||
List of CrewAI tools
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_tools_by_name(self) -> Dict[str, BaseTool]:
|
||||
"""
|
||||
Get a dictionary of tools mapped by their names
|
||||
def get_tools_by_name(self) -> dict[str, BaseTool]:
|
||||
"""Get a dictionary of tools mapped by their names.
|
||||
|
||||
Returns:
|
||||
Dictionary of {tool_name: tool}
|
||||
"""
|
||||
return {tool.name: tool for tool in self.tools}
|
||||
|
||||
async def cleanup(self, thread_id: Optional[str] = None) -> None:
|
||||
"""Clean up resources
|
||||
async def cleanup(self, thread_id: str | None = None) -> None:
|
||||
"""Clean up resources.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID to clean up. If None, cleans up all sessions.
|
||||
@@ -604,7 +600,7 @@ class CodeInterpreterToolkit:
|
||||
for tid in thread_ids:
|
||||
try:
|
||||
self._code_interpreters[tid].stop()
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: PERF203
|
||||
logger.warning(
|
||||
f"Error stopping code interpreter for thread {tid}: {e}"
|
||||
)
|
||||
@@ -615,9 +611,8 @@ class CodeInterpreterToolkit:
|
||||
|
||||
def create_code_interpreter_toolkit(
|
||||
region: str = "us-west-2",
|
||||
) -> Tuple[CodeInterpreterToolkit, List[BaseTool]]:
|
||||
"""
|
||||
Create a CodeInterpreterToolkit
|
||||
) -> tuple[CodeInterpreterToolkit, list[BaseTool]]:
|
||||
"""Create a CodeInterpreterToolkit.
|
||||
|
||||
Args:
|
||||
region: AWS region for code interpreter
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
@@ -26,21 +26,21 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
description: str = (
|
||||
"Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||
)
|
||||
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
args_schema: type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
knowledge_base_id: str = None
|
||||
number_of_results: Optional[int] = 5
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None
|
||||
next_token: Optional[str] = None
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
number_of_results: int | None = 5
|
||||
retrieval_configuration: dict[str, Any] | None = None
|
||||
guardrail_configuration: dict[str, Any] | None = None
|
||||
next_token: str | None = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_base_id: str = None,
|
||||
number_of_results: Optional[int] = 5,
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None,
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None,
|
||||
next_token: Optional[str] = None,
|
||||
knowledge_base_id: str | None = None,
|
||||
number_of_results: int | None = 5,
|
||||
retrieval_configuration: dict[str, Any] | None = None,
|
||||
guardrail_configuration: dict[str, Any] | None = None,
|
||||
next_token: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||
@@ -72,7 +72,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Update the description to include the knowledge base details
|
||||
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
||||
|
||||
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||
def _build_retrieval_configuration(self) -> dict[str, Any]:
|
||||
"""Build the retrieval configuration based on provided parameters.
|
||||
|
||||
Returns:
|
||||
@@ -124,9 +124,9 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
)
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}")
|
||||
raise BedrockValidationError(f"Parameter validation failed: {e!s}") from e
|
||||
|
||||
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _process_retrieval_result(self, result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||
|
||||
Args:
|
||||
@@ -194,8 +194,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`boto3` package not found, please run `uv add boto3`"
|
||||
) from e
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
@@ -257,6 +259,8 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
error_code = e.response["Error"].get("Code", "Unknown")
|
||||
error_message = e.response["Error"].get("Message", str(e))
|
||||
|
||||
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||
raise BedrockKnowledgeBaseError(
|
||||
f"Error ({error_code}): {error_message}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}")
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}") from e
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import List, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -16,15 +15,17 @@ class S3ReaderToolInput(BaseModel):
|
||||
class S3ReaderTool(BaseTool):
|
||||
name: str = "S3 Reader Tool"
|
||||
description: str = "Reads a file from Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3ReaderToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
args_schema: type[BaseModel] = S3ReaderToolInput
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])
|
||||
|
||||
def _run(self, file_path: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`boto3` package not found, please run `uv add boto3`"
|
||||
) from e
|
||||
|
||||
try:
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
@@ -38,9 +39,8 @@ class S3ReaderTool(BaseTool):
|
||||
|
||||
# Read file content from S3
|
||||
response = s3.get_object(Bucket=bucket_name, Key=object_key)
|
||||
file_content = response["Body"].read().decode("utf-8")
|
||||
return response["Body"].read().decode("utf-8")
|
||||
|
||||
return file_content
|
||||
except ClientError as e:
|
||||
return f"Error reading file from S3: {e!s}"
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import List, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -17,15 +16,17 @@ class S3WriterToolInput(BaseModel):
|
||||
class S3WriterTool(BaseTool):
|
||||
name: str = "S3 Writer Tool"
|
||||
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3WriterToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
args_schema: type[BaseModel] = S3WriterToolInput
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])
|
||||
|
||||
def _run(self, file_path: str, content: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`boto3` package not found, please run `uv add boto3`"
|
||||
) from e
|
||||
|
||||
try:
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Printer:
|
||||
"""Handles colored console output formatting."""
|
||||
|
||||
@staticmethod
|
||||
def print(content: str, color: Optional[str] = None) -> None:
|
||||
def print(content: str, color: str | None = None) -> None:
|
||||
"""Prints content with optional color formatting.
|
||||
|
||||
Args:
|
||||
@@ -20,7 +18,7 @@ class Printer:
|
||||
if hasattr(Printer, f"_print_{color}"):
|
||||
getattr(Printer, f"_print_{color}")(content)
|
||||
else:
|
||||
print(content)
|
||||
print(content) # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_purple(content: str) -> None:
|
||||
@@ -29,7 +27,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold purple.
|
||||
"""
|
||||
print(f"\033[1m\033[95m {content}\033[00m")
|
||||
print(f"\033[1m\033[95m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_green(content: str) -> None:
|
||||
@@ -38,7 +36,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold green.
|
||||
"""
|
||||
print(f"\033[1m\033[92m {content}\033[00m")
|
||||
print(f"\033[1m\033[92m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_purple(content: str) -> None:
|
||||
@@ -47,7 +45,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in purple.
|
||||
"""
|
||||
print(f"\033[95m {content}\033[00m")
|
||||
print(f"\033[95m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_red(content: str) -> None:
|
||||
@@ -56,7 +54,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in red.
|
||||
"""
|
||||
print(f"\033[91m {content}\033[00m")
|
||||
print(f"\033[91m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_blue(content: str) -> None:
|
||||
@@ -65,7 +63,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold blue.
|
||||
"""
|
||||
print(f"\033[1m\033[94m {content}\033[00m")
|
||||
print(f"\033[1m\033[94m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_yellow(content: str) -> None:
|
||||
@@ -74,7 +72,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in yellow.
|
||||
"""
|
||||
print(f"\033[93m {content}\033[00m")
|
||||
print(f"\033[93m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_yellow(content: str) -> None:
|
||||
@@ -83,7 +81,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold yellow.
|
||||
"""
|
||||
print(f"\033[1m\033[93m {content}\033[00m")
|
||||
print(f"\033[1m\033[93m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_cyan(content: str) -> None:
|
||||
@@ -92,7 +90,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in cyan.
|
||||
"""
|
||||
print(f"\033[96m {content}\033[00m")
|
||||
print(f"\033[96m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_cyan(content: str) -> None:
|
||||
@@ -101,7 +99,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold cyan.
|
||||
"""
|
||||
print(f"\033[1m\033[96m {content}\033[00m")
|
||||
print(f"\033[1m\033[96m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_magenta(content: str) -> None:
|
||||
@@ -110,7 +108,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in magenta.
|
||||
"""
|
||||
print(f"\033[35m {content}\033[00m")
|
||||
print(f"\033[35m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_bold_magenta(content: str) -> None:
|
||||
@@ -119,7 +117,7 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in bold magenta.
|
||||
"""
|
||||
print(f"\033[1m\033[35m {content}\033[00m")
|
||||
print(f"\033[1m\033[35m {content}\033[00m") # noqa: T201
|
||||
|
||||
@staticmethod
|
||||
def _print_green(content: str) -> None:
|
||||
@@ -128,4 +126,4 @@ class Printer:
|
||||
Args:
|
||||
content: The string to be printed in green.
|
||||
"""
|
||||
print(f"\033[32m {content}\033[00m")
|
||||
print(f"\033[32m {content}\033[00m") # noqa: T201
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -10,14 +10,14 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class LoaderResult(BaseModel):
|
||||
content: str = Field(description="The text content of the source")
|
||||
source: str = Field(description="The source of the content", default="unknown")
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: dict[str, Any] = Field(
|
||||
description="The metadata of the source", default_factory=dict
|
||||
)
|
||||
doc_id: str = Field(description="The id of the document")
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
@@ -26,15 +26,13 @@ class BaseLoader(ABC):
|
||||
def generate_doc_id(
|
||||
self, source_ref: str | None = None, content: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique document id based on the source reference and content.
|
||||
"""Generate a unique document id based on the source reference and content.
|
||||
If the source reference is not provided, the content is used as the source reference.
|
||||
If the content is not provided, the source reference is used as the content.
|
||||
If both are provided, the source reference is used as the content.
|
||||
|
||||
Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference.
|
||||
"""
|
||||
|
||||
source_ref = source_ref or ""
|
||||
content = content or ""
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter:
|
||||
"""
|
||||
A text splitter that recursively splits text based on a hierarchy of separators.
|
||||
"""
|
||||
"""A text splitter that recursively splits text based on a hierarchy of separators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the RecursiveCharacterTextSplitter.
|
||||
"""Initialize the RecursiveCharacterTextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of each chunk
|
||||
@@ -39,10 +35,10 @@ class RecursiveCharacterTextSplitter:
|
||||
"",
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
@@ -71,7 +67,7 @@ class RecursiveCharacterTextSplitter:
|
||||
|
||||
return self._merge_splits(good_splits, separator)
|
||||
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> list[str]:
|
||||
if separator == "":
|
||||
return list(text)
|
||||
|
||||
@@ -95,13 +91,13 @@ class RecursiveCharacterTextSplitter:
|
||||
return [s for s in splits if s]
|
||||
return text.split(separator)
|
||||
|
||||
def _split_by_characters(self, text: str) -> List[str]:
|
||||
def _split_by_characters(self, text: str) -> list[str]:
|
||||
chunks = []
|
||||
for i in range(0, len(text), self._chunk_size):
|
||||
chunks.append(text[i : i + self._chunk_size])
|
||||
chunks.append(text[i : i + self._chunk_size]) # noqa: PERF401
|
||||
return chunks
|
||||
|
||||
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||
def _merge_splits(self, splits: list[str], separator: str) -> list[str]:
|
||||
"""Merge splits into chunks with proper overlap."""
|
||||
docs = []
|
||||
current_doc = []
|
||||
@@ -154,11 +150,10 @@ class BaseChunker:
|
||||
self,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the Chunker
|
||||
"""Initialize the Chunker.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of each chunk
|
||||
@@ -166,7 +161,6 @@ class BaseChunker:
|
||||
separators: List of separators to use for splitting
|
||||
keep_separator: Whether to keep separators in the chunks
|
||||
"""
|
||||
|
||||
self._splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
@@ -174,7 +168,7 @@ class BaseChunker:
|
||||
keep_separator=keep_separator,
|
||||
)
|
||||
|
||||
def chunk(self, text: str) -> List[str]:
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
@@ -8,7 +6,7 @@ class DefaultChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 20,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
@@ -8,7 +6,7 @@ class CsvChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 1200,
|
||||
chunk_overlap: int = 100,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
@@ -28,7 +26,7 @@ class JsonChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
@@ -50,7 +48,7 @@ class XmlChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
@@ -8,7 +6,7 @@ class TextChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 1500,
|
||||
chunk_overlap: int = 150,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
@@ -32,7 +30,7 @@ class DocxChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
@@ -56,7 +54,7 @@ class MdxChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 3000,
|
||||
chunk_overlap: int = 300,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
@@ -8,7 +6,7 @@ class WebsiteChunker(BaseChunker):
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
@@ -23,7 +23,7 @@ class EmbeddingService:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
try:
|
||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||
return response.data[0]["embedding"]
|
||||
@@ -31,7 +31,7 @@ class EmbeddingService:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
@@ -46,18 +46,18 @@ class EmbeddingService:
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
data_type: DataType = DataType.TEXT
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: Optional[str] = None
|
||||
persist_directory: str | None = None
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
embedding_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
@@ -90,10 +90,10 @@ class RAG(Adapter):
|
||||
def add(
|
||||
self,
|
||||
content: str | Path,
|
||||
data_type: Optional[Union[str, DataType]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
data_type: str | DataType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
loader: BaseLoader | None = None,
|
||||
chunker: BaseChunker | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
source_content = SourceContent(content)
|
||||
@@ -181,7 +181,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
@@ -228,7 +228,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
def get_collection_info(self) -> dict[str, Any]:
|
||||
try:
|
||||
count = self._collection.count()
|
||||
return {
|
||||
@@ -246,7 +246,7 @@ class RAG(Adapter):
|
||||
try:
|
||||
if isinstance(data_type, str):
|
||||
return DataType(data_type)
|
||||
except Exception:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return content.data_type
|
||||
|
||||
@@ -65,7 +65,7 @@ class DataType(str, Enum):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}")
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
@@ -100,7 +100,7 @@ class DataType(str, Enum):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}")
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
class DataTypes:
|
||||
@@ -117,7 +117,7 @@ class DataTypes:
|
||||
try:
|
||||
url = urlparse(content)
|
||||
is_url = (url.scheme and url.netloc) or url.scheme == "file"
|
||||
except Exception:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def get_file_type(path: str) -> DataType | None:
|
||||
|
||||
@@ -33,7 +33,7 @@ class CSVLoader(BaseLoader):
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}")
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
@@ -8,8 +7,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
class DirectoryLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process all files from a directory recursively.
|
||||
"""Load and process all files from a directory recursively.
|
||||
|
||||
Args:
|
||||
source: Directory path or URL to a directory listing
|
||||
@@ -63,7 +61,7 @@ class DirectoryLoader(BaseLoader):
|
||||
"source": result.source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: PERF203
|
||||
error_msg = f"Error processing {file_path}: {e!s}"
|
||||
errors.append(error_msg)
|
||||
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
|
||||
@@ -91,9 +89,9 @@ class DirectoryLoader(BaseLoader):
|
||||
self,
|
||||
dir_path: str,
|
||||
recursive: bool,
|
||||
include_ext: List[str] | None = None,
|
||||
exclude_ext: List[str] | None = None,
|
||||
) -> List[str]:
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Find all files in directory matching criteria."""
|
||||
files = []
|
||||
|
||||
@@ -103,7 +101,7 @@ class DirectoryLoader(BaseLoader):
|
||||
|
||||
for filename in filenames:
|
||||
if self._should_include_file(filename, include_ext, exclude_ext):
|
||||
files.append(os.path.join(root, filename))
|
||||
files.append(os.path.join(root, filename)) # noqa: PERF401
|
||||
else:
|
||||
try:
|
||||
for item in os.listdir(dir_path):
|
||||
@@ -120,8 +118,8 @@ class DirectoryLoader(BaseLoader):
|
||||
def _should_include_file(
|
||||
self,
|
||||
filename: str,
|
||||
include_ext: List[str] = None,
|
||||
exclude_ext: List[str] = None,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Determine if a file should be included based on criteria."""
|
||||
if filename.startswith("."):
|
||||
|
||||
@@ -28,7 +28,9 @@ class DocsSiteLoader(BaseLoader):
|
||||
response = requests.get(docs_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||
raise ValueError(
|
||||
f"Unable to fetch documentation from {docs_url}: {e}"
|
||||
) from e
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ class DOCXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]"
|
||||
)
|
||||
) from e
|
||||
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
@@ -49,10 +49,13 @@ class DOCXLoader(BaseLoader):
|
||||
temp_file.write(response.content)
|
||||
return temp_file.name
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}")
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(
|
||||
self, file_path: str, source_ref: str, DocxDocument
|
||||
self,
|
||||
file_path: str,
|
||||
source_ref: str,
|
||||
DocxDocument, # noqa: N803
|
||||
) -> LoaderResult:
|
||||
try:
|
||||
doc = DocxDocument(file_path)
|
||||
@@ -60,7 +63,7 @@ class DOCXLoader(BaseLoader):
|
||||
text_parts = []
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text_parts.append(paragraph.text)
|
||||
text_parts.append(paragraph.text) # noqa: PERF401
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
@@ -78,4 +81,4 @@ class DOCXLoader(BaseLoader):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading DOCX file: {e!s}")
|
||||
raise ValueError(f"Error loading DOCX file: {e!s}") from e
|
||||
|
||||
@@ -38,7 +38,7 @@ class GithubLoader(BaseLoader):
|
||||
try:
|
||||
repo = g.get_repo(repo_name)
|
||||
except GithubException as e:
|
||||
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||
raise ValueError(f"Unable to access repository {repo_name}: {e}") from e
|
||||
|
||||
all_content = []
|
||||
|
||||
@@ -66,7 +66,7 @@ class GithubLoader(BaseLoader):
|
||||
if isinstance(contents, list):
|
||||
all_content.append("Repository structure:")
|
||||
for content_file in contents[:20]:
|
||||
all_content.append(
|
||||
all_content.append( # noqa: PERF401
|
||||
f"- {content_file.path} ({content_file.type})"
|
||||
)
|
||||
all_content.append("")
|
||||
|
||||
@@ -36,7 +36,7 @@ class JSONLoader(BaseLoader):
|
||||
else json.dumps(response.json(), indent=2)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}")
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}") from e
|
||||
|
||||
def _is_json_response(self, response) -> bool:
|
||||
try:
|
||||
|
||||
@@ -32,7 +32,7 @@ class MDXLoader(BaseLoader):
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}")
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
|
||||
@@ -95,6 +95,6 @@ class MySQLLoader(BaseLoader):
|
||||
finally:
|
||||
connection.close()
|
||||
except pymysql.Error as e:
|
||||
raise ValueError(f"MySQL database error: {e}")
|
||||
raise ValueError(f"MySQL database error: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}") from e
|
||||
|
||||
@@ -28,11 +28,11 @@ class PDFLoader(BaseLoader):
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf
|
||||
except ImportError:
|
||||
import PyPDF2 as pypdf # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
)
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
|
||||
@@ -56,7 +56,7 @@ class PDFLoader(BaseLoader):
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}")
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
|
||||
@@ -95,6 +95,6 @@ class PostgresLoader(BaseLoader):
|
||||
finally:
|
||||
connection.close()
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}")
|
||||
raise ValueError(f"PostgreSQL database error: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}") from e
|
||||
|
||||
@@ -51,4 +51,4 @@ class WebPageLoader(BaseLoader):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading webpage {url}: {e!s}")
|
||||
raise ValueError(f"Error loading webpage {url}: {e!s}") from e
|
||||
|
||||
@@ -32,7 +32,7 @@ class XMLLoader(BaseLoader):
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {e!s}")
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -41,14 +41,14 @@ class XMLLoader(BaseLoader):
|
||||
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
if content.strip().startswith("<"):
|
||||
root = ET.fromstring(content)
|
||||
root = ET.fromstring(content) # noqa: S314
|
||||
else:
|
||||
root = ET.parse(source_ref).getroot()
|
||||
root = ET.parse(source_ref).getroot() # noqa: S314
|
||||
|
||||
text_parts = []
|
||||
for text_content in root.itertext():
|
||||
if text_content and text_content.strip():
|
||||
text_parts.append(text_content.strip())
|
||||
text_parts.append(text_content.strip()) # noqa: PERF401
|
||||
|
||||
text = "\n".join(text_parts)
|
||||
metadata = {"format": "xml", "root_tag": root.tag}
|
||||
|
||||
@@ -25,10 +25,10 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
"""
|
||||
try:
|
||||
from pytube import Channel
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"YouTube channel support requires pytube. Install with: uv add pytube"
|
||||
)
|
||||
) from e
|
||||
|
||||
channel_url = source.source
|
||||
|
||||
@@ -93,14 +93,14 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
except Exception:
|
||||
try:
|
||||
transcript = (
|
||||
transcript_list.find_generated_transcript(
|
||||
["en"]
|
||||
)
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
transcript = next(iter(transcript_list), None)
|
||||
|
||||
if transcript:
|
||||
@@ -124,7 +124,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
content_parts.append(
|
||||
f" Transcript Preview: {preview}..."
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
content_parts.append(" Transcript: Not available")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -26,11 +26,11 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
"""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"YouTube support requires youtube-transcript-api. "
|
||||
"Install with: uv add youtube-transcript-api"
|
||||
)
|
||||
) from e
|
||||
|
||||
video_url = source.source
|
||||
video_id = self._extract_video_id(video_url)
|
||||
@@ -51,10 +51,10 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
transcript = None
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
except Exception:
|
||||
try:
|
||||
transcript = transcript_list.find_generated_transcript(["en"])
|
||||
except:
|
||||
except Exception:
|
||||
transcript = next(iter(transcript_list))
|
||||
|
||||
if transcript:
|
||||
@@ -84,7 +84,7 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
|
||||
if yt.title:
|
||||
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||
except:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -128,7 +128,7 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
query_params = parse_qs(parsed.query)
|
||||
if "v" in query_params:
|
||||
return query_params["v"][0]
|
||||
except:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@@ -40,7 +40,6 @@ class SourceContent:
|
||||
If the content is a URL or a local file, returns the source.
|
||||
Otherwise, returns the hash of the content.
|
||||
"""
|
||||
|
||||
if self.is_url() or self.path_exists():
|
||||
return self.source
|
||||
|
||||
|
||||
@@ -1,127 +1,274 @@
|
||||
from .ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from .apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from .arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from .brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from .brightdata_tool import (
|
||||
from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool
|
||||
from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from crewai_tools.tools.brightdata_tool import (
|
||||
BrightDataDatasetTool,
|
||||
BrightDataSearchTool,
|
||||
BrightDataWebUnlockerTool,
|
||||
)
|
||||
from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool
|
||||
from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool
|
||||
from .code_interpreter_tool.code_interpreter_tool import CodeInterpreterTool
|
||||
from .composio_tool.composio_tool import ComposioTool
|
||||
from .contextualai_create_agent_tool.contextual_create_agent_tool import (
|
||||
from crewai_tools.tools.browserbase_load_tool.browserbase_load_tool import (
|
||||
BrowserbaseLoadTool,
|
||||
)
|
||||
from crewai_tools.tools.code_docs_search_tool.code_docs_search_tool import (
|
||||
CodeDocsSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
CodeInterpreterTool,
|
||||
)
|
||||
from crewai_tools.tools.composio_tool.composio_tool import ComposioTool
|
||||
from crewai_tools.tools.contextualai_create_agent_tool.contextual_create_agent_tool import (
|
||||
ContextualAICreateAgentTool,
|
||||
)
|
||||
from .contextualai_parse_tool.contextual_parse_tool import ContextualAIParseTool
|
||||
from .contextualai_query_tool.contextual_query_tool import ContextualAIQueryTool
|
||||
from .contextualai_rerank_tool.contextual_rerank_tool import ContextualAIRerankTool
|
||||
from .couchbase_tool.couchbase_tool import CouchbaseFTSVectorSearchTool
|
||||
from .crewai_enterprise_tools.crewai_enterprise_tools import CrewaiEnterpriseTools
|
||||
from .crewai_platform_tools.crewai_platform_tools import CrewaiPlatformTools
|
||||
from .csv_search_tool.csv_search_tool import CSVSearchTool
|
||||
from .dalle_tool.dalle_tool import DallETool
|
||||
from .databricks_query_tool.databricks_query_tool import DatabricksQueryTool
|
||||
from .directory_read_tool.directory_read_tool import DirectoryReadTool
|
||||
from .directory_search_tool.directory_search_tool import DirectorySearchTool
|
||||
from .docx_search_tool.docx_search_tool import DOCXSearchTool
|
||||
from .exa_tools.exa_search_tool import EXASearchTool
|
||||
from .file_read_tool.file_read_tool import FileReadTool
|
||||
from .file_writer_tool.file_writer_tool import FileWriterTool
|
||||
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
||||
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||
from crewai_tools.tools.contextualai_parse_tool.contextual_parse_tool import (
|
||||
ContextualAIParseTool,
|
||||
)
|
||||
from crewai_tools.tools.contextualai_query_tool.contextual_query_tool import (
|
||||
ContextualAIQueryTool,
|
||||
)
|
||||
from crewai_tools.tools.contextualai_rerank_tool.contextual_rerank_tool import (
|
||||
ContextualAIRerankTool,
|
||||
)
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
|
||||
CouchbaseFTSVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools import (
|
||||
CrewaiEnterpriseTools,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
)
|
||||
from crewai_tools.tools.csv_search_tool.csv_search_tool import CSVSearchTool
|
||||
from crewai_tools.tools.dalle_tool.dalle_tool import DallETool
|
||||
from crewai_tools.tools.databricks_query_tool.databricks_query_tool import (
|
||||
DatabricksQueryTool,
|
||||
)
|
||||
from crewai_tools.tools.directory_read_tool.directory_read_tool import (
|
||||
DirectoryReadTool,
|
||||
)
|
||||
from crewai_tools.tools.directory_search_tool.directory_search_tool import (
|
||||
DirectorySearchTool,
|
||||
)
|
||||
from crewai_tools.tools.docx_search_tool.docx_search_tool import DOCXSearchTool
|
||||
from crewai_tools.tools.exa_tools.exa_search_tool import EXASearchTool
|
||||
from crewai_tools.tools.file_read_tool.file_read_tool import FileReadTool
|
||||
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool
|
||||
from crewai_tools.tools.files_compressor_tool.files_compressor_tool import (
|
||||
FileCompressorTool,
|
||||
)
|
||||
from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
from .generate_crewai_automation_tool.generate_crewai_automation_tool import (
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import (
|
||||
FirecrawlSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automation_tool import (
|
||||
GenerateCrewaiAutomationTool,
|
||||
)
|
||||
from .github_search_tool.github_search_tool import GithubSearchTool
|
||||
from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool
|
||||
from .invoke_crewai_automation_tool.invoke_crewai_automation_tool import (
|
||||
from crewai_tools.tools.github_search_tool.github_search_tool import GithubSearchTool
|
||||
from crewai_tools.tools.hyperbrowser_load_tool.hyperbrowser_load_tool import (
|
||||
HyperbrowserLoadTool,
|
||||
)
|
||||
from crewai_tools.tools.invoke_crewai_automation_tool.invoke_crewai_automation_tool import (
|
||||
InvokeCrewAIAutomationTool,
|
||||
)
|
||||
from .json_search_tool.json_search_tool import JSONSearchTool
|
||||
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from .mongodb_vector_search_tool import (
|
||||
from crewai_tools.tools.jina_scrape_website_tool.jina_scrape_website_tool import (
|
||||
JinaScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.json_search_tool.json_search_tool import JSONSearchTool
|
||||
from crewai_tools.tools.linkup.linkup_search_tool import LinkupSearchTool
|
||||
from crewai_tools.tools.llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from crewai_tools.tools.mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from crewai_tools.tools.mongodb_vector_search_tool import (
|
||||
MongoDBToolSchema,
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
)
|
||||
from .multion_tool.multion_tool import MultiOnTool
|
||||
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from .nl2sql.nl2sql_tool import NL2SQLTool
|
||||
from .ocr_tool.ocr_tool import OCRTool
|
||||
from .oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
|
||||
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
|
||||
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
from crewai_tools.tools.ocr_tool.ocr_tool import OCRTool
|
||||
from crewai_tools.tools.oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
|
||||
OxylabsAmazonProductScraperTool,
|
||||
)
|
||||
from .oxylabs_amazon_search_scraper_tool.oxylabs_amazon_search_scraper_tool import (
|
||||
from crewai_tools.tools.oxylabs_amazon_search_scraper_tool.oxylabs_amazon_search_scraper_tool import (
|
||||
OxylabsAmazonSearchScraperTool,
|
||||
)
|
||||
from .oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
|
||||
from crewai_tools.tools.oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
|
||||
OxylabsGoogleSearchScraperTool,
|
||||
)
|
||||
from .oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
|
||||
from crewai_tools.tools.oxylabs_universal_scraper_tool.oxylabs_universal_scraper_tool import (
|
||||
OxylabsUniversalScraperTool,
|
||||
)
|
||||
from .parallel_tools import (
|
||||
ParallelSearchTool,
|
||||
)
|
||||
from .patronus_eval_tool import (
|
||||
from crewai_tools.tools.parallel_tools import ParallelSearchTool
|
||||
from crewai_tools.tools.patronus_eval_tool import (
|
||||
PatronusEvalTool,
|
||||
PatronusLocalEvaluatorTool,
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
from .pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
from .pg_search_tool.pg_search_tool import PGSearchTool
|
||||
from .qdrant_vector_search_tool.qdrant_search_tool import QdrantVectorSearchTool
|
||||
from .rag.rag_tool import RagTool
|
||||
from .scrape_element_from_website.scrape_element_from_website import (
|
||||
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
from crewai_tools.tools.qdrant_vector_search_tool.qdrant_search_tool import (
|
||||
QdrantVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
from crewai_tools.tools.scrape_element_from_website.scrape_element_from_website import (
|
||||
ScrapeElementFromWebsiteTool,
|
||||
)
|
||||
from .scrape_website_tool.scrape_website_tool import ScrapeWebsiteTool
|
||||
from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import (
|
||||
from crewai_tools.tools.scrape_website_tool.scrape_website_tool import (
|
||||
ScrapeWebsiteTool,
|
||||
)
|
||||
from crewai_tools.tools.scrapegraph_scrape_tool.scrapegraph_scrape_tool import (
|
||||
ScrapegraphScrapeTool,
|
||||
ScrapegraphScrapeToolSchema,
|
||||
)
|
||||
from .scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import (
|
||||
from crewai_tools.tools.scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import (
|
||||
ScrapflyScrapeWebsiteTool,
|
||||
)
|
||||
from .selenium_scraping_tool.selenium_scraping_tool import SeleniumScrapingTool
|
||||
from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool
|
||||
from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool
|
||||
from .serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
from .serper_scrape_website_tool.serper_scrape_website_tool import (
|
||||
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
|
||||
SeleniumScrapingTool,
|
||||
)
|
||||
from crewai_tools.tools.serpapi_tool.serpapi_google_search_tool import (
|
||||
SerpApiGoogleSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serpapi_tool.serpapi_google_shopping_tool import (
|
||||
SerpApiGoogleShoppingTool,
|
||||
)
|
||||
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
from crewai_tools.tools.serper_scrape_website_tool.serper_scrape_website_tool import (
|
||||
SerperScrapeWebsiteTool,
|
||||
)
|
||||
from .serply_api_tool.serply_job_search_tool import SerplyJobSearchTool
|
||||
from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
|
||||
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
|
||||
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
|
||||
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
|
||||
from .singlestore_search_tool import SingleStoreSearchTool
|
||||
from .snowflake_search_tool import (
|
||||
from crewai_tools.tools.serply_api_tool.serply_job_search_tool import (
|
||||
SerplyJobSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_news_search_tool import (
|
||||
SerplyNewsSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_scholar_search_tool import (
|
||||
SerplyScholarSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_web_search_tool import (
|
||||
SerplyWebSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.serply_api_tool.serply_webpage_to_markdown_tool import (
|
||||
SerplyWebpageToMarkdownTool,
|
||||
)
|
||||
from crewai_tools.tools.singlestore_search_tool import SingleStoreSearchTool
|
||||
from crewai_tools.tools.snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
from .spider_tool.spider_tool import SpiderTool
|
||||
from .stagehand_tool.stagehand_tool import StagehandTool
|
||||
from .tavily_extractor_tool.tavily_extractor_tool import TavilyExtractorTool
|
||||
from .tavily_search_tool.tavily_search_tool import TavilySearchTool
|
||||
from .txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from .vision_tool.vision_tool import VisionTool
|
||||
from .weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||
from .website_search.website_search_tool import WebsiteSearchTool
|
||||
from .xml_search_tool.xml_search_tool import XMLSearchTool
|
||||
from .youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
|
||||
from crewai_tools.tools.stagehand_tool.stagehand_tool import StagehandTool
|
||||
from crewai_tools.tools.tavily_extractor_tool.tavily_extractor_tool import (
|
||||
TavilyExtractorTool,
|
||||
)
|
||||
from crewai_tools.tools.tavily_search_tool.tavily_search_tool import TavilySearchTool
|
||||
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from crewai_tools.tools.vision_tool.vision_tool import VisionTool
|
||||
from crewai_tools.tools.weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||
from crewai_tools.tools.website_search.website_search_tool import WebsiteSearchTool
|
||||
from crewai_tools.tools.xml_search_tool.xml_search_tool import XMLSearchTool
|
||||
from crewai_tools.tools.youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||
YoutubeChannelSearchTool,
|
||||
)
|
||||
from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool
|
||||
from .zapier_action_tool.zapier_action_tool import ZapierActionTools
|
||||
from crewai_tools.tools.youtube_video_search_tool.youtube_video_search_tool import (
|
||||
YoutubeVideoSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.zapier_action_tool.zapier_action_tool import ZapierActionTools
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMindTool",
|
||||
"ApifyActorsTool",
|
||||
"ArxivPaperTool",
|
||||
"BraveSearchTool",
|
||||
"BrightDataDatasetTool",
|
||||
"BrightDataSearchTool",
|
||||
"BrightDataWebUnlockerTool",
|
||||
"BrowserbaseLoadTool",
|
||||
"CSVSearchTool",
|
||||
"CodeDocsSearchTool",
|
||||
"CodeInterpreterTool",
|
||||
"ComposioTool",
|
||||
"ContextualAICreateAgentTool",
|
||||
"ContextualAIParseTool",
|
||||
"ContextualAIQueryTool",
|
||||
"ContextualAIRerankTool",
|
||||
"CouchbaseFTSVectorSearchTool",
|
||||
"CrewaiEnterpriseTools",
|
||||
"CrewaiPlatformTools",
|
||||
"DOCXSearchTool",
|
||||
"DallETool",
|
||||
"DatabricksQueryTool",
|
||||
"DirectoryReadTool",
|
||||
"DirectorySearchTool",
|
||||
"EXASearchTool",
|
||||
"FileCompressorTool",
|
||||
"FileReadTool",
|
||||
"FileWriterTool",
|
||||
"FirecrawlCrawlWebsiteTool",
|
||||
"FirecrawlScrapeWebsiteTool",
|
||||
"FirecrawlSearchTool",
|
||||
"GenerateCrewaiAutomationTool",
|
||||
"GithubSearchTool",
|
||||
"HyperbrowserLoadTool",
|
||||
"InvokeCrewAIAutomationTool",
|
||||
"JSONSearchTool",
|
||||
"JinaScrapeWebsiteTool",
|
||||
"LinkupSearchTool",
|
||||
"LlamaIndexTool",
|
||||
"MDXSearchTool",
|
||||
"MongoDBToolSchema",
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
"MultiOnTool",
|
||||
"MySQLSearchTool",
|
||||
"NL2SQLTool",
|
||||
"OCRTool",
|
||||
"OxylabsAmazonProductScraperTool",
|
||||
"OxylabsAmazonSearchScraperTool",
|
||||
"OxylabsGoogleSearchScraperTool",
|
||||
"OxylabsUniversalScraperTool",
|
||||
"PDFSearchTool",
|
||||
"ParallelSearchTool",
|
||||
"PatronusEvalTool",
|
||||
"PatronusLocalEvaluatorTool",
|
||||
"PatronusPredefinedCriteriaEvalTool",
|
||||
"QdrantVectorSearchTool",
|
||||
"RagTool",
|
||||
"ScrapeElementFromWebsiteTool",
|
||||
"ScrapeWebsiteTool",
|
||||
"ScrapegraphScrapeTool",
|
||||
"ScrapegraphScrapeToolSchema",
|
||||
"ScrapflyScrapeWebsiteTool",
|
||||
"SeleniumScrapingTool",
|
||||
"SerpApiGoogleSearchTool",
|
||||
"SerpApiGoogleShoppingTool",
|
||||
"SerperDevTool",
|
||||
"SerperScrapeWebsiteTool",
|
||||
"SerplyJobSearchTool",
|
||||
"SerplyNewsSearchTool",
|
||||
"SerplyScholarSearchTool",
|
||||
"SerplyWebSearchTool",
|
||||
"SerplyWebpageToMarkdownTool",
|
||||
"SingleStoreSearchTool",
|
||||
"SnowflakeConfig",
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
"SpiderTool",
|
||||
"StagehandTool",
|
||||
"TXTSearchTool",
|
||||
"TavilyExtractorTool",
|
||||
"TavilySearchTool",
|
||||
"VisionTool",
|
||||
"WeaviateVectorSearchTool",
|
||||
"WebsiteSearchTool",
|
||||
"XMLSearchTool",
|
||||
"YoutubeChannelSearchTool",
|
||||
"YoutubeVideoSearchTool",
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import OpenAI
|
||||
@@ -28,16 +28,20 @@ class AIMindTool(BaseTool):
|
||||
"and Google BigQuery. "
|
||||
"Input should be a question in natural language."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: Optional[str] = None
|
||||
datasources: Optional[List[Dict[str, Any]]] = None
|
||||
mind_name: Optional[str] = None
|
||||
package_dependencies: List[str] = ["minds-sdk"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="MINDS_API_KEY", description="API key for AI-Minds", required=True),
|
||||
]
|
||||
args_schema: type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: str | None = None
|
||||
datasources: list[dict[str, Any]] | None = None
|
||||
mind_name: str | None = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["minds-sdk"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="MINDS_API_KEY", description="API key for AI-Minds", required=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||
if not self.api_key:
|
||||
@@ -48,10 +52,10 @@ class AIMindTool(BaseTool):
|
||||
try:
|
||||
from minds.client import Client # type: ignore
|
||||
from minds.datasources import DatabaseConfig # type: ignore
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`minds_sdk` package not found, please run `pip install minds-sdk`"
|
||||
)
|
||||
) from e
|
||||
|
||||
minds_client = Client(api_key=self.api_key)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import Field
|
||||
@@ -10,13 +10,15 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ApifyActorsTool(BaseTool):
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="APIFY_API_TOKEN",
|
||||
description="API token for Apify platform access",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="APIFY_API_TOKEN",
|
||||
description="API token for Apify platform access",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
"""Tool that runs Apify Actors.
|
||||
|
||||
To use, you should have the environment variable `APIFY_API_TOKEN` set
|
||||
@@ -48,7 +50,7 @@ class ApifyActorsTool(BaseTool):
|
||||
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||
"""
|
||||
actor_tool: "_ApifyActorsTool" = Field(description="Apify Actor Tool")
|
||||
package_dependencies: List[str] = ["langchain-apify"]
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["langchain-apify"])
|
||||
|
||||
def __init__(self, actor_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
if not os.environ.get("APIFY_API_TOKEN"):
|
||||
@@ -61,11 +63,11 @@ class ApifyActorsTool(BaseTool):
|
||||
|
||||
try:
|
||||
from langchain_apify import ApifyActorsTool as _ApifyActorsTool
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import langchain_apify python package. "
|
||||
"Please install it with `pip install langchain-apify` or `uv add langchain-apify`."
|
||||
)
|
||||
) from e
|
||||
actor_tool = _ApifyActorsTool(actor_name)
|
||||
|
||||
kwargs.update(
|
||||
@@ -78,7 +80,7 @@ class ApifyActorsTool(BaseTool):
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _run(self, run_input: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
def _run(self, run_input: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Run the Actor tool with the given input.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import time
|
||||
from typing import ClassVar, List, Optional, Type
|
||||
from typing import ClassVar
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
@@ -32,10 +32,10 @@ class ArxivPaperTool(BaseTool):
|
||||
REQUEST_TIMEOUT: ClassVar[int] = 10
|
||||
name: str = "Arxiv Paper Fetcher and Downloader"
|
||||
description: str = "Fetches metadata from Arxiv based on a search query and optionally downloads PDFs."
|
||||
args_schema: Type[BaseModel] = ArxivToolInput
|
||||
model_config = {"extra": "allow"}
|
||||
package_dependencies: List[str] = ["pydantic"]
|
||||
env_vars: List[EnvVar] = []
|
||||
args_schema: type[BaseModel] = ArxivToolInput
|
||||
model_config = ConfigDict(extra="allow")
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["pydantic"])
|
||||
env_vars: list[EnvVar] = Field(default_factory=list)
|
||||
|
||||
def __init__(
|
||||
self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False
|
||||
@@ -80,12 +80,12 @@ class ArxivPaperTool(BaseTool):
|
||||
logger.error(f"ArxivTool Error: {e!s}")
|
||||
return f"Failed to fetch or download Arxiv papers: {e!s}"
|
||||
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> List[dict]:
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> list[dict]:
|
||||
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
|
||||
logger.info(f"Fetching data from Arxiv API: {api_url}")
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(
|
||||
with urllib.request.urlopen( # noqa: S310
|
||||
api_url, timeout=self.REQUEST_TIMEOUT
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
@@ -95,7 +95,7 @@ class ArxivPaperTool(BaseTool):
|
||||
logger.error(f"Error fetching data from Arxiv: {e}")
|
||||
raise
|
||||
|
||||
root = ET.fromstring(data)
|
||||
root = ET.fromstring(data) # noqa: S314
|
||||
papers = []
|
||||
|
||||
for entry in root.findall(self.ATOM_NAMESPACE + "entry"):
|
||||
@@ -126,11 +126,11 @@ class ArxivPaperTool(BaseTool):
|
||||
return papers
|
||||
|
||||
@staticmethod
|
||||
def _get_element_text(entry: ET.Element, element_name: str) -> Optional[str]:
|
||||
def _get_element_text(entry: ET.Element, element_name: str) -> str | None:
|
||||
elem = entry.find(f"{ArxivPaperTool.ATOM_NAMESPACE}{element_name}")
|
||||
return elem.text.strip() if elem is not None and elem.text else None
|
||||
|
||||
def _extract_pdf_url(self, entry: ET.Element) -> Optional[str]:
|
||||
def _extract_pdf_url(self, entry: ET.Element) -> str | None:
|
||||
for link in entry.findall(self.ATOM_NAMESPACE + "link"):
|
||||
if link.attrib.get("title", "").lower() == "pdf":
|
||||
return link.attrib.get("href")
|
||||
@@ -164,7 +164,7 @@ class ArxivPaperTool(BaseTool):
|
||||
def download_pdf(self, pdf_url: str, save_path: str):
|
||||
try:
|
||||
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
|
||||
urllib.request.urlretrieve(pdf_url, str(save_path))
|
||||
urllib.request.urlretrieve(pdf_url, str(save_path)) # noqa: S310
|
||||
logger.info(f"PDF saved: {save_path}")
|
||||
except urllib.error.URLError as e:
|
||||
logger.error(f"Network error occurred while downloading {pdf_url}: {e}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar, List, Optional, Type
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -13,7 +13,6 @@ def _save_results_to_file(content: str) -> None:
|
||||
filename = f"search_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(content)
|
||||
print(f"Results saved to {filename}")
|
||||
|
||||
|
||||
class BraveSearchToolSchema(BaseModel):
|
||||
@@ -25,8 +24,7 @@ class BraveSearchToolSchema(BaseModel):
|
||||
|
||||
|
||||
class BraveSearchTool(BaseTool):
|
||||
"""
|
||||
BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
|
||||
|
||||
This module provides functionality to search the internet using Brave's Search API,
|
||||
supporting customizable result counts and country-specific searches.
|
||||
@@ -41,18 +39,22 @@ class BraveSearchTool(BaseTool):
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query."
|
||||
)
|
||||
args_schema: Type[BaseModel] = BraveSearchToolSchema
|
||||
args_schema: type[BaseModel] = BraveSearchToolSchema
|
||||
search_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
country: Optional[str] = ""
|
||||
country: str | None = ""
|
||||
n_results: int = 10
|
||||
save_file: bool = False
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY", description="API key for Brave Search", required=True
|
||||
),
|
||||
]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="BRAVE_API_KEY",
|
||||
description="API key for Brave Search",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -89,7 +91,9 @@ class BraveSearchTool(BaseTool):
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(self.search_url, headers=headers, params=payload)
|
||||
response = requests.get(
|
||||
self.search_url, headers=headers, params=payload, timeout=30
|
||||
)
|
||||
response.raise_for_status() # Handle non-200 responses
|
||||
results = response.json()
|
||||
|
||||
@@ -108,7 +112,7 @@ class BraveSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
except KeyError: # noqa: PERF203
|
||||
continue
|
||||
|
||||
content = "\n".join(string)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
@@ -23,7 +23,7 @@ class BrightDataConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BrightDataDatasetToolException(Exception):
|
||||
class BrightDataDatasetToolException(Exception): # noqa: N818
|
||||
"""Exception raised for custom error in the application."""
|
||||
|
||||
def __init__(self, message, error_code):
|
||||
@@ -36,8 +36,7 @@ class BrightDataDatasetToolException(Exception):
|
||||
|
||||
|
||||
class BrightDataDatasetToolSchema(BaseModel):
|
||||
"""
|
||||
Schema for validating input parameters for the BrightDataDatasetTool.
|
||||
"""Schema for validating input parameters for the BrightDataDatasetTool.
|
||||
|
||||
Attributes:
|
||||
dataset_type (str): Required Bright Data Dataset Type used to specify which dataset to access.
|
||||
@@ -48,12 +47,12 @@ class BrightDataDatasetToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
dataset_type: str = Field(..., description="The Bright Data Dataset Type")
|
||||
format: Optional[str] = Field(
|
||||
format: str | None = Field(
|
||||
default="json", description="Response format (json by default)"
|
||||
)
|
||||
url: str = Field(..., description="The URL to extract data from")
|
||||
zipcode: Optional[str] = Field(default=None, description="Optional zipcode")
|
||||
additional_params: Optional[Dict[str, Any]] = Field(
|
||||
zipcode: str | None = Field(default=None, description="Optional zipcode")
|
||||
additional_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional params if any"
|
||||
)
|
||||
|
||||
@@ -399,8 +398,7 @@ datasets = [
|
||||
|
||||
|
||||
class BrightDataDatasetTool(BaseTool):
|
||||
"""
|
||||
CrewAI-compatible tool for scraping structured data using Bright Data Datasets.
|
||||
"""CrewAI-compatible tool for scraping structured data using Bright Data Datasets.
|
||||
|
||||
Attributes:
|
||||
name (str): Tool name displayed in the CrewAI environment.
|
||||
@@ -410,20 +408,20 @@ class BrightDataDatasetTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data Dataset Tool"
|
||||
description: str = "Scrapes structured data using Bright Data Dataset API from a URL and optional input parameters"
|
||||
args_schema: Type[BaseModel] = BrightDataDatasetToolSchema
|
||||
dataset_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
args_schema: type[BaseModel] = BrightDataDatasetToolSchema
|
||||
dataset_type: str | None = None
|
||||
url: str | None = None
|
||||
format: str = "json"
|
||||
zipcode: Optional[str] = None
|
||||
additional_params: Optional[Dict[str, Any]] = None
|
||||
zipcode: str | None = None
|
||||
additional_params: dict[str, Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_type: str = None,
|
||||
url: str = None,
|
||||
dataset_type: str | None = None,
|
||||
url: str | None = None,
|
||||
format: str = "json",
|
||||
zipcode: str = None,
|
||||
additional_params: Dict[str, Any] = None,
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_type = dataset_type
|
||||
@@ -440,12 +438,11 @@ class BrightDataDatasetTool(BaseTool):
|
||||
dataset_type: str,
|
||||
output_format: str,
|
||||
url: str,
|
||||
zipcode: Optional[str] = None,
|
||||
additional_params: Optional[Dict[str, Any]] = None,
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
polling_interval: int = 1,
|
||||
) -> Dict:
|
||||
"""
|
||||
Asynchronously trigger and poll Bright Data dataset scraping.
|
||||
) -> dict:
|
||||
"""Asynchronously trigger and poll Bright Data dataset scraping.
|
||||
|
||||
Args:
|
||||
dataset_type (str): Bright Data Dataset Type.
|
||||
@@ -500,7 +497,6 @@ class BrightDataDatasetTool(BaseTool):
|
||||
trigger_response.status,
|
||||
)
|
||||
trigger_data = await trigger_response.json()
|
||||
print(trigger_data)
|
||||
snapshot_id = trigger_data.get("snapshot_id")
|
||||
|
||||
# Step 2: Poll for completion
|
||||
@@ -520,7 +516,6 @@ class BrightDataDatasetTool(BaseTool):
|
||||
)
|
||||
status_data = await status_response.json()
|
||||
if status_data.get("status") == "ready":
|
||||
print("Job is ready")
|
||||
break
|
||||
if status_data.get("status") == "error":
|
||||
raise BrightDataDatasetToolException(
|
||||
@@ -545,11 +540,11 @@ class BrightDataDatasetTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
url: str = None,
|
||||
dataset_type: str = None,
|
||||
format: str = None,
|
||||
zipcode: str = None,
|
||||
additional_params: Dict[str, Any] = None,
|
||||
url: str | None = None,
|
||||
dataset_type: str | None = None,
|
||||
format: str | None = None,
|
||||
zipcode: str | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
dataset_type = dataset_type or self.dataset_type
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
import urllib.parse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -20,8 +20,7 @@ class BrightDataConfig(BaseModel):
|
||||
|
||||
|
||||
class BrightDataSearchToolSchema(BaseModel):
|
||||
"""
|
||||
Schema that defines the input arguments for the BrightDataSearchToolSchema.
|
||||
"""Schema that defines the input arguments for the BrightDataSearchToolSchema.
|
||||
|
||||
Attributes:
|
||||
query (str): The search query to be executed (e.g., "latest AI news").
|
||||
@@ -34,35 +33,34 @@ class BrightDataSearchToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
query: str = Field(..., description="Search query to perform")
|
||||
search_engine: Optional[str] = Field(
|
||||
search_engine: str | None = Field(
|
||||
default="google",
|
||||
description="Search engine domain (e.g., 'google', 'bing', 'yandex')",
|
||||
)
|
||||
country: Optional[str] = Field(
|
||||
country: str | None = Field(
|
||||
default="us",
|
||||
description="Two-letter country code for geo-targeting (e.g., 'us', 'gb')",
|
||||
)
|
||||
language: Optional[str] = Field(
|
||||
language: str | None = Field(
|
||||
default="en",
|
||||
description="Language code (e.g., 'en', 'es') used in the query URL",
|
||||
)
|
||||
search_type: Optional[str] = Field(
|
||||
search_type: str | None = Field(
|
||||
default=None,
|
||||
description="Type of search (e.g., 'isch' for images, 'nws' for news)",
|
||||
)
|
||||
device_type: Optional[str] = Field(
|
||||
device_type: str | None = Field(
|
||||
default="desktop",
|
||||
description="Device type to simulate (e.g., 'mobile', 'desktop', 'ios')",
|
||||
)
|
||||
parse_results: Optional[bool] = Field(
|
||||
parse_results: bool | None = Field(
|
||||
default=True,
|
||||
description="Whether to parse and return JSON (True) or raw HTML/text (False)",
|
||||
)
|
||||
|
||||
|
||||
class BrightDataSearchTool(BaseTool):
|
||||
"""
|
||||
A web search tool that utilizes Bright Data's SERP API to perform queries and return either structured results
|
||||
"""A web search tool that utilizes Bright Data's SERP API to perform queries and return either structured results
|
||||
or raw page content from search engines like Google or Bing.
|
||||
|
||||
Attributes:
|
||||
@@ -79,26 +77,26 @@ class BrightDataSearchTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data SERP Search"
|
||||
description: str = "Tool to perform web search using Bright Data SERP API."
|
||||
args_schema: Type[BaseModel] = BrightDataSearchToolSchema
|
||||
args_schema: type[BaseModel] = BrightDataSearchToolSchema
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
query: Optional[str] = None
|
||||
query: str | None = None
|
||||
search_engine: str = "google"
|
||||
country: str = "us"
|
||||
language: str = "en"
|
||||
search_type: Optional[str] = None
|
||||
search_type: str | None = None
|
||||
device_type: str = "desktop"
|
||||
parse_results: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str = None,
|
||||
query: str | None = None,
|
||||
search_engine: str = "google",
|
||||
country: str = "us",
|
||||
language: str = "en",
|
||||
search_type: str = None,
|
||||
search_type: str | None = None,
|
||||
device_type: str = "desktop",
|
||||
parse_results: bool = True,
|
||||
):
|
||||
@@ -128,17 +126,16 @@ class BrightDataSearchTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str = None,
|
||||
search_engine: str = None,
|
||||
country: str = None,
|
||||
language: str = None,
|
||||
search_type: str = None,
|
||||
device_type: str = None,
|
||||
parse_results: bool = None,
|
||||
query: str | None = None,
|
||||
search_engine: str | None = None,
|
||||
country: str | None = None,
|
||||
language: str | None = None,
|
||||
search_type: str | None = None,
|
||||
device_type: str | None = None,
|
||||
parse_results: bool | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Executes a search query using Bright Data SERP API and returns results.
|
||||
"""Executes a search query using Bright Data SERP API and returns results.
|
||||
|
||||
Args:
|
||||
query (str): The search query string (URL encoded internally).
|
||||
@@ -153,7 +150,6 @@ class BrightDataSearchTool(BaseTool):
|
||||
Returns:
|
||||
dict or str: Parsed JSON data from Bright Data if available, otherwise error message.
|
||||
"""
|
||||
|
||||
query = query or self.query
|
||||
search_engine = search_engine or self.search_engine
|
||||
country = country or self.country
|
||||
@@ -218,10 +214,9 @@ class BrightDataSearchTool(BaseTool):
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.base_url, json=request_params, headers=headers
|
||||
self.base_url, json=request_params, headers=headers, timeout=30
|
||||
)
|
||||
|
||||
print(f"Status code: {response.status_code}")
|
||||
response.raise_for_status()
|
||||
|
||||
return response.text
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -19,8 +19,7 @@ class BrightDataConfig(BaseModel):
|
||||
|
||||
|
||||
class BrightDataUnlockerToolSchema(BaseModel):
|
||||
"""
|
||||
Pydantic schema for input parameters used by the BrightDataWebUnlockerTool.
|
||||
"""Pydantic schema for input parameters used by the BrightDataWebUnlockerTool.
|
||||
|
||||
This schema defines the structure and validation for parameters passed when performing
|
||||
a web scraping request using Bright Data's Web Unlocker.
|
||||
@@ -32,17 +31,16 @@ class BrightDataUnlockerToolSchema(BaseModel):
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="URL to perform the web scraping")
|
||||
format: Optional[str] = Field(
|
||||
format: str | None = Field(
|
||||
default="raw", description="Response format (raw is standard)"
|
||||
)
|
||||
data_format: Optional[str] = Field(
|
||||
data_format: str | None = Field(
|
||||
default="markdown", description="Response data format (html by default)"
|
||||
)
|
||||
|
||||
|
||||
class BrightDataWebUnlockerTool(BaseTool):
|
||||
"""
|
||||
A tool for performing web scraping using the Bright Data Web Unlocker API.
|
||||
"""A tool for performing web scraping using the Bright Data Web Unlocker API.
|
||||
|
||||
This tool allows automated and programmatic access to web pages by routing requests
|
||||
through Bright Data's unlocking and proxy infrastructure, which can bypass bot
|
||||
@@ -63,17 +61,17 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
|
||||
name: str = "Bright Data Web Unlocker Scraping"
|
||||
description: str = "Tool to perform web scraping using Bright Data Web Unlocker"
|
||||
args_schema: Type[BaseModel] = BrightDataUnlockerToolSchema
|
||||
args_schema: type[BaseModel] = BrightDataUnlockerToolSchema
|
||||
_config = BrightDataConfig.from_env()
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
zone: str = ""
|
||||
url: Optional[str] = None
|
||||
url: str | None = None
|
||||
format: str = "raw"
|
||||
data_format: str = "markdown"
|
||||
|
||||
def __init__(
|
||||
self, url: str = None, format: str = "raw", data_format: str = "markdown"
|
||||
self, url: str | None = None, format: str = "raw", data_format: str = "markdown"
|
||||
):
|
||||
super().__init__()
|
||||
self.base_url = self._config.API_URL
|
||||
@@ -90,9 +88,9 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
url: str = None,
|
||||
format: str = None,
|
||||
data_format: str = None,
|
||||
url: str | None = None,
|
||||
format: str | None = None,
|
||||
data_format: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
url = url or self.url
|
||||
@@ -122,8 +120,9 @@ class BrightDataWebUnlockerTool(BaseTool):
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.base_url, json=payload, headers=headers)
|
||||
print(f"Status Code: {response.status_code}")
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=headers, timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.text
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,34 +12,36 @@ class BrowserbaseLoadToolSchema(BaseModel):
|
||||
class BrowserbaseLoadTool(BaseTool):
|
||||
name: str = "Browserbase web load tool"
|
||||
description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: Optional[bool] = False
|
||||
session_id: Optional[str] = None
|
||||
proxy: Optional[bool] = None
|
||||
browserbase: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["browserbase"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="BROWSERBASE_API_KEY",
|
||||
description="API key for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="BROWSERBASE_PROJECT_ID",
|
||||
description="Project ID for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
args_schema: type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: str | None = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: str | None = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: bool | None = False
|
||||
session_id: str | None = None
|
||||
proxy: bool | None = None
|
||||
browserbase: Any | None = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["browserbase"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="BROWSERBASE_API_KEY",
|
||||
description="API key for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="BROWSERBASE_PROJECT_ID",
|
||||
description="Project ID for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
text_content: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
api_key: str | None = None,
|
||||
project_id: str | None = None,
|
||||
text_content: bool | None = False,
|
||||
session_id: str | None = None,
|
||||
proxy: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -57,12 +59,12 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "browserbase"], check=True)
|
||||
subprocess.run(["uv", "add", "browserbase"], check=True) # noqa: S607
|
||||
from browserbase import Browserbase # type: ignore
|
||||
else:
|
||||
raise ImportError(
|
||||
"`browserbase` package not found, please run `uv add browserbase`"
|
||||
)
|
||||
) from None
|
||||
|
||||
self.browserbase = Browserbase(api_key=self.api_key)
|
||||
self.text_content = text_content
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -26,9 +25,9 @@ class CodeDocsSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Code Docs content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||
args_schema: type[BaseModel] = CodeDocsSearchToolSchema
|
||||
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, docs_url: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
@@ -42,7 +41,7 @@ class CodeDocsSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docs_url: Optional[str] = None,
|
||||
docs_url: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
|
||||
@@ -8,15 +8,16 @@ potentially unsafe operations and importing restricted modules.
|
||||
import importlib.util
|
||||
import os
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.printer import Printer
|
||||
from docker import DockerClient, from_env as docker_from_env
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.printer import Printer
|
||||
|
||||
|
||||
class CodeInterpreterSchema(BaseModel):
|
||||
"""Schema for defining inputs to the CodeInterpreterTool.
|
||||
@@ -30,7 +31,7 @@ class CodeInterpreterSchema(BaseModel):
|
||||
description="Python3 code used to be interpreted in the Docker container. ALWAYS PRINT the final result and the output of the code",
|
||||
)
|
||||
|
||||
libraries_used: List[str] = Field(
|
||||
libraries_used: list[str] = Field(
|
||||
...,
|
||||
description="List of libraries used in the code with proper installing names separated by commas. Example: numpy,pandas,beautifulsoup4",
|
||||
)
|
||||
@@ -44,7 +45,7 @@ class SandboxPython:
|
||||
environment where harmful operations are blocked.
|
||||
"""
|
||||
|
||||
BLOCKED_MODULES = {
|
||||
BLOCKED_MODULES: ClassVar[set[str]] = {
|
||||
"os",
|
||||
"sys",
|
||||
"subprocess",
|
||||
@@ -56,7 +57,7 @@ class SandboxPython:
|
||||
"builtins",
|
||||
}
|
||||
|
||||
UNSAFE_BUILTINS = {
|
||||
UNSAFE_BUILTINS: ClassVar[set[str]] = {
|
||||
"exec",
|
||||
"eval",
|
||||
"open",
|
||||
@@ -72,9 +73,9 @@ class SandboxPython:
|
||||
@staticmethod
|
||||
def restricted_import(
|
||||
name: str,
|
||||
custom_globals: Optional[Dict[str, Any]] = None,
|
||||
custom_locals: Optional[Dict[str, Any]] = None,
|
||||
fromlist: Optional[List[str]] = None,
|
||||
custom_globals: dict[str, Any] | None = None,
|
||||
custom_locals: dict[str, Any] | None = None,
|
||||
fromlist: list[str] | None = None,
|
||||
level: int = 0,
|
||||
) -> ModuleType:
|
||||
"""A restricted import function that blocks importing of unsafe modules.
|
||||
@@ -97,7 +98,7 @@ class SandboxPython:
|
||||
return __import__(name, custom_globals, custom_locals, fromlist or (), level)
|
||||
|
||||
@staticmethod
|
||||
def safe_builtins() -> Dict[str, Any]:
|
||||
def safe_builtins() -> dict[str, Any]:
|
||||
"""Creates a dictionary of built-in functions with unsafe ones removed.
|
||||
|
||||
Returns:
|
||||
@@ -114,14 +115,14 @@ class SandboxPython:
|
||||
return safe_builtins
|
||||
|
||||
@staticmethod
|
||||
def exec(code: str, locals: Dict[str, Any]) -> None:
|
||||
def exec(code: str, locals: dict[str, Any]) -> None:
|
||||
"""Executes Python code in a restricted environment.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
locals: A dictionary that will be used for local variable storage.
|
||||
"""
|
||||
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals)
|
||||
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals) # noqa: S102
|
||||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
@@ -134,11 +135,11 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
name: str = "Code Interpreter"
|
||||
description: str = "Interprets Python3 code strings with a final print statement."
|
||||
args_schema: Type[BaseModel] = CodeInterpreterSchema
|
||||
args_schema: type[BaseModel] = CodeInterpreterSchema
|
||||
default_image_tag: str = "code-interpreter:latest"
|
||||
code: Optional[str] = None
|
||||
user_dockerfile_path: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
code: str | None = None
|
||||
user_dockerfile_path: str | None = None
|
||||
user_docker_base_url: str | None = None
|
||||
unsafe_mode: bool = False
|
||||
|
||||
@staticmethod
|
||||
@@ -160,7 +161,6 @@ class CodeInterpreterTool(BaseTool):
|
||||
Raises:
|
||||
FileNotFoundError: If the Dockerfile cannot be found.
|
||||
"""
|
||||
|
||||
client = (
|
||||
docker_from_env()
|
||||
if self.user_docker_base_url is None
|
||||
@@ -181,7 +181,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
if not os.path.exists(dockerfile_path):
|
||||
raise FileNotFoundError(
|
||||
f"Dockerfile not found in {dockerfile_path}"
|
||||
)
|
||||
) from None
|
||||
|
||||
client.images.build(
|
||||
path=dockerfile_path,
|
||||
@@ -205,7 +205,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
return self.run_code_unsafe(code, libraries_used)
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
|
||||
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: list[str]) -> None:
|
||||
"""Installs required Python libraries in the Docker container.
|
||||
|
||||
Args:
|
||||
@@ -258,7 +258,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
["docker", "info"], # noqa: S607
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
@@ -275,7 +275,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
Printer.print("Docker is not installed", color="bold_purple")
|
||||
return False
|
||||
|
||||
def run_code_safety(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_safety(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code in the safest available environment.
|
||||
|
||||
Attempts to run code in Docker if available, falls back to a restricted
|
||||
@@ -292,7 +292,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
return self.run_code_in_restricted_sandbox(code)
|
||||
|
||||
def run_code_in_docker(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_in_docker(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs Python code in a Docker container for safe isolation.
|
||||
|
||||
Creates a Docker container, installs the required libraries, executes the code,
|
||||
@@ -340,7 +340,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
def run_code_unsafe(self, code: str, libraries_used: List[str]) -> str:
|
||||
def run_code_unsafe(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code directly on the host machine without any safety restrictions.
|
||||
|
||||
WARNING: This mode is unsafe and should only be used in trusted environments
|
||||
@@ -354,16 +354,15 @@ class CodeInterpreterTool(BaseTool):
|
||||
The value of the 'result' variable from the executed code,
|
||||
or an error message if execution failed.
|
||||
"""
|
||||
|
||||
Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
|
||||
# Install libraries on the host machine
|
||||
for library in libraries_used:
|
||||
os.system(f"pip install {library}")
|
||||
os.system(f"pip install {library}") # noqa: S605
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
exec_locals = {}
|
||||
exec(code, {}, exec_locals)
|
||||
exec(code, {}, exec_locals) # noqa: S102
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Composio tools wrapper.
|
||||
"""
|
||||
"""Composio tools wrapper."""
|
||||
|
||||
import typing as t
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import Field
|
||||
import typing_extensions as te
|
||||
|
||||
|
||||
@@ -12,13 +11,15 @@ class ComposioTool(BaseTool):
|
||||
"""Wrapper for composio tools."""
|
||||
|
||||
composio_action: t.Callable
|
||||
env_vars: t.List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="COMPOSIO_API_KEY",
|
||||
description="API key for Composio services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="COMPOSIO_API_KEY",
|
||||
description="API key for Composio services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
"""Run the composio action with given arguments."""
|
||||
@@ -35,7 +36,7 @@ class ComposioTool(BaseTool):
|
||||
return
|
||||
|
||||
connections = t.cast(
|
||||
t.List[ConnectedAccountModel],
|
||||
list[ConnectedAccountModel],
|
||||
toolset.client.connected_accounts.get(),
|
||||
)
|
||||
if tool.app not in [connection.appUniqueId for connection in connections]:
|
||||
@@ -51,7 +52,6 @@ class ComposioTool(BaseTool):
|
||||
**kwargs: t.Any,
|
||||
) -> te.Self:
|
||||
"""Wrap a composio tool as crewAI tool."""
|
||||
|
||||
from composio import Action, ComposioToolSet
|
||||
from composio.constants import DEFAULT_ENTITY_ID
|
||||
from composio.utils.shared import json_schema_to_model
|
||||
@@ -70,7 +70,7 @@ class ComposioTool(BaseTool):
|
||||
schema = action_schema.model_dump(exclude_none=True)
|
||||
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
|
||||
|
||||
def function(**kwargs: t.Any) -> t.Dict:
|
||||
def function(**kwargs: t.Any) -> dict:
|
||||
"""Wrapper function for composio action."""
|
||||
return toolset.execute_action(
|
||||
action=Action(schema["name"]),
|
||||
@@ -97,10 +97,10 @@ class ComposioTool(BaseTool):
|
||||
def from_app(
|
||||
cls,
|
||||
*apps: t.Any,
|
||||
tags: t.Optional[t.List[str]] = None,
|
||||
use_case: t.Optional[str] = None,
|
||||
tags: list[str] | None = None,
|
||||
use_case: str | None = None,
|
||||
**kwargs: t.Any,
|
||||
) -> t.List[te.Self]:
|
||||
) -> list[te.Self]:
|
||||
"""Create toolset from an app."""
|
||||
if len(apps) == 0:
|
||||
raise ValueError("You need to provide at least one app name")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -10,7 +10,7 @@ class ContextualAICreateAgentSchema(BaseModel):
|
||||
agent_name: str = Field(..., description="Name for the new agent")
|
||||
agent_description: str = Field(..., description="Description for the new agent")
|
||||
datastore_name: str = Field(..., description="Name for the new datastore")
|
||||
document_paths: List[str] = Field(..., description="List of file paths to upload")
|
||||
document_paths: list[str] = Field(..., description="List of file paths to upload")
|
||||
|
||||
|
||||
class ContextualAICreateAgentTool(BaseTool):
|
||||
@@ -20,11 +20,13 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
description: str = (
|
||||
"Create a new Contextual AI RAG agent with documents and datastore"
|
||||
)
|
||||
args_schema: Type[BaseModel] = ContextualAICreateAgentSchema
|
||||
args_schema: type[BaseModel] = ContextualAICreateAgentSchema
|
||||
|
||||
api_key: str
|
||||
contextual_client: Any = None
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -32,17 +34,17 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
from contextual import ContextualAI
|
||||
|
||||
self.contextual_client = ContextualAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"contextual-client package is required. Install it with: pip install contextual-client"
|
||||
)
|
||||
) from e
|
||||
|
||||
def _run(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_description: str,
|
||||
datastore_name: str,
|
||||
document_paths: List[str],
|
||||
document_paths: list[str],
|
||||
) -> str:
|
||||
"""Create a complete RAG pipeline with documents."""
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,10 +13,10 @@ class ContextualAIParseSchema(BaseModel):
|
||||
enable_document_hierarchy: bool = Field(
|
||||
default=True, description="Enable document hierarchy"
|
||||
)
|
||||
page_range: Optional[str] = Field(
|
||||
page_range: str | None = Field(
|
||||
default=None, description="Page range to parse (e.g., '0-5')"
|
||||
)
|
||||
output_types: List[str] = Field(
|
||||
output_types: list[str] = Field(
|
||||
default=["markdown-per-page"], description="List of output types"
|
||||
)
|
||||
|
||||
@@ -28,10 +26,12 @@ class ContextualAIParseTool(BaseTool):
|
||||
|
||||
name: str = "Contextual AI Document Parser"
|
||||
description: str = "Parse documents using Contextual AI's advanced document parser"
|
||||
args_schema: Type[BaseModel] = ContextualAIParseSchema
|
||||
args_schema: type[BaseModel] = ContextualAIParseSchema
|
||||
|
||||
api_key: str
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -39,10 +39,12 @@ class ContextualAIParseTool(BaseTool):
|
||||
parse_mode: str = "standard",
|
||||
figure_caption_mode: str = "concise",
|
||||
enable_document_hierarchy: bool = True,
|
||||
page_range: Optional[str] = None,
|
||||
output_types: List[str] = ["markdown-per-page"],
|
||||
page_range: str | None = None,
|
||||
output_types: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Parse a document using Contextual AI's parser."""
|
||||
if output_types is None:
|
||||
output_types = ["markdown-per-page"]
|
||||
try:
|
||||
import json
|
||||
import os
|
||||
@@ -72,14 +74,16 @@ class ContextualAIParseTool(BaseTool):
|
||||
|
||||
with open(file_path, "rb") as fp:
|
||||
file = {"raw_file": fp}
|
||||
result = requests.post(url, headers=headers, data=config, files=file)
|
||||
result = requests.post(
|
||||
url, headers=headers, data=config, files=file, timeout=30
|
||||
)
|
||||
response = json.loads(result.text)
|
||||
job_id = response["job_id"]
|
||||
|
||||
# Monitor job status
|
||||
status_url = f"{base_url}/parse/jobs/{job_id}/status"
|
||||
while True:
|
||||
result = requests.get(status_url, headers=headers)
|
||||
result = requests.get(status_url, headers=headers, timeout=30)
|
||||
parse_response = json.loads(result.text)["status"]
|
||||
|
||||
if parse_response == "completed":
|
||||
@@ -95,6 +99,7 @@ class ContextualAIParseTool(BaseTool):
|
||||
results_url,
|
||||
headers=headers,
|
||||
params={"output_types": ",".join(output_types)},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return json.dumps(json.loads(result.text), indent=2)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -11,7 +11,7 @@ class ContextualAIQuerySchema(BaseModel):
|
||||
|
||||
query: str = Field(..., description="Query to send to the Contextual AI agent.")
|
||||
agent_id: str = Field(..., description="ID of the Contextual AI agent to query")
|
||||
datastore_id: Optional[str] = Field(
|
||||
datastore_id: str | None = Field(
|
||||
None, description="Optional datastore ID for document readiness verification"
|
||||
)
|
||||
|
||||
@@ -23,11 +23,13 @@ class ContextualAIQueryTool(BaseTool):
|
||||
description: str = (
|
||||
"Use this tool to query a Contextual AI RAG agent with access to your documents"
|
||||
)
|
||||
args_schema: Type[BaseModel] = ContextualAIQuerySchema
|
||||
args_schema: type[BaseModel] = ContextualAIQuerySchema
|
||||
|
||||
api_key: str
|
||||
contextual_client: Any = None
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -35,16 +37,16 @@ class ContextualAIQueryTool(BaseTool):
|
||||
from contextual import ContextualAI
|
||||
|
||||
self.contextual_client = ContextualAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"contextual-client package is required. Install it with: pip install contextual-client"
|
||||
)
|
||||
) from e
|
||||
|
||||
def _check_documents_ready(self, datastore_id: str) -> bool:
|
||||
"""Synchronous check if all documents are ready."""
|
||||
url = f"https://api.contextual.ai/v1/datastores/{datastore_id}/documents"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.get(url, headers=headers)
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
documents = data.get("documents", [])
|
||||
@@ -57,17 +59,14 @@ class ContextualAIQueryTool(BaseTool):
|
||||
self, datastore_id: str, max_attempts: int = 20, interval: float = 30.0
|
||||
) -> bool:
|
||||
"""Asynchronously poll until documents are ready, exiting early if possible."""
|
||||
for attempt in range(max_attempts):
|
||||
for _attempt in range(max_attempts):
|
||||
ready = await asyncio.to_thread(self._check_documents_ready, datastore_id)
|
||||
if ready:
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
print("Processing documents ...")
|
||||
return True # give up but don't fail hard
|
||||
|
||||
def _run(
|
||||
self, query: str, agent_id: str, datastore_id: Optional[str] = None
|
||||
) -> str:
|
||||
def _run(self, query: str, agent_id: str, datastore_id: str | None = None) -> str:
|
||||
if not agent_id:
|
||||
raise ValueError("Agent ID is required to query the Contextual AI agent")
|
||||
|
||||
@@ -89,14 +88,12 @@ class ContextualAIQueryTool(BaseTool):
|
||||
loop.run_until_complete(
|
||||
self._wait_for_documents_async(datastore_id)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to apply nest_asyncio: {e!s}")
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
else:
|
||||
asyncio.run(self._wait_for_documents_async(datastore_id))
|
||||
else:
|
||||
print(
|
||||
"Warning: No datastore_id provided. Document status checking disabled."
|
||||
)
|
||||
pass
|
||||
|
||||
try:
|
||||
response = self.contextual_client.agents.query.create(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,11 +6,11 @@ class ContextualAIRerankSchema(BaseModel):
|
||||
"""Schema for contextual rerank tool."""
|
||||
|
||||
query: str = Field(..., description="The search query to rerank documents against")
|
||||
documents: List[str] = Field(..., description="List of document texts to rerank")
|
||||
instruction: Optional[str] = Field(
|
||||
documents: list[str] = Field(..., description="List of document texts to rerank")
|
||||
instruction: str | None = Field(
|
||||
default=None, description="Optional instruction for reranking behavior"
|
||||
)
|
||||
metadata: Optional[List[str]] = Field(
|
||||
metadata: list[str] | None = Field(
|
||||
default=None, description="Optional metadata for each document"
|
||||
)
|
||||
model: str = Field(
|
||||
@@ -27,17 +25,19 @@ class ContextualAIRerankTool(BaseTool):
|
||||
description: str = (
|
||||
"Rerank documents using Contextual AI's instruction-following reranker"
|
||||
)
|
||||
args_schema: Type[BaseModel] = ContextualAIRerankSchema
|
||||
args_schema: type[BaseModel] = ContextualAIRerankSchema
|
||||
|
||||
api_key: str
|
||||
package_dependencies: List[str] = ["contextual-client"]
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
instruction: Optional[str] = None,
|
||||
metadata: Optional[List[str]] = None,
|
||||
documents: list[str],
|
||||
instruction: str | None = None,
|
||||
metadata: list[str] | None = None,
|
||||
model: str = "ctxl-rerank-en-v1-instruct",
|
||||
) -> str:
|
||||
"""Rerank documents using Contextual AI's instruction-following reranker."""
|
||||
@@ -66,7 +66,9 @@ class ContextualAIRerankTool(BaseTool):
|
||||
payload["metadata"] = metadata
|
||||
|
||||
rerank_url = f"{base_url}/rerank"
|
||||
result = requests.post(rerank_url, json=payload, headers=headers)
|
||||
result = requests.post(
|
||||
rerank_url, json=payload, headers=headers, timeout=30
|
||||
)
|
||||
|
||||
if result.status_code != 200:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
@@ -18,7 +19,7 @@ except ImportError:
|
||||
VectorSearch = Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from pydantic import BaseModel, ConfigDict, Field, SkipValidation
|
||||
|
||||
|
||||
class CouchbaseToolSchema(BaseModel):
|
||||
@@ -31,35 +32,35 @@ class CouchbaseToolSchema(BaseModel):
|
||||
|
||||
|
||||
class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
"""Tool to search the Couchbase database"""
|
||||
"""Tool to search the Couchbase database."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
name: str = "CouchbaseFTSVectorSearchTool"
|
||||
description: str = "A tool to search the Couchbase database for relevant information on internal documents."
|
||||
args_schema: Type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Optional[Cluster]] = None
|
||||
collection_name: Optional[str] = (None,)
|
||||
scope_name: Optional[str] = (None,)
|
||||
bucket_name: Optional[str] = (None,)
|
||||
index_name: Optional[str] = (None,)
|
||||
embedding_key: Optional[str] = Field(
|
||||
args_schema: type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Cluster | None] = None
|
||||
collection_name: str | None = (None,)
|
||||
scope_name: str | None = (None,)
|
||||
bucket_name: str | None = (None,)
|
||||
index_name: str | None = (None,)
|
||||
embedding_key: str | None = Field(
|
||||
default="embedding",
|
||||
description="Name of the field in the search index that stores the vector",
|
||||
)
|
||||
scoped_index: Optional[bool] = (
|
||||
scoped_index: bool | None = (
|
||||
Field(
|
||||
default=True,
|
||||
description="Specify whether the index is scoped. Is True by default.",
|
||||
),
|
||||
)
|
||||
limit: Optional[int] = Field(default=3)
|
||||
embedding_function: SkipValidation[Callable[[str], List[float]]] = Field(
|
||||
limit: int | None = Field(default=3)
|
||||
embedding_function: SkipValidation[Callable[[str], list[float]]] = Field(
|
||||
default=None,
|
||||
description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database.",
|
||||
)
|
||||
|
||||
def _check_bucket_exists(self) -> bool:
|
||||
"""Check if the bucket exists in the linked Couchbase cluster"""
|
||||
"""Check if the bucket exists in the linked Couchbase cluster."""
|
||||
bucket_manager = self.cluster.buckets()
|
||||
try:
|
||||
bucket_manager.get_bucket(self.bucket_name)
|
||||
@@ -69,8 +70,9 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
|
||||
def _check_scope_and_collection_exists(self) -> bool:
|
||||
"""Check if the scope and collection exists in the linked Couchbase bucket
|
||||
Raises a ValueError if either is not found"""
|
||||
scope_collection_map: Dict[str, Any] = {}
|
||||
Raises a ValueError if either is not found.
|
||||
"""
|
||||
scope_collection_map: dict[str, Any] = {}
|
||||
|
||||
# Get a list of all scopes in the bucket
|
||||
for scope in self._bucket.collections().get_all_scopes():
|
||||
@@ -98,7 +100,8 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
|
||||
def _check_index_exists(self) -> bool:
|
||||
"""Check if the Search index exists in the linked Couchbase cluster
|
||||
Raises a ValueError if the index does not exist"""
|
||||
Raises a ValueError if the index does not exist.
|
||||
"""
|
||||
if self.scoped_index:
|
||||
all_indexes = [
|
||||
index.name for index in self._scope.search_indexes().get_all_indexes()
|
||||
@@ -182,7 +185,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "couchbase"], check=True)
|
||||
subprocess.run(["uv", "add", "couchbase"], check=True) # noqa: S607
|
||||
else:
|
||||
raise ImportError(
|
||||
"The 'couchbase' package is required to use the CouchbaseFTSVectorSearchTool. "
|
||||
@@ -230,7 +233,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
json_response = []
|
||||
|
||||
for row in search_iter.rows():
|
||||
json_response.append(row.fields)
|
||||
json_response.append(row.fields) # noqa: PERF401
|
||||
except Exception as e:
|
||||
return f"Search failed with error: {e}"
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""
|
||||
Crewai Enterprise Tools
|
||||
"""
|
||||
"""Crewai Enterprise Tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
@@ -15,11 +13,11 @@ from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CrewaiEnterpriseTools(
|
||||
enterprise_token: t.Optional[str] = None,
|
||||
actions_list: t.Optional[t.List[str]] = None,
|
||||
enterprise_action_kit_project_id: t.Optional[str] = None,
|
||||
enterprise_action_kit_project_url: t.Optional[str] = None,
|
||||
def CrewaiEnterpriseTools( # noqa: N802
|
||||
enterprise_token: str | None = None,
|
||||
actions_list: list[str] | None = None,
|
||||
enterprise_action_kit_project_id: str | None = None,
|
||||
enterprise_action_kit_project_url: str | None = None,
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai enterprise tools.
|
||||
|
||||
@@ -34,7 +32,6 @@ def CrewaiEnterpriseTools(
|
||||
Returns:
|
||||
A ToolCollection of BaseTool instances for enterprise actions
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
@@ -68,7 +65,7 @@ def CrewaiEnterpriseTools(
|
||||
|
||||
|
||||
# ENTERPRISE INJECTION ONLY
|
||||
def _parse_actions_list(actions_list: t.Optional[t.List[str]]) -> t.List[str] | None:
|
||||
def _parse_actions_list(actions_list: list[str] | None) -> list[str] | None:
|
||||
"""Parse a string representation of a list of tool names to a list of tool names.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""CrewAI Platform Tools
|
||||
"""CrewAI Platform Tools.
|
||||
|
||||
This module provides tools for integrating with various platform applications
|
||||
through the CrewAI platform API.
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""
|
||||
Crewai Enterprise Tools
|
||||
"""
|
||||
"""Crewai Enterprise Tools."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast, get_origin
|
||||
from typing import Any, Literal, Optional, Union, cast, get_origin
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import Field, create_model
|
||||
@@ -18,7 +16,7 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
|
||||
class CrewAIPlatformActionTool(BaseTool):
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
action_schema: Dict[str, Any] = Field(
|
||||
action_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the action"
|
||||
)
|
||||
|
||||
@@ -26,7 +24,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
self,
|
||||
description: str,
|
||||
action_name: str,
|
||||
action_schema: Dict[str, Any],
|
||||
action_schema: dict[str, Any],
|
||||
):
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(action_name)
|
||||
@@ -54,8 +52,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema", **field_definitions
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create main schema model: {e}")
|
||||
except Exception:
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema",
|
||||
input_text=(str, Field(description="Input for the action")),
|
||||
@@ -81,8 +78,8 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
return "".join(word.capitalize() for word in parts if word)
|
||||
|
||||
def _extract_schema_info(
|
||||
self, action_schema: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], List[str]]:
|
||||
self, action_schema: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
schema_props = (
|
||||
action_schema.get("function", {})
|
||||
.get("parameters", {})
|
||||
@@ -93,7 +90,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||
@@ -101,8 +98,8 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
|
||||
if non_null_types:
|
||||
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||
return Optional[base_type] if is_nullable else base_type
|
||||
return cast(Type[Any], Optional[str])
|
||||
return Optional[base_type] if is_nullable else base_type # noqa: UP045
|
||||
return cast(type[Any], Optional[str]) # noqa: UP045
|
||||
|
||||
if "oneOf" in schema:
|
||||
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||
@@ -121,7 +118,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return List[item_type]
|
||||
return list[item_type]
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
@@ -129,8 +126,8 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _create_nested_model(
|
||||
self, schema: Dict[str, Any], model_name: str
|
||||
) -> Type[Any]:
|
||||
self, schema: dict[str, Any], model_name: str
|
||||
) -> type[Any]:
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
if full_model_name in self._model_registry:
|
||||
@@ -162,23 +159,22 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
nested_model = create_model(full_model_name, **field_definitions)
|
||||
self._model_registry[full_model_name] = nested_model
|
||||
return nested_model
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create nested model {full_model_name}: {e}")
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: Type[Any], is_required: bool, description: str
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
if is_required:
|
||||
return (field_type, Field(description=description))
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
return (
|
||||
Optional[field_type],
|
||||
Optional[field_type], # noqa: UP045
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
@@ -190,7 +186,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
def _get_required_nullable_fields(self) -> List[str]:
|
||||
def _get_required_nullable_fields(self) -> list[str]:
|
||||
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||
|
||||
required_nullable_fields = []
|
||||
@@ -201,7 +197,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
|
||||
return required_nullable_fields
|
||||
|
||||
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
|
||||
if "anyOf" in schema:
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
return schema.get("type") == "null"
|
||||
@@ -211,7 +207,7 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
cleaned_kwargs[key] = value
|
||||
cleaned_kwargs[key] = value # noqa: PERF403
|
||||
|
||||
required_nullable_fields = self._get_required_nullable_fields()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
import requests
|
||||
@@ -64,8 +64,8 @@ class CrewaiPlatformToolBuilder:
|
||||
self._actions_schema[action_name] = action_schema
|
||||
|
||||
def _generate_detailed_description(
|
||||
self, schema: Dict[str, Any], indent: int = 0
|
||||
) -> List[str]:
|
||||
self, schema: dict[str, Any], indent: int = 0
|
||||
) -> list[str]:
|
||||
descriptions = []
|
||||
indent_str = " " * indent
|
||||
|
||||
|
||||
@@ -11,17 +11,17 @@ from crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder impor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CrewaiPlatformTools(
|
||||
def CrewaiPlatformTools( # noqa: N802
|
||||
apps: list[str],
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai platform tools.
|
||||
|
||||
Args:
|
||||
apps: List of platform apps to get tools that are available on the platform.
|
||||
|
||||
Returns:
|
||||
A list of BaseTool instances for platform actions
|
||||
"""
|
||||
|
||||
builder = CrewaiPlatformToolBuilder(apps=apps)
|
||||
|
||||
return builder.tools()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -26,9 +25,9 @@ class CSVSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a CSV's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
||||
args_schema: type[BaseModel] = CSVSearchToolSchema
|
||||
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
def __init__(self, csv: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
@@ -42,7 +41,7 @@ class CSVSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
csv: Optional[str] = None,
|
||||
csv: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import List, Type
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import OpenAI
|
||||
@@ -17,20 +16,22 @@ class ImagePromptSchema(BaseModel):
|
||||
class DallETool(BaseTool):
|
||||
name: str = "Dall-E Tool"
|
||||
description: str = "Generates images using OpenAI's Dall-E model."
|
||||
args_schema: Type[BaseModel] = ImagePromptSchema
|
||||
args_schema: type[BaseModel] = ImagePromptSchema
|
||||
|
||||
model: str = "dall-e-3"
|
||||
size: str = "1024x1024"
|
||||
quality: str = "standard"
|
||||
n: int = 1
|
||||
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="OPENAI_API_KEY",
|
||||
description="API key for OpenAI services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="OPENAI_API_KEY",
|
||||
description="API key for OpenAI services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
client = OpenAI()
|
||||
@@ -48,11 +49,9 @@ class DallETool(BaseTool):
|
||||
n=self.n,
|
||||
)
|
||||
|
||||
image_data = json.dumps(
|
||||
return json.dumps(
|
||||
{
|
||||
"image_url": response.data[0].url,
|
||||
"image_description": response.data[0].revised_prompt,
|
||||
}
|
||||
)
|
||||
|
||||
return image_data
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -15,19 +15,19 @@ class DatabricksQueryToolSchema(BaseModel):
|
||||
query: str = Field(
|
||||
..., description="SQL query to execute against the Databricks workspace table"
|
||||
)
|
||||
catalog: Optional[str] = Field(
|
||||
catalog: str | None = Field(
|
||||
None,
|
||||
description="Databricks catalog name (optional, defaults to configured catalog)",
|
||||
)
|
||||
db_schema: Optional[str] = Field(
|
||||
db_schema: str | None = Field(
|
||||
None,
|
||||
description="Databricks schema name (optional, defaults to configured schema)",
|
||||
)
|
||||
warehouse_id: Optional[str] = Field(
|
||||
warehouse_id: str | None = Field(
|
||||
None,
|
||||
description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)",
|
||||
)
|
||||
row_limit: Optional[int] = Field(
|
||||
row_limit: int | None = Field(
|
||||
1000, description="Maximum number of rows to return (default: 1000)"
|
||||
)
|
||||
|
||||
@@ -46,8 +46,7 @@ class DatabricksQueryToolSchema(BaseModel):
|
||||
|
||||
|
||||
class DatabricksQueryTool(BaseTool):
|
||||
"""
|
||||
A tool for querying Databricks workspace tables using SQL.
|
||||
"""A tool for querying Databricks workspace tables using SQL.
|
||||
|
||||
This tool executes SQL queries against Databricks tables and returns the results.
|
||||
It requires Databricks authentication credentials to be set as environment variables.
|
||||
@@ -66,25 +65,24 @@ class DatabricksQueryTool(BaseTool):
|
||||
"Execute SQL queries against Databricks workspace tables and return the results."
|
||||
" Provide a 'query' parameter with the SQL query to execute."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DatabricksQueryToolSchema
|
||||
args_schema: type[BaseModel] = DatabricksQueryToolSchema
|
||||
|
||||
# Optional default parameters
|
||||
default_catalog: Optional[str] = None
|
||||
default_schema: Optional[str] = None
|
||||
default_warehouse_id: Optional[str] = None
|
||||
default_catalog: str | None = None
|
||||
default_schema: str | None = None
|
||||
default_warehouse_id: str | None = None
|
||||
|
||||
_workspace_client: Optional["WorkspaceClient"] = None
|
||||
package_dependencies: List[str] = ["databricks-sdk"]
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["databricks-sdk"])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_catalog: Optional[str] = None,
|
||||
default_schema: Optional[str] = None,
|
||||
default_warehouse_id: Optional[str] = None,
|
||||
default_catalog: str | None = None,
|
||||
default_schema: str | None = None,
|
||||
default_warehouse_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the DatabricksQueryTool.
|
||||
"""Initialize the DatabricksQueryTool.
|
||||
|
||||
Args:
|
||||
default_catalog (Optional[str]): Default catalog to use for queries.
|
||||
@@ -119,13 +117,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
self._workspace_client = WorkspaceClient()
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
|
||||
)
|
||||
) from e
|
||||
return self._workspace_client
|
||||
|
||||
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
||||
def _format_results(self, results: list[dict[str, Any]]) -> str:
|
||||
"""Format query results as a readable string."""
|
||||
if not results:
|
||||
return "Query returned no results."
|
||||
@@ -176,8 +174,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a SQL query against Databricks and return the results.
|
||||
"""Execute a SQL query against Databricks and return the results.
|
||||
|
||||
Args:
|
||||
query (str): SQL query to execute
|
||||
@@ -337,9 +334,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
if hasattr(result.result, "data_array"):
|
||||
# Add defensive check for None data_array
|
||||
if result.result.data_array is None:
|
||||
print(
|
||||
"data_array is None - likely an empty result set or DDL query"
|
||||
)
|
||||
# Return empty result handling rather than trying to process null data
|
||||
return "Query executed successfully (no data returned)"
|
||||
|
||||
@@ -418,9 +412,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
"is_likely_incorrect_row_structure" in locals()
|
||||
and is_likely_incorrect_row_structure
|
||||
):
|
||||
print(
|
||||
"Data appears to be malformed - will use special row reconstruction"
|
||||
)
|
||||
needs_special_string_handling = True
|
||||
else:
|
||||
needs_special_string_handling = False
|
||||
@@ -431,7 +422,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
and needs_special_string_handling
|
||||
):
|
||||
# We're dealing with data where the rows may be incorrectly structured
|
||||
print("Using row reconstruction processing mode")
|
||||
|
||||
# Collect all values into a flat list
|
||||
all_values = []
|
||||
@@ -568,7 +558,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
|
||||
if title_idx >= 0:
|
||||
print("Attempting title reconstruction method")
|
||||
# Try to detect if title is split across multiple values
|
||||
i = 0
|
||||
while i < len(all_values):
|
||||
@@ -609,7 +598,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# If we still don't have rows, use simple chunking as fallback
|
||||
if not reconstructed_rows:
|
||||
print("Falling back to basic chunking approach")
|
||||
chunks = [
|
||||
all_values[i : i + expected_column_count]
|
||||
for i in range(
|
||||
@@ -637,7 +625,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# Apply post-processing to fix known issues
|
||||
if reconstructed_rows and "Title" in columns:
|
||||
print("Applying post-processing to improve data quality")
|
||||
for row in reconstructed_rows:
|
||||
# Fix titles that might still have issues
|
||||
if (
|
||||
@@ -654,7 +641,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
chunk_results = reconstructed_rows
|
||||
else:
|
||||
# Process normal result structure as before
|
||||
print("Using standard processing mode")
|
||||
|
||||
# Check different result structures
|
||||
if (
|
||||
@@ -662,7 +648,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
and result.result.data_array
|
||||
):
|
||||
# Check if data appears to be malformed within chunks
|
||||
for chunk_idx, chunk in enumerate(result.result.data_array):
|
||||
for _chunk_idx, chunk in enumerate(
|
||||
result.result.data_array
|
||||
):
|
||||
# Check if chunk might actually contain individual columns of a single row
|
||||
# This is another way data might be malformed - check the first few values
|
||||
if len(chunk) > 0 and len(columns) > 1:
|
||||
@@ -676,10 +664,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
if (
|
||||
len(chunk) > len(columns) * 3
|
||||
): # Heuristic: if chunk has way more items than columns
|
||||
print(
|
||||
"Chunk appears to contain individual values rather than rows - switching to row reconstruction"
|
||||
)
|
||||
|
||||
# This chunk might actually be values of multiple rows - try to reconstruct
|
||||
values = chunk # All values in this chunk
|
||||
reconstructed_rows = []
|
||||
@@ -697,7 +681,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
row_dict = {
|
||||
col: val
|
||||
for col, val in zip(
|
||||
columns, row_values
|
||||
columns,
|
||||
row_values,
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
reconstructed_rows.append(row_dict)
|
||||
@@ -726,7 +712,9 @@ class DatabricksQueryTool(BaseTool):
|
||||
row_dict = {
|
||||
col: val
|
||||
for col, val in zip(
|
||||
columns, row_values
|
||||
columns,
|
||||
row_values,
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
chunk_results.append(row_dict)
|
||||
@@ -735,7 +723,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
continue
|
||||
|
||||
# Normal processing for typical row structure
|
||||
for row_idx, row in enumerate(chunk):
|
||||
for _row_idx, row in enumerate(chunk):
|
||||
# Ensure row is actually a collection of values
|
||||
if not isinstance(row, (list, tuple, dict)):
|
||||
# This might be a single value; skip it or handle specially
|
||||
@@ -771,7 +759,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
elif hasattr(result.result, "data") and result.result.data:
|
||||
# Alternative data structure
|
||||
|
||||
for row_idx, row in enumerate(result.result.data):
|
||||
for _row_idx, row in enumerate(result.result.data):
|
||||
# Debug info
|
||||
|
||||
# Safely create dictionary matching column names to values
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -20,10 +20,10 @@ class DirectoryReadTool(BaseTool):
|
||||
description: str = (
|
||||
"A tool that can be used to recursively list a directory's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DirectoryReadToolSchema
|
||||
directory: Optional[str] = None
|
||||
args_schema: type[BaseModel] = DirectoryReadToolSchema
|
||||
directory: str | None = None
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.directory = directory
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -26,9 +25,9 @@ class DirectorySearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a directory's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
args_schema: type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
@@ -42,7 +41,7 @@ class DirectorySearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
directory: Optional[str] = None,
|
||||
directory: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -9,7 +10,7 @@ from ..rag.rag_tool import RagTool
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
"""Input for DOCXSearchTool."""
|
||||
|
||||
docx: Optional[str] = Field(
|
||||
docx: str | None = Field(
|
||||
..., description="File path or URL of a DOCX file to be searched"
|
||||
)
|
||||
search_query: str = Field(
|
||||
@@ -32,9 +33,9 @@ class DOCXSearchTool(RagTool):
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a DOCX's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
||||
args_schema: type[BaseModel] = DOCXSearchToolSchema
|
||||
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
def __init__(self, docx: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
@@ -48,7 +49,7 @@ class DOCXSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
docx: str | None = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Any:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
try:
|
||||
@@ -18,53 +18,55 @@ class EXABaseToolSchema(BaseModel):
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to search the internet"
|
||||
)
|
||||
start_published_date: Optional[str] = Field(
|
||||
start_published_date: str | None = Field(
|
||||
None, description="Start date for the search"
|
||||
)
|
||||
end_published_date: Optional[str] = Field(
|
||||
None, description="End date for the search"
|
||||
)
|
||||
include_domains: Optional[list[str]] = Field(
|
||||
end_published_date: str | None = Field(None, description="End date for the search")
|
||||
include_domains: list[str] | None = Field(
|
||||
None, description="List of domains to include in the search"
|
||||
)
|
||||
|
||||
|
||||
class EXASearchTool(BaseTool):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
name: str = "EXASearchTool"
|
||||
description: str = "Search the internet using Exa"
|
||||
args_schema: Type[BaseModel] = EXABaseToolSchema
|
||||
args_schema: type[BaseModel] = EXABaseToolSchema
|
||||
client: Optional["Exa"] = None
|
||||
content: Optional[bool] = False
|
||||
summary: Optional[bool] = False
|
||||
type: Optional[str] = "auto"
|
||||
package_dependencies: List[str] = ["exa_py"]
|
||||
api_key: Optional[str] = Field(
|
||||
content: bool | None = False
|
||||
summary: bool | None = False
|
||||
type: str | None = "auto"
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["exa_py"])
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("EXA_API_KEY"),
|
||||
description="API key for Exa services",
|
||||
json_schema_extra={"required": False},
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
base_url: str | None = Field(
|
||||
default_factory=lambda: os.getenv("EXA_BASE_URL"),
|
||||
description="API server url",
|
||||
json_schema_extra={"required": False},
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="EXA_API_KEY", description="API key for Exa services", required=False
|
||||
),
|
||||
EnvVar(
|
||||
name="EXA_BASE_URL",
|
||||
description="API url for the Exa services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="EXA_API_KEY",
|
||||
description="API key for Exa services",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="EXA_BASE_URL",
|
||||
description="API url for the Exa services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[bool] = False,
|
||||
summary: Optional[bool] = False,
|
||||
type: Optional[str] = "auto",
|
||||
content: bool | None = False,
|
||||
summary: bool | None = False,
|
||||
type: str | None = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -78,7 +80,7 @@ class EXASearchTool(BaseTool):
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "exa_py"], check=True)
|
||||
subprocess.run(["uv", "add", "exa_py"], check=True) # noqa: S607
|
||||
|
||||
else:
|
||||
raise ImportError(
|
||||
@@ -95,9 +97,9 @@ class EXASearchTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
start_published_date: Optional[str] = None,
|
||||
end_published_date: Optional[str] = None,
|
||||
include_domains: Optional[list[str]] = None,
|
||||
start_published_date: str | None = None,
|
||||
end_published_date: str | None = None,
|
||||
include_domains: list[str] | None = None,
|
||||
) -> Any:
|
||||
if self.client is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -8,10 +8,10 @@ class FileReadToolSchema(BaseModel):
|
||||
"""Input for FileReadTool."""
|
||||
|
||||
file_path: str = Field(..., description="Mandatory file full path to read the file")
|
||||
start_line: Optional[int] = Field(
|
||||
start_line: int | None = Field(
|
||||
1, description="Line number to start reading from (1-indexed)"
|
||||
)
|
||||
line_count: Optional[int] = Field(
|
||||
line_count: int | None = Field(
|
||||
None, description="Number of lines to read. If None, reads the entire file"
|
||||
)
|
||||
|
||||
@@ -44,10 +44,10 @@ class FileReadTool(BaseTool):
|
||||
|
||||
name: str = "Read a file's content"
|
||||
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read. Optionally, provide 'start_line' to start reading from a specific line and 'line_count' to limit the number of lines read."
|
||||
args_schema: Type[BaseModel] = FileReadToolSchema
|
||||
file_path: Optional[str] = None
|
||||
args_schema: type[BaseModel] = FileReadToolSchema
|
||||
file_path: str | None = None
|
||||
|
||||
def __init__(self, file_path: Optional[str] = None, **kwargs: Any) -> None:
|
||||
def __init__(self, file_path: str | None = None, **kwargs: Any) -> None:
|
||||
"""Initialize the FileReadTool.
|
||||
|
||||
Args:
|
||||
@@ -65,9 +65,9 @@ class FileReadTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
file_path: Optional[str] = None,
|
||||
start_line: Optional[int] = 1,
|
||||
line_count: Optional[int] = None,
|
||||
file_path: str | None = None,
|
||||
start_line: int | None = 1,
|
||||
line_count: int | None = None,
|
||||
) -> str:
|
||||
file_path = file_path or self.file_path
|
||||
start_line = start_line or 1
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
@@ -18,7 +18,7 @@ def strtobool(val) -> bool:
|
||||
|
||||
class FileWriterToolInput(BaseModel):
|
||||
filename: str
|
||||
directory: Optional[str] = "./"
|
||||
directory: str | None = "./"
|
||||
overwrite: str | bool = False
|
||||
content: str
|
||||
|
||||
@@ -26,7 +26,7 @@ class FileWriterToolInput(BaseModel):
|
||||
class FileWriterTool(BaseTool):
|
||||
name: str = "File Writer Tool"
|
||||
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
args_schema: Type[BaseModel] = FileWriterToolInput
|
||||
args_schema: type[BaseModel] = FileWriterToolInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import tarfile
|
||||
from typing import Optional, Type
|
||||
import zipfile
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -13,7 +12,7 @@ class FileCompressorToolInput(BaseModel):
|
||||
input_path: str = Field(
|
||||
..., description="Path to the file or directory to compress."
|
||||
)
|
||||
output_path: Optional[str] = Field(
|
||||
output_path: str | None = Field(
|
||||
default=None, description="Optional output archive filename."
|
||||
)
|
||||
overwrite: bool = Field(
|
||||
@@ -32,12 +31,12 @@ class FileCompressorTool(BaseTool):
|
||||
"Compresses a file or directory into an archive (.zip currently supported). "
|
||||
"Useful for archiving logs, documents, or backups."
|
||||
)
|
||||
args_schema: Type[BaseModel] = FileCompressorToolInput
|
||||
args_schema: type[BaseModel] = FileCompressorToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: Optional[str] = None,
|
||||
output_path: str | None = None,
|
||||
overwrite: bool = False,
|
||||
format: str = "zip",
|
||||
) -> str:
|
||||
@@ -47,7 +46,7 @@ class FileCompressorTool(BaseTool):
|
||||
if not output_path:
|
||||
output_path = self._generate_output_path(input_path, format)
|
||||
|
||||
FORMAT_EXTENSION = {
|
||||
format_extension = {
|
||||
"zip": ".zip",
|
||||
"tar": ".tar",
|
||||
"tar.gz": ".tar.gz",
|
||||
@@ -55,10 +54,10 @@ class FileCompressorTool(BaseTool):
|
||||
"tar.xz": ".tar.xz",
|
||||
}
|
||||
|
||||
if format not in FORMAT_EXTENSION:
|
||||
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(FORMAT_EXTENSION.keys())}"
|
||||
if not output_path.endswith(FORMAT_EXTENSION[format]):
|
||||
return f"Error: If '{format}' format is chosen, output file must have a '{FORMAT_EXTENSION[format]}' extension."
|
||||
if format not in format_extension:
|
||||
return f"Compression format '{format}' is not supported. Allowed formats: {', '.join(format_extension.keys())}"
|
||||
if not output_path.endswith(format_extension[format]):
|
||||
return f"Error: If '{format}' format is chosen, output file must have a '{format_extension[format]}' extension."
|
||||
if not self._prepare_output(output_path, overwrite):
|
||||
return (
|
||||
f"Output '{output_path}' already exists and overwrite is set to False."
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -20,8 +20,7 @@ class FirecrawlCrawlWebsiteToolSchema(BaseModel):
|
||||
|
||||
|
||||
class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
"""
|
||||
Tool for crawling websites using Firecrawl. To run this tool, you need to have a Firecrawl API key.
|
||||
"""Tool for crawling websites using Firecrawl. To run this tool, you need to have a Firecrawl API key.
|
||||
|
||||
Args:
|
||||
api_key (str): Your Firecrawl API key.
|
||||
@@ -44,9 +43,9 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
)
|
||||
name: str = "Firecrawl web crawl tool"
|
||||
description: str = "Crawl webpages using Firecrawl and return the contents"
|
||||
args_schema: Type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
|
||||
api_key: Optional[str] = None
|
||||
config: Optional[dict[str, Any]] = Field(
|
||||
args_schema: type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"maxDepth": 2,
|
||||
"ignoreSitemap": True,
|
||||
@@ -61,16 +60,18 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
}
|
||||
)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="FIRECRAWL_API_KEY",
|
||||
description="API key for Firecrawl services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="FIRECRAWL_API_KEY",
|
||||
description="API key for Firecrawl services",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self._initialize_firecrawl()
|
||||
@@ -89,16 +90,16 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(["uv", "add", "firecrawl-py"], check=True)
|
||||
subprocess.run(["uv", "add", "firecrawl-py"], check=True) # noqa: S607
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
self._firecrawl = FirecrawlApp(api_key=self.api_key)
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install firecrawl-py package")
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise ImportError("Failed to install firecrawl-py package") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"`firecrawl-py` package not found, please run `uv add firecrawl-py`"
|
||||
)
|
||||
) from None
|
||||
|
||||
def _run(self, url: str):
|
||||
if not self._firecrawl:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user