fix: resolve all strict mypy errors across crewai-tools package

This commit is contained in:
Greyson LaLonde
2026-03-25 13:11:54 +08:00
committed by GitHub
parent 8a1424534e
commit 26953c88c2
109 changed files with 857 additions and 560 deletions

View File

@@ -136,7 +136,7 @@ class EnterpriseActionTool(BaseTool):
enum_values = schema["enum"]
if not enum_values:
return self._map_json_type_to_python(json_type)
return Literal[tuple(enum_values)] # type: ignore[return-value]
return Literal[tuple(enum_values)]
if json_type == "array":
items_schema = schema.get("items", {"type": "string"})
@@ -155,7 +155,7 @@ class EnterpriseActionTool(BaseTool):
full_model_name = f"{self._base_name}{model_name}"
if full_model_name in self._model_registry:
return self._model_registry[full_model_name]
return cast(type[Any], self._model_registry[full_model_name])
properties = schema.get("properties", {})
required_fields = schema.get("required", [])
@@ -178,19 +178,19 @@ class EnterpriseActionTool(BaseTool):
field_definitions[prop_name] = self._create_field_definition(
prop_type,
is_required,
prop_desc, # type: ignore[arg-type]
prop_desc,
)
try:
nested_model = create_model(full_model_name, **field_definitions) # type: ignore[call-overload]
self._model_registry[full_model_name] = nested_model
return nested_model
return cast(type[Any], nested_model)
except Exception:
return dict
def _create_field_definition(
self, field_type: type[Any] | _SpecialForm, is_required: bool, description: str
) -> tuple:
) -> tuple[type[Any] | _SpecialForm, Any]:
"""Create Pydantic field definition based on type and requirement."""
if is_required:
return (field_type, Field(description=description))
@@ -232,7 +232,7 @@ class EnterpriseActionTool(BaseTool):
return any(t.get("type") == "null" for t in schema["anyOf"])
return schema.get("type") == "null"
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
"""Execute the specific enterprise action with validated parameters."""
try:
cleaned_kwargs = {}
@@ -280,8 +280,8 @@ class EnterpriseActionKitToolAdapter:
):
"""Initialize the adapter with an enterprise action token."""
self._set_enterprise_action_token(enterprise_action_token)
self._actions_schema = {} # type: ignore[var-annotated]
self._tools = None
self._actions_schema: dict[str, Any] = {}
self._tools: list[BaseTool] | None = None
self.enterprise_api_base_url = (
enterprise_api_base_url or get_enterprise_api_base_url()
)
@@ -293,7 +293,7 @@ class EnterpriseActionKitToolAdapter:
self._create_tools()
return self._tools or []
def _fetch_actions(self):
def _fetch_actions(self) -> None:
"""Fetch available actions from the API."""
try:
actions_url = f"{self.enterprise_api_base_url}/actions"
@@ -379,9 +379,9 @@ class EnterpriseActionKitToolAdapter:
return descriptions
def _create_tools(self):
def _create_tools(self) -> None:
"""Create BaseTool instances for each action."""
tools = []
tools: list[BaseTool] = []
for action_name, action_schema in self._actions_schema.items():
function_details = action_schema.get("function", {})
@@ -403,7 +403,7 @@ class EnterpriseActionKitToolAdapter:
description=full_description,
action_name=action_name,
action_schema=action_schema,
enterprise_action_token=self.enterprise_action_token,
enterprise_action_token=self.enterprise_action_token or "",
enterprise_api_base_url=self.enterprise_api_base_url,
)
@@ -411,7 +411,7 @@ class EnterpriseActionKitToolAdapter:
self._tools = tools
def _set_enterprise_action_token(self, enterprise_action_token: str | None):
def _set_enterprise_action_token(self, enterprise_action_token: str | None) -> None:
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
warnings.warn(
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
@@ -423,10 +423,15 @@ class EnterpriseActionKitToolAdapter:
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
)
self.enterprise_action_token = token
self.enterprise_action_token: str | None = token
def __enter__(self):
def __enter__(self) -> list[BaseTool]:
return self.tools()
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
pass

View File

@@ -5,20 +5,19 @@ from typing import Any
from crewai.utilities.lock_store import lock as store_lock
from lancedb import ( # type: ignore[import-untyped]
DBConnection as LanceDBConnection,
connect as lancedb_connect,
)
from lancedb.table import Table as LanceDBTable # type: ignore[import-untyped]
from openai import Client as OpenAIClient
from pydantic import Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
def _default_embedding_function():
def _default_embedding_function() -> Callable[[list[str]], list[list[float]]]:
"""Create a default embedding function using OpenAI."""
client = OpenAIClient()
def _embedding_function(input):
def _embedding_function(input: list[str]) -> list[list[float]]:
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
return [record.embedding for record in rs.data]
@@ -28,13 +27,15 @@ def _default_embedding_function():
class LanceDBAdapter(Adapter):
uri: str | Path
table_name: str
embedding_function: Callable = Field(default_factory=_default_embedding_function)
embedding_function: Callable[[list[str]], list[list[float]]] = Field(
default_factory=_default_embedding_function
)
top_k: int = 3
vector_column_name: str = "vector"
text_column_name: str = "text"
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
_db: Any = PrivateAttr()
_table: Any = PrivateAttr()
_lock_name: str = PrivateAttr(default="")
def model_post_init(self, __context: Any) -> None:

View File

@@ -12,7 +12,7 @@ class RAGAdapter(Adapter):
embedding_model: str = "text-embedding-3-small",
top_k: int = 5,
embedding_api_key: str | None = None,
**embedding_kwargs,
**embedding_kwargs: Any,
):
super().__init__()

View File

@@ -9,7 +9,7 @@ from crewai.tools import BaseTool
T = TypeVar("T", bound=BaseTool)
class ToolCollection(list, Generic[T]):
class ToolCollection(list[T], Generic[T]):
"""A collection of tools that can be accessed by index or name.
This class extends the built-in list to provide dictionary-like
@@ -34,7 +34,8 @@ class ToolCollection(list, Generic[T]):
def __getitem__(self, key: int | str) -> T: # type: ignore[override]
if isinstance(key, str):
return self._name_cache[key.lower()]
return super().__getitem__(key)
result: T = super().__getitem__(key)
return result
def append(self, tool: T) -> None:
super().append(tool)
@@ -54,7 +55,7 @@ class ToolCollection(list, Generic[T]):
del self._name_cache[tool.name.lower()]
def pop(self, index: int = -1) -> T: # type: ignore[override]
tool = super().pop(index)
tool: T = super().pop(index)
if tool.name.lower() in self._name_cache:
del self._name_cache[tool.name.lower()]
return tool

View File

@@ -1,6 +1,6 @@
import logging
import os
from typing import Final, Literal
from typing import Any, Final, Literal
from crewai.tools import BaseTool
from pydantic import Field, create_model
@@ -22,7 +22,7 @@ class ZapierActionTool(BaseTool):
action_id: str = Field(description="Zapier action ID")
api_key: str = Field(description="Zapier API key")
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> Any:
"""Execute the Zapier action."""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
@@ -64,9 +64,9 @@ class ZapierActionsAdapter:
logger.error("Zapier Actions API key is required")
raise ValueError("Zapier Actions API key is required")
def get_zapier_actions(self):
def get_zapier_actions(self) -> Any:
headers = {
"x-api-key": self.api_key,
"x-api-key": self.api_key or "",
}
response = requests.request(
"GET",

View File

@@ -2,6 +2,7 @@ from datetime import datetime, timezone
import json
import os
import time
from typing import Any
from crewai.tools import BaseTool
from dotenv import load_dotenv
@@ -42,17 +43,17 @@ class BedrockInvokeAgentTool(BaseTool):
enable_trace: bool = False,
end_session: bool = False,
description: str | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize the BedrockInvokeAgentTool with agent configuration.
Args:
agent_id (str): The unique identifier of the Bedrock agent
agent_alias_id (str): The unique identifier of the agent alias
session_id (str): The unique identifier of the session
enable_trace (bool): Whether to enable trace for the agent invocation
end_session (bool): Whether to end the session with the agent
description (Optional[str]): Custom description for the tool
agent_id: The unique identifier of the Bedrock agent.
agent_alias_id: The unique identifier of the agent alias.
session_id: The unique identifier of the session.
enable_trace: Whether to enable trace for the agent invocation.
end_session: Whether to end the session with the agent.
description: Custom description for the tool.
"""
super().__init__(**kwargs)
@@ -72,7 +73,7 @@ class BedrockInvokeAgentTool(BaseTool):
# Validate parameters
self._validate_parameters()
def _validate_parameters(self):
def _validate_parameters(self) -> None:
"""Validate the parameters according to AWS API requirements."""
try:
# Validate agent_id

View File

@@ -3,7 +3,7 @@
import asyncio
import json
import logging
from typing import Any
from typing import Any, cast
from urllib.parse import urlparse
from crewai.tools import BaseTool
@@ -82,7 +82,7 @@ class CurrentWebPageToolInput(BaseModel):
class BrowserBaseTool(BaseTool):
"""Base class for browser tools."""
def __init__(self, session_manager: BrowserSessionManager): # type: ignore[call-arg]
def __init__(self, session_manager: BrowserSessionManager) -> None:
"""Initialize with a session manager."""
super().__init__() # type: ignore[call-arg]
self._session_manager = session_manager
@@ -90,16 +90,16 @@ class BrowserBaseTool(BaseTool):
if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
self._original_run = self._run
# Override _run to use _arun when in an asyncio loop
def patched_run(*args, **kwargs):
def patched_run(*args: Any, **kwargs: Any) -> str:
try:
import nest_asyncio # type: ignore[import-untyped]
loop = asyncio.get_event_loop()
nest_asyncio.apply(loop)
return asyncio.get_event_loop().run_until_complete(
result: str = asyncio.get_event_loop().run_until_complete(
self._arun(*args, **kwargs)
)
return result
except Exception as e:
return f"Error in patched _run: {e!s}"
@@ -132,7 +132,7 @@ class NavigateTool(BrowserBaseTool):
description: str = "Navigate a browser to the specified URL"
args_schema: type[BaseModel] = NavigateToolInput
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
def _run(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get page for this thread
@@ -150,7 +150,7 @@ class NavigateTool(BrowserBaseTool):
except Exception as e:
return f"Error navigating to {url}: {e!s}"
async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str:
async def _arun(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get page for this thread
@@ -188,7 +188,7 @@ class ClickTool(BrowserBaseTool):
return selector
return f"{selector} >> visible=1"
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
@@ -213,7 +213,9 @@ class ClickTool(BrowserBaseTool):
except Exception as e:
return f"Error clicking on element: {e!s}"
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
async def _arun(
self, selector: str, thread_id: str = "default", **kwargs: Any
) -> str:
"""Use the async tool."""
try:
# Get the current page
@@ -246,7 +248,7 @@ class NavigateBackTool(BrowserBaseTool):
description: str = "Navigate back to the previous page"
args_schema: type[BaseModel] = NavigateBackToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str:
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
@@ -261,7 +263,7 @@ class NavigateBackTool(BrowserBaseTool):
except Exception as e:
return f"Error navigating back: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get the current page
@@ -284,7 +286,7 @@ class ExtractTextTool(BrowserBaseTool):
description: str = "Extract all the text on the current webpage"
args_schema: type[BaseModel] = ExtractTextToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str:
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Import BeautifulSoup
@@ -306,7 +308,7 @@ class ExtractTextTool(BrowserBaseTool):
except Exception as e:
return f"Error extracting text: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Import BeautifulSoup
@@ -336,12 +338,12 @@ class ExtractHyperlinksTool(BrowserBaseTool):
description: str = "Extract all hyperlinks on the current webpage"
args_schema: type[BaseModel] = ExtractHyperlinksToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str:
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
except ImportError:
return (
"The 'beautifulsoup4' package is required to use this tool."
@@ -356,9 +358,10 @@ class ExtractHyperlinksTool(BrowserBaseTool):
soup = BeautifulSoup(content, "html.parser")
links = []
for link in soup.find_all("a", href=True):
text = link.get_text().strip()
href = link["href"]
if href.startswith(("http", "https")): # type: ignore[union-attr]
tag = cast(Tag, link)
text = tag.get_text().strip()
href = str(tag.get("href", ""))
if href.startswith(("http", "https")):
links.append({"text": text, "url": href})
if not links:
@@ -368,12 +371,12 @@ class ExtractHyperlinksTool(BrowserBaseTool):
except Exception as e:
return f"Error extracting hyperlinks: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Import BeautifulSoup
try:
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
except ImportError:
return (
"The 'beautifulsoup4' package is required to use this tool."
@@ -388,9 +391,10 @@ class ExtractHyperlinksTool(BrowserBaseTool):
soup = BeautifulSoup(content, "html.parser")
links = []
for link in soup.find_all("a", href=True):
text = link.get_text().strip()
href = link["href"]
if href.startswith(("http", "https")): # type: ignore[union-attr]
tag = cast(Tag, link)
text = tag.get_text().strip()
href = str(tag.get("href", ""))
if href.startswith(("http", "https")):
links.append({"text": text, "url": href})
if not links:
@@ -408,7 +412,7 @@ class GetElementsTool(BrowserBaseTool):
description: str = "Get elements from the webpage using a CSS selector"
args_schema: type[BaseModel] = GetElementsToolInput
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
@@ -428,7 +432,9 @@ class GetElementsTool(BrowserBaseTool):
except Exception as e:
return f"Error getting elements: {e!s}"
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
async def _arun(
self, selector: str, thread_id: str = "default", **kwargs: Any
) -> str:
"""Use the async tool."""
try:
# Get the current page
@@ -456,7 +462,7 @@ class CurrentWebPageTool(BrowserBaseTool):
description: str = "Get information about the current webpage"
args_schema: type[BaseModel] = CurrentWebPageToolInput
def _run(self, thread_id: str = "default", **kwargs) -> str:
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the sync tool."""
try:
# Get the current page
@@ -469,7 +475,7 @@ class CurrentWebPageTool(BrowserBaseTool):
except Exception as e:
return f"Error getting current webpage info: {e!s}"
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
"""Use the async tool."""
try:
# Get the current page
@@ -535,7 +541,7 @@ class BrowserToolkit:
self._nest_current_loop()
self._setup_tools()
def _nest_current_loop(self):
def _nest_current_loop(self) -> None:
"""Apply nest_asyncio if we're in an asyncio loop."""
try:
loop = asyncio.get_event_loop()

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def extract_output_from_stream(response):
def extract_output_from_stream(response: dict[str, Any]) -> str:
"""Extract output from code interpreter response stream.
Args:
@@ -143,8 +143,8 @@ class ExecuteCodeTool(BaseTool):
args_schema: type[BaseModel] = ExecuteCodeInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(
@@ -198,8 +198,8 @@ class ExecuteCommandTool(BaseTool):
args_schema: type[BaseModel] = ExecuteCommandInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, command: str, thread_id: str = "default") -> str:
@@ -231,8 +231,8 @@ class ReadFilesTool(BaseTool):
args_schema: type[BaseModel] = ReadFilesInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, paths: list[str], thread_id: str = "default") -> str:
@@ -264,8 +264,8 @@ class ListFilesTool(BaseTool):
args_schema: type[BaseModel] = ListFilesInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
@@ -297,8 +297,8 @@ class DeleteFilesTool(BaseTool):
args_schema: type[BaseModel] = DeleteFilesInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, paths: list[str], thread_id: str = "default") -> str:
@@ -330,8 +330,8 @@ class WriteFilesTool(BaseTool):
args_schema: type[BaseModel] = WriteFilesInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
@@ -365,8 +365,8 @@ class StartCommandTool(BaseTool):
args_schema: type[BaseModel] = StartCommandInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, command: str, thread_id: str = "default") -> str:
@@ -398,8 +398,8 @@ class GetTaskTool(BaseTool):
args_schema: type[BaseModel] = GetTaskInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, task_id: str, thread_id: str = "default") -> str:
@@ -431,8 +431,8 @@ class StopTaskTool(BaseTool):
args_schema: type[BaseModel] = StopTaskInput
toolkit: Any = Field(default=None, exclude=True)
def __init__(self, toolkit):
super().__init__()
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.toolkit = toolkit
def _run(self, task_id: str, thread_id: str = "default") -> str:

View File

@@ -44,16 +44,16 @@ class BedrockKBRetrieverTool(BaseTool):
retrieval_configuration: dict[str, Any] | None = None,
guardrail_configuration: dict[str, Any] | None = None,
next_token: str | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
Args:
knowledge_base_id (str): The unique identifier of the knowledge base to query
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
knowledge_base_id: The unique identifier of the knowledge base to query.
number_of_results: The maximum number of results to return.
retrieval_configuration: Configurations for the knowledge base query and retrieval process.
guardrail_configuration: Guardrail settings.
next_token: Token for retrieving the next batch of results.
"""
super().__init__(**kwargs)
@@ -89,7 +89,7 @@ class BedrockKBRetrieverTool(BaseTool):
return {"vectorSearchConfiguration": vector_search_config}
def _validate_parameters(self):
def _validate_parameters(self) -> None:
"""Validate the parameters according to AWS API requirements."""
try:
# Validate knowledge_base_id

View File

@@ -39,11 +39,12 @@ class S3ReaderTool(BaseTool):
# Read file content from S3
response = s3.get_object(Bucket=bucket_name, Key=object_key)
return response["Body"].read().decode("utf-8")
result: str = response["Body"].read().decode("utf-8")
return result
except ClientError as e:
return f"Error reading file from S3: {e!s}"
def _parse_s3_path(self, file_path: str) -> tuple:
def _parse_s3_path(self, file_path: str) -> tuple[str, str]:
parts = file_path.replace("s3://", "").split("/", 1)
return parts[0], parts[1]

View File

@@ -45,6 +45,6 @@ class S3WriterTool(BaseTool):
except ClientError as e:
return f"Error writing file to S3: {e!s}"
def _parse_s3_path(self, file_path: str) -> tuple:
def _parse_s3_path(self, file_path: str) -> tuple[str, str]:
parts = file_path.replace("s3://", "").split("/", 1)
return parts[0], parts[1]

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python3
from collections.abc import Mapping
from collections.abc import Callable, Mapping
import inspect
import json
from pathlib import Path
from typing import Any
from typing import Any, cast
from crewai.tools.base_tool import BaseTool, EnvVar
from pydantic import BaseModel
@@ -115,7 +115,8 @@ class ToolSpecExtractor:
default_value = field.default
if default_value is PydanticUndefined or default_value is None:
if field.default_factory:
return field.default_factory()
factory = cast(Callable[[], Any], field.default_factory)
return factory()
return None
return default_value

View File

@@ -1,5 +1,6 @@
from crewai_tools.rag.core import RAG, EmbeddingService
from crewai_tools.rag.core import RAG
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.embedding_service import EmbeddingService
__all__ = [

View File

@@ -21,7 +21,7 @@ class BaseLoader(ABC):
self.config = config or {}
@abstractmethod
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
def load(self, content: SourceContent, **kwargs: Any) -> LoaderResult: ...
@staticmethod
def generate_doc_id(

View File

@@ -77,7 +77,7 @@ class RAG(Adapter):
super().model_post_init(__context)
def add(
def add( # type: ignore[override]
self,
content: str | Path,
data_type: str | DataType | None = None,

View File

@@ -9,6 +9,7 @@ import logging
import os
from typing import Any
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from pydantic import BaseModel, Field
@@ -81,7 +82,7 @@ class EmbeddingService:
**kwargs,
)
self._embedding_function = None
self._embedding_function: EmbeddingFunction[Any] | None = None
self._initialize_embedding_function()
@staticmethod
@@ -107,7 +108,7 @@ class EmbeddingService:
return os.getenv(env_key)
return None
def _initialize_embedding_function(self):
def _initialize_embedding_function(self) -> None:
"""Initialize the embedding function using CrewAI's factory."""
try:
from crewai.rag.embeddings.factory import build_embedder
@@ -264,7 +265,7 @@ class EmbeddingService:
try:
# Use ChromaDB's embedding function interface
embeddings = self._embedding_function([text]) # type: ignore
return embeddings[0] if embeddings else []
return list(embeddings[0]) if embeddings else []
except Exception as e:
logger.error(f"Error generating embedding for text: {e}")
@@ -294,12 +295,12 @@ class EmbeddingService:
try:
# Process in batches to avoid API limits
all_embeddings = []
all_embeddings: list[list[float]] = []
for i in range(0, len(valid_texts), self.config.batch_size):
batch = valid_texts[i : i + self.config.batch_size]
batch_embeddings = self._embedding_function(batch) # type: ignore
all_embeddings.extend(batch_embeddings)
all_embeddings.extend(list(e) for e in batch_embeddings)
return all_embeddings

View File

@@ -1,5 +1,6 @@
import csv
from io import StringIO
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.loaders.utils import load_from_url
@@ -7,7 +8,7 @@ from crewai_tools.rag.source_content import SourceContent
class CSVLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
source_ref = source_content.source_ref
content_str = source_content.source

View File

@@ -1,12 +1,13 @@
import os
from pathlib import Path
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class DirectoryLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load and process all files from a directory recursively.
Args:
@@ -32,7 +33,7 @@ class DirectoryLoader(BaseLoader):
return self._process_directory(source_ref, kwargs)
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
def _process_directory(self, dir_path: str, kwargs: dict[str, Any]) -> LoaderResult:
recursive: bool = kwargs.get("recursive", True)
include_extensions: list[str] | None = kwargs.get("include_extensions", None)
exclude_extensions: list[str] | None = kwargs.get("exclude_extensions", None)

View File

@@ -1,8 +1,9 @@
"""Documentation site loader."""
from typing import Any
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
@@ -12,7 +13,7 @@ from crewai_tools.rag.source_content import SourceContent
class DocsSiteLoader(BaseLoader):
"""Loader for documentation websites."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load content from a documentation site.
Args:
@@ -53,7 +54,9 @@ class DocsSiteLoader(BaseLoader):
break
if not main_content:
main_content = soup.find("body")
body = soup.find("body")
if isinstance(body, Tag):
main_content = body
if not main_content:
raise ValueError(
@@ -66,6 +69,8 @@ class DocsSiteLoader(BaseLoader):
if headings:
text_parts.append("Table of Contents:")
for heading in headings[:15]:
if not isinstance(heading, Tag):
continue
level = int(heading.name[1])
indent = " " * (level - 1)
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
@@ -81,6 +86,8 @@ class DocsSiteLoader(BaseLoader):
if nav:
links = nav.find_all("a", href=True)
for link in links[:20]:
if not isinstance(link, Tag):
continue
href = link.get("href", "")
if isinstance(href, str) and not href.startswith(
("http://", "https://", "mailto:", "#")

View File

@@ -9,7 +9,7 @@ from crewai_tools.rag.source_content import SourceContent
class DOCXLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
try:
from docx import Document as DocxDocument
except ImportError as e:
@@ -33,7 +33,7 @@ class DOCXLoader(BaseLoader):
)
@staticmethod
def _download_from_url(url: str, kwargs: dict) -> str:
def _download_from_url(url: str, kwargs: dict[str, Any]) -> str:
headers = kwargs.get(
"headers",
{

View File

@@ -1,5 +1,7 @@
"""GitHub repository content loader."""
from typing import Any
from github import Github, GithubException
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
@@ -9,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
class GithubLoader(BaseLoader):
"""Loader for GitHub repository content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load content from a GitHub repository.
Args:

View File

@@ -1,4 +1,5 @@
import json
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.loaders.utils import load_from_url
@@ -6,7 +7,7 @@ from crewai_tools.rag.source_content import SourceContent
class JSONLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
source_ref = source_content.source_ref
content = source_content.source

View File

@@ -1,5 +1,5 @@
import re
from typing import Final
from typing import Any, Final
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.loaders.utils import load_from_url
@@ -15,7 +15,7 @@ _EXTRA_NEWLINES_PATTERN: Final[re.Pattern[str]] = re.compile(r"\n\s*\n\s*\n")
class MDXLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
source_ref = source_content.source_ref
content = source_content.source

View File

@@ -1,8 +1,8 @@
"""PDF loader for extracting text from PDF files."""
import os
import tempfile
from pathlib import Path
import tempfile
from typing import Any
from urllib.parse import urlparse
@@ -25,7 +25,7 @@ class PDFLoader(BaseLoader):
return False
@staticmethod
def _download_from_url(url: str, kwargs: dict) -> str:
def _download_from_url(url: str, kwargs: dict[str, Any]) -> str:
"""Download PDF from a URL to a temporary file and return its path.
Args:

View File

@@ -1,5 +1,6 @@
"""PostgreSQL database loader."""
from typing import Any
from urllib.parse import urlparse
from psycopg2 import Error, connect
@@ -12,7 +13,7 @@ from crewai_tools.rag.source_content import SourceContent
class PostgresLoader(BaseLoader):
"""Loader for PostgreSQL database content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load content from a PostgreSQL database table.
Args:

View File

@@ -1,9 +1,11 @@
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TextFileLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
source_ref = source_content.source_ref
if not source_content.path_exists():
raise FileNotFoundError(
@@ -21,7 +23,7 @@ class TextFileLoader(BaseLoader):
class TextLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
return LoaderResult(
content=source_content.source,
source=source_content.source_ref,

View File

@@ -1,8 +1,13 @@
"""Utility functions for RAG loaders."""
from typing import Any
def load_from_url(
url: str, kwargs: dict, accept_header: str = "*/*", loader_name: str = "Loader"
url: str,
kwargs: dict[str, Any],
accept_header: str = "*/*",
loader_name: str = "Loader",
) -> str:
"""Load content from a URL.

View File

@@ -1,5 +1,5 @@
import re
from typing import Final
from typing import Any, Final
from bs4 import BeautifulSoup
import requests
@@ -13,7 +13,7 @@ _NEWLINE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\s+\n\s+")
class WebPageLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
url = source_content.source
headers = kwargs.get(
"headers",

View File

@@ -10,7 +10,7 @@ from crewai_tools.rag.source_content import SourceContent
class YoutubeChannelLoader(BaseLoader):
"""Loader for YouTube channels."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load and extract content from a YouTube channel.
Args:

View File

@@ -11,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
class YoutubeVideoLoader(BaseLoader):
"""Loader for YouTube videos."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
"""Load and extract transcript from a YouTube video.
Args:

View File

@@ -42,7 +42,7 @@ class AIMindTool(BaseTool):
]
)
def __init__(self, api_key: str | None = None, **kwargs):
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.api_key = api_key or os.getenv("MINDS_API_KEY")
if not self.api_key:
@@ -51,8 +51,10 @@ class AIMindTool(BaseTool):
)
try:
from minds.client import Client # type: ignore
from minds.datasources import DatabaseConfig # type: ignore
from minds.client import Client # type: ignore[import-not-found]
from minds.datasources import ( # type: ignore[import-not-found]
DatabaseConfig,
)
except ImportError as e:
raise ImportError(
"`minds_sdk` package not found, please run `pip install minds-sdk`"
@@ -81,7 +83,7 @@ class AIMindTool(BaseTool):
self.mind_name = mind.name
def _run(self, query: str):
def _run(self, query: str) -> str | None:
# Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
openai_client = OpenAI(

View File

@@ -2,7 +2,7 @@ import logging
from pathlib import Path
import re
import time
from typing import ClassVar
from typing import Any, ClassVar
import urllib.error
import urllib.parse
import urllib.request
@@ -75,7 +75,9 @@ class ArxivPaperTool(BaseTool):
logger.error(f"ArxivTool Error: {e!s}")
return f"Failed to fetch or download Arxiv papers: {e!s}"
def fetch_arxiv_data(self, search_query: str, max_results: int) -> list[dict]:
def fetch_arxiv_data(
self, search_query: str, max_results: int
) -> list[dict[str, Any]]:
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
logger.info(f"Fetching data from Arxiv API: {api_url}")
@@ -135,7 +137,7 @@ class ArxivPaperTool(BaseTool):
return href
return None
def _format_paper_result(self, paper: dict) -> str:
def _format_paper_result(self, paper: dict[str, Any]) -> str:
summary = (
(paper["summary"][: self.SUMMARY_TRUNCATE_LENGTH] + "...")
if len(paper["summary"]) > self.SUMMARY_TRUNCATE_LENGTH
@@ -156,7 +158,7 @@ class ArxivPaperTool(BaseTool):
save_path.mkdir(parents=True, exist_ok=True)
return save_path
def download_pdf(self, pdf_url: str, save_path: str):
def download_pdf(self, pdf_url: str, save_path: str) -> None:
try:
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
urllib.request.urlretrieve(pdf_url, str(save_path)) # noqa: S310

View File

@@ -138,7 +138,7 @@ class BraveSearchToolBase(BaseTool, ABC):
self._rate_limit_lock = threading.Lock()
@property
def api_key(self) -> str:
def api_key(self) -> str | None:
return self._api_key
@property
@@ -214,7 +214,8 @@ class BraveSearchToolBase(BaseTool, ABC):
# Response was OK, return the JSON body
if resp.ok:
try:
return resp.json()
result: dict[str, Any] = resp.json()
return result
except ValueError as exc:
raise RuntimeError(
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
@@ -239,9 +240,9 @@ class BraveSearchToolBase(BaseTool, ABC):
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
_raise_for_error(resp)
# All retries exhausted
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
return {} # unreachable (here to satisfy the type checker and linter)
# All retries exhausted — last_resp is always set when we reach here
_raise_for_error(last_resp or resp)
return {} # unreachable; satisfies return type
def _run(self, q: str | None = None, **params: Any) -> Any:
# Allow positional usage: tool.run("latest Brave browser features")

View File

@@ -3,7 +3,6 @@ from typing import Any
from pydantic import BaseModel
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
from crewai_tools.tools.brave_search_tool.schemas import (
LLMContextHeaders,
LLMContextParams,
@@ -27,6 +26,6 @@ class BraveLLMContextTool(BraveSearchToolBase):
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
def _refine_response(self, response: dict[str, Any]) -> Any:
"""The LLM Context response schema is fairly simple. Return as is."""
return response

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, cast
from pydantic import BaseModel
@@ -65,8 +65,8 @@ class BraveLocalPOIsTool(BraveSearchToolBase):
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
results = response.get("results", [])
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = response.get("results", [])
return [
{
"title": result.get("title"),
@@ -76,7 +76,7 @@ class BraveLocalPOIsTool(BraveSearchToolBase):
"contact": result.get("contact", {}).get("telephone")
or result.get("contact", {}).get("email")
or None,
"opening_hours": _simplify_opening_hours(result),
"opening_hours": _simplify_opening_hours(cast(LocationResult, result)),
}
for result in results
]
@@ -97,9 +97,8 @@ class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
return params
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
# Make the response more concise, and easier to consume
results = response.get("results", [])
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = response.get("results", [])
return [
{
"id": result.get("id"),

View File

@@ -50,7 +50,7 @@ class BraveSearchTool(BaseTool):
_last_request_time: ClassVar[float] = 0
_min_request_interval: ClassVar[float] = 1.0 # seconds
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if "BRAVE_API_KEY" not in os.environ:
raise ValueError(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import os
from typing import Any
@@ -13,7 +15,7 @@ class BrightDataConfig(BaseModel):
DEFAULT_POLLING_INTERVAL: int = 1
@classmethod
def from_env(cls):
def from_env(cls) -> BrightDataConfig:
return cls(
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
@@ -26,12 +28,12 @@ class BrightDataConfig(BaseModel):
class BrightDataDatasetToolException(Exception): # noqa: N818
"""Exception raised for custom error in the application."""
def __init__(self, message, error_code):
def __init__(self, message: str, error_code: int) -> None:
self.message = message
super().__init__(message)
self.error_code = error_code
def __str__(self):
def __str__(self) -> str:
return f"{self.message} (Error Code: {self.error_code})"
@@ -62,7 +64,7 @@ config = BrightDataConfig.from_env()
BRIGHTDATA_API_URL = config.API_URL
timeout = config.DEFAULT_TIMEOUT
datasets = [
datasets: list[dict[str, Any]] = [
{
"id": "amazon_product",
"dataset_id": "gd_l7q7dkf244hwjntr0",
@@ -440,7 +442,7 @@ class BrightDataDatasetTool(BaseTool):
self.zipcode = zipcode
self.additional_params = additional_params
def filter_dataset_by_id(self, target_id):
def filter_dataset_by_id(self, target_id: str) -> list[dict[str, Any]]:
return [dataset for dataset in datasets if dataset["id"] == target_id]
async def get_dataset_data_async(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
from typing import Any
import urllib.parse
@@ -11,7 +13,7 @@ class BrightDataConfig(BaseModel):
API_URL: str = "https://api.brightdata.com/request"
@classmethod
def from_env(cls):
def from_env(cls) -> BrightDataConfig:
return cls(
API_URL=os.environ.get(
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
@@ -127,7 +129,7 @@ class BrightDataSearchTool(BaseTool):
if not self.zone:
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
def get_search_url(self, engine: str, query: str):
def get_search_url(self, engine: str, query: str) -> str:
if engine == "yandex":
return f"https://yandex.com/search/?text=${query}"
if engine == "bing":
@@ -143,7 +145,7 @@ class BrightDataSearchTool(BaseTool):
search_type: str | None = None,
device_type: str | None = None,
parse_results: bool | None = None,
**kwargs,
**kwargs: Any,
) -> Any:
"""Executes a search query using Bright Data SERP API and returns results.

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
from typing import Any
@@ -10,7 +12,7 @@ class BrightDataConfig(BaseModel):
API_URL: str = "https://api.brightdata.com/request"
@classmethod
def from_env(cls):
def from_env(cls) -> BrightDataConfig:
return cls(
API_URL=os.environ.get(
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"

View File

@@ -42,15 +42,15 @@ class BrowserbaseLoadTool(BaseTool):
text_content: bool | None = False,
session_id: str | None = None,
proxy: bool | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not self.api_key:
raise EnvironmentError(
"BROWSERBASE_API_KEY environment variable is required for initialization"
)
try:
from browserbase import Browserbase # type: ignore
from browserbase import Browserbase
except ImportError:
import click
@@ -60,7 +60,7 @@ class BrowserbaseLoadTool(BaseTool):
import subprocess
subprocess.run(["uv", "add", "browserbase"], check=True) # noqa: S607
from browserbase import Browserbase # type: ignore
from browserbase import Browserbase
else:
raise ImportError(
"`browserbase` package not found, please run `uv add browserbase`"
@@ -71,7 +71,7 @@ class BrowserbaseLoadTool(BaseTool):
self.session_id = session_id
self.proxy = proxy
def _run(self, url: str):
def _run(self, url: str) -> Any:
return self.browserbase.load_url( # type: ignore[union-attr]
url, self.text_content, self.session_id, self.proxy
)

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
@@ -26,7 +28,7 @@ class CodeDocsSearchTool(RagTool):
)
args_schema: type[BaseModel] = CodeDocsSearchToolSchema
def __init__(self, docs_url: str | None = None, **kwargs):
def __init__(self, docs_url: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if docs_url is not None:
self.add(docs_url)
@@ -34,7 +36,7 @@ class CodeDocsSearchTool(RagTool):
self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description()
def add(self, docs_url: str) -> None:
def add(self, docs_url: str) -> None: # type: ignore[override]
super().add(docs_url, data_type=DataType.DOCS_SITE)
def _run( # type: ignore[override]

View File

@@ -18,7 +18,6 @@ from docker import ( # type: ignore[import-untyped]
from_env as docker_from_env,
)
from docker.errors import ImageNotFound, NotFound # type: ignore[import-untyped]
from docker.models.containers import Container # type: ignore[import-untyped]
from pydantic import BaseModel, Field
from typing_extensions import Unpack
@@ -232,7 +231,7 @@ class CodeInterpreterTool(BaseTool):
return self.run_code_safety(code, libraries_used)
@staticmethod
def _install_libraries(container: Container, libraries: list[str]) -> None:
def _install_libraries(container: Any, libraries: list[str]) -> None:
"""Installs required Python libraries in the Docker container.
Args:
@@ -242,7 +241,7 @@ class CodeInterpreterTool(BaseTool):
for library in libraries:
container.exec_run(["pip", "install", library])
def _init_docker_container(self) -> Container:
def _init_docker_container(self) -> Any:
"""Initializes and returns a Docker container for code execution.
Stops and removes any existing container with the same name before creating
@@ -269,7 +268,7 @@ class CodeInterpreterTool(BaseTool):
tty=True,
working_dir="/workspace",
name=container_name,
volumes={current_path: {"bind": "/workspace", "mode": "rw"}}, # type: ignore
volumes={current_path: {"bind": "/workspace", "mode": "rw"}},
)
@staticmethod
@@ -351,14 +350,14 @@ class CodeInterpreterTool(BaseTool):
container = self._init_docker_container()
self._install_libraries(container, libraries_used)
exec_result = container.exec_run(["python3", "-c", code])
exec_result: Any = container.exec_run(["python3", "-c", code])
container.stop()
container.remove()
if exec_result.exit_code != 0:
return f"Something went wrong while running the code: \n{exec_result.output.decode('utf-8')}"
return exec_result.output.decode("utf-8")
return str(exec_result.output.decode("utf-8"))
@staticmethod
def run_code_in_restricted_sandbox(code: str) -> str:
@@ -385,12 +384,12 @@ class CodeInterpreterTool(BaseTool):
"""
Printer.print(
"WARNING: Running code in INSECURE restricted sandbox (vulnerable to escape attacks)",
color="bold_red"
color="bold_red",
)
exec_locals: dict[str, Any] = {}
try:
SandboxPython.exec(code=code, locals_=exec_locals)
return exec_locals.get("result", "No result variable found.")
return exec_locals.get("result", "No result variable found.") # type: ignore[no-any-return]
except Exception as e:
return f"An error occurred: {e!s}"
@@ -412,12 +411,14 @@ class CodeInterpreterTool(BaseTool):
Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
# Install libraries on the host machine
for library in libraries_used:
subprocess.run([sys.executable, "-m", "pip", "install", library], check=False) # noqa: S603
subprocess.run( # noqa: S603
[sys.executable, "-m", "pip", "install", library], check=False
)
# Execute the code
try:
exec_locals: dict[str, Any] = {}
exec(code, {}, exec_locals) # noqa: S102
return exec_locals.get("result", "No result variable found.")
return exec_locals.get("result", "No result variable found.") # type: ignore[no-any-return]
except Exception as e:
return f"An error occurred: {e!s}"

View File

@@ -10,7 +10,7 @@ import typing_extensions as te
class ComposioTool(BaseTool):
"""Wrapper for composio tools."""
composio_action: t.Callable
composio_action: t.Callable[..., t.Any]
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -70,7 +70,7 @@ class ComposioTool(BaseTool):
schema = action_schema.model_dump(exclude_none=True)
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
def function(**kwargs: t.Any) -> dict:
def function(**kwargs: t.Any) -> dict[str, t.Any]:
"""Wrapper function for composio action."""
return toolset.execute_action(
action=Action(schema["name"]),

View File

@@ -28,7 +28,7 @@ class ContextualAICreateAgentTool(BaseTool):
default_factory=lambda: ["contextual-client"]
)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
try:
from contextual import ContextualAI

View File

@@ -31,7 +31,7 @@ class ContextualAIQueryTool(BaseTool):
default_factory=lambda: ["contextual-client"]
)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
try:
from contextual import ContextualAI
@@ -99,20 +99,19 @@ class ContextualAIQueryTool(BaseTool):
response = self.contextual_client.agents.query.create(
agent_id=agent_id, messages=[{"role": "user", "content": query}]
)
if hasattr(response, "content"):
return response.content
if hasattr(response, "message"):
content = getattr(response, "content", None)
if content is not None:
return str(content)
message = getattr(response, "message", None)
if message is not None:
msg_content = getattr(message, "content", None)
return str(msg_content) if msg_content is not None else str(message)
messages = getattr(response, "messages", None)
if messages and len(messages) > 0:
last_message = messages[-1]
last_content = getattr(last_message, "content", None)
return (
response.message.content
if hasattr(response.message, "content")
else str(response.message)
)
if hasattr(response, "messages") and len(response.messages) > 0:
last_message = response.messages[-1]
return (
last_message.content
if hasattr(last_message, "content")
else str(last_message)
str(last_content) if last_content is not None else str(last_message)
)
return str(response)
except Exception as e:

View File

@@ -15,11 +15,11 @@ try:
COUCHBASE_AVAILABLE = True
except ImportError:
COUCHBASE_AVAILABLE = False
search = Any
Cluster = Any
SearchOptions = Any
VectorQuery = Any
VectorSearch = Any
search = Any # type: ignore[assignment,unused-ignore]
Cluster = Any # type: ignore[assignment,unused-ignore]
SearchOptions = Any # type: ignore[assignment,unused-ignore]
VectorQuery = Any # type: ignore[assignment,unused-ignore]
VectorSearch = Any # type: ignore[assignment,unused-ignore]
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, SkipValidation
@@ -41,7 +41,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
name: str = "CouchbaseFTSVectorSearchTool"
description: str = "A tool to search the Couchbase database for relevant information on internal documents."
args_schema: type[BaseModel] = CouchbaseToolSchema
cluster: SkipValidation[Cluster] = Field(
cluster: SkipValidation[Any] = Field(
description="An instance of the Couchbase Cluster connected to the desired Couchbase server.",
)
collection_name: str = Field(
@@ -136,7 +136,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
return True
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
"""Initialize the CouchbaseFTSVectorSearchTool.
Args:

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
@@ -26,7 +28,7 @@ class CSVSearchTool(RagTool):
)
args_schema: type[BaseModel] = CSVSearchToolSchema
def __init__(self, csv: str | None = None, **kwargs):
def __init__(self, csv: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if csv is not None:
self.add(csv)
@@ -34,7 +36,7 @@ class CSVSearchTool(RagTool):
self.args_schema = FixedCSVSearchToolSchema
self._generate_description()
def add(self, csv: str) -> None:
def add(self, csv: str) -> None: # type: ignore[override]
super().add(csv, data_type=DataType.CSV)
def _run( # type: ignore[override]

View File

@@ -1,8 +1,8 @@
import json
from typing import Literal
from typing import Any, Literal
from crewai.tools import BaseTool, EnvVar
from openai import Omit, OpenAI
from openai import OpenAI
from pydantic import BaseModel, Field
@@ -33,9 +33,9 @@ class DallETool(BaseTool):
]
| None
) = "1024x1024"
quality: (
Literal["standard", "hd", "low", "medium", "high", "auto"] | None | Omit
) = "standard"
quality: Literal["standard", "hd", "low", "medium", "high", "auto"] | None = (
"standard"
)
n: int = 1
env_vars: list[EnvVar] = Field(
@@ -48,7 +48,7 @@ class DallETool(BaseTool):
]
)
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
client = OpenAI()
image_description = kwargs.get("image_description")

View File

@@ -23,7 +23,7 @@ class DirectoryReadTool(BaseTool):
args_schema: type[BaseModel] = DirectoryReadToolSchema
directory: str | None = None
def __init__(self, directory: str | None = None, **kwargs):
def __init__(self, directory: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if directory is not None:
self.directory = directory

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
@@ -26,7 +28,7 @@ class DirectorySearchTool(RagTool):
)
args_schema: type[BaseModel] = DirectorySearchToolSchema
def __init__(self, directory: str | None = None, **kwargs):
def __init__(self, directory: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if directory is not None:
self.add(directory)
@@ -34,7 +36,7 @@ class DirectorySearchTool(RagTool):
self.args_schema = FixedDirectorySearchToolSchema
self._generate_description()
def add(self, directory: str) -> None:
def add(self, directory: str) -> None: # type: ignore[override]
super().add(directory, data_type=DataType.DIRECTORY)
def _run( # type: ignore[override]

View File

@@ -34,7 +34,7 @@ class DOCXSearchTool(RagTool):
)
args_schema: type[BaseModel] = DOCXSearchToolSchema
def __init__(self, docx: str | None = None, **kwargs):
def __init__(self, docx: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if docx is not None:
self.add(docx)
@@ -42,7 +42,7 @@ class DOCXSearchTool(RagTool):
self.args_schema = FixedDOCXSearchToolSchema
self._generate_description()
def add(self, docx: str) -> None:
def add(self, docx: str) -> None: # type: ignore[override]
super().add(docx, data_type=DataType.DOCX)
def _run( # type: ignore[override]

View File

@@ -71,8 +71,8 @@ class EXASearchTool(BaseTool):
content: bool | None = False,
summary: bool | None = False,
type: str | None = "auto",
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(
**kwargs,
)

View File

@@ -6,7 +6,7 @@ from crewai.tools import BaseTool
from pydantic import BaseModel
def strtobool(val) -> bool:
def strtobool(val: str | bool) -> bool:
if isinstance(val, bool):
return val
val = val.lower()
@@ -46,7 +46,10 @@ class FileWriterTool(BaseTool):
# itself, since that is not a valid file target.
real_directory = Path(directory).resolve()
real_filepath = Path(filepath).resolve()
if not real_filepath.is_relative_to(real_directory) or real_filepath == real_directory:
if (
not real_filepath.is_relative_to(real_directory)
or real_filepath == real_directory
):
return "Error: Invalid file path — the filename must not escape the target directory."
if kwargs.get("directory"):
@@ -62,9 +65,7 @@ class FileWriterTool(BaseTool):
file.write(kwargs["content"])
return f"Content successfully written to {real_filepath}"
except FileExistsError:
return (
f"File {real_filepath} already exists and overwrite option was not passed."
)
return f"File {real_filepath} already exists and overwrite option was not passed."
except KeyError as e:
return f"An error occurred while accessing key: {e!s}"
except Exception as e:

View File

@@ -106,7 +106,7 @@ class FileCompressorTool(BaseTool):
return True
@staticmethod
def _compress_zip(input_path: str, output_path: str):
def _compress_zip(input_path: str, output_path: str) -> None:
"""Compresses input into a zip archive."""
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
if os.path.isfile(input_path):
@@ -119,7 +119,7 @@ class FileCompressorTool(BaseTool):
zipf.write(full_path, arcname)
@staticmethod
def _compress_tar(input_path: str, output_path: str, format: str):
def _compress_tar(input_path: str, output_path: str, format: str) -> None:
"""Compresses input into a tar archive with the given format."""
format_mode = {
"tar": "w",

View File

@@ -1,14 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import Any
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
if TYPE_CHECKING:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
try:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
@@ -63,7 +60,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
},
}
)
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
_firecrawl: Any = PrivateAttr(None)
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -75,14 +72,14 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
]
)
def __init__(self, api_key: str | None = None, **kwargs):
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self._initialize_firecrawl()
def _initialize_firecrawl(self) -> None:
try:
from firecrawl import FirecrawlApp # type: ignore
from firecrawl import FirecrawlApp
self._firecrawl = FirecrawlApp(api_key=self.api_key)
except ImportError:
@@ -105,7 +102,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
"`firecrawl-py` package not found, please run `uv add firecrawl-py`"
) from None
def _run(self, url: str):
def _run(self, url: str) -> Any:
if not self._firecrawl:
raise RuntimeError("FirecrawlApp not properly initialized")
@@ -113,13 +110,10 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
try:
from firecrawl import FirecrawlApp
from firecrawl import FirecrawlApp # noqa: F401
# Only rebuild if the class hasn't been initialized yet
if not hasattr(FirecrawlCrawlWebsiteTool, "_model_rebuilt"):
if not getattr(FirecrawlCrawlWebsiteTool, "_model_rebuilt", False):
FirecrawlCrawlWebsiteTool.model_rebuild()
FirecrawlCrawlWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
except ImportError:
"""
When this tool is not used, then exception can be ignored.
"""
pass

View File

@@ -1,14 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import Any
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
if TYPE_CHECKING:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
try:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
@@ -70,7 +67,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
}
)
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
_firecrawl: Any = PrivateAttr(None)
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -82,10 +79,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
]
)
def __init__(self, api_key: str | None = None, **kwargs):
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
try:
from firecrawl import FirecrawlApp # type: ignore
from firecrawl import FirecrawlApp
except ImportError:
import click
@@ -105,7 +102,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
self._firecrawl = FirecrawlApp(api_key=api_key)
def _run(self, url: str):
def _run(self, url: str) -> Any:
if not self._firecrawl:
raise RuntimeError("FirecrawlApp not properly initialized")
@@ -113,13 +110,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
try:
from firecrawl import FirecrawlApp
from firecrawl import FirecrawlApp # noqa: F401
# Must rebuild model after class is defined
if not hasattr(FirecrawlScrapeWebsiteTool, "_model_rebuilt"):
if not getattr(FirecrawlScrapeWebsiteTool, "_model_rebuilt", False):
FirecrawlScrapeWebsiteTool.model_rebuild()
FirecrawlScrapeWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
except ImportError:
"""
When this tool is not used, then exception can be ignored.
"""
pass

View File

@@ -1,15 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import Any
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
if TYPE_CHECKING:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
try:
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
@@ -65,7 +61,7 @@ class FirecrawlSearchTool(BaseTool):
},
}
)
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
_firecrawl: Any = PrivateAttr(None)
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -77,14 +73,14 @@ class FirecrawlSearchTool(BaseTool):
]
)
def __init__(self, api_key: str | None = None, **kwargs):
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self._initialize_firecrawl()
def _initialize_firecrawl(self) -> None:
try:
from firecrawl import FirecrawlApp # type: ignore
from firecrawl import FirecrawlApp
self._firecrawl = FirecrawlApp(api_key=self.api_key)
except ImportError:
@@ -121,13 +117,10 @@ class FirecrawlSearchTool(BaseTool):
try:
from firecrawl import FirecrawlApp # type: ignore
from firecrawl import FirecrawlApp # noqa: F401
# Only rebuild if the class hasn't been initialized yet
if not hasattr(FirecrawlSearchTool, "_model_rebuilt"):
if not getattr(FirecrawlSearchTool, "_model_rebuilt", False):
FirecrawlSearchTool.model_rebuild()
FirecrawlSearchTool._model_rebuilt = True # type: ignore[attr-defined]
except ImportError:
"""
When this tool is not used, then exception can be ignored.
"""
pass

View File

@@ -1,4 +1,5 @@
import os
from typing import Any
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
@@ -46,7 +47,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
]
)
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
input_data = GenerateCrewaiAutomationToolSchema(**kwargs)
response = requests.post( # noqa: S113
f"{self.crewai_enterprise_url}/crewai_plus/api/v1/studio",
@@ -58,7 +59,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
studio_project_url = response.json().get("url")
return f"Generated CrewAI Studio project URL: {studio_project_url}"
def _get_headers(self, organization_id: str | None = None) -> dict:
def _get_headers(self, organization_id: str | None = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {self.personal_access_token}",
"Content-Type": "application/json",

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
@@ -38,8 +40,8 @@ class GithubSearchTool(RagTool):
self,
github_repo: str | None = None,
content_types: list[str] | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if github_repo and content_types:
@@ -48,7 +50,7 @@ class GithubSearchTool(RagTool):
self.args_schema = FixedGithubSearchToolSchema
self._generate_description()
def add(
def add( # type: ignore[override]
self,
repo: str,
content_types: list[str] | None = None,

View File

@@ -10,7 +10,7 @@ class HyperbrowserLoadToolSchema(BaseModel):
operation: Literal["scrape", "crawl"] = Field(
description="Operation to perform on the website. Either 'scrape' or 'crawl'"
)
params: dict | None = Field(
params: dict[str, Any] | None = Field(
description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait"
)
@@ -42,7 +42,7 @@ class HyperbrowserLoadTool(BaseTool):
]
)
def __init__(self, api_key: str | None = None, **kwargs):
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.api_key = api_key or os.getenv("HYPERBROWSER_API_KEY")
if not api_key:
@@ -65,7 +65,7 @@ class HyperbrowserLoadTool(BaseTool):
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
@staticmethod
def _prepare_params(params: dict) -> dict:
def _prepare_params(params: dict[str, Any]) -> dict[str, Any]:
"""Prepare session and scrape options parameters."""
try:
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
@@ -91,7 +91,7 @@ class HyperbrowserLoadTool(BaseTool):
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
return params
def _extract_content(self, data: Any | None):
def _extract_content(self, data: Any | None) -> str:
"""Extract content from response data."""
content = ""
if data:
@@ -102,15 +102,15 @@ class HyperbrowserLoadTool(BaseTool):
self,
url: str,
operation: Literal["scrape", "crawl"] = "scrape",
params: dict | None = None,
):
params: dict[str, Any] | None = None,
) -> str:
if params is None:
params = {}
try:
from hyperbrowser.models.crawl import ( # type: ignore[import-untyped]
StartCrawlJobParams,
)
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
from hyperbrowser.models.scrape import (
StartScrapeJobParams,
)
except ImportError as e:

View File

@@ -134,7 +134,8 @@ class InvokeCrewAIAutomationTool(BaseTool):
json={"inputs": inputs},
timeout=30,
)
return response.json()
result: dict[str, Any] = response.json()
return result
def _get_crew_status(self, crew_id: str) -> dict[str, Any]:
"""Get the status of a crew task.
@@ -153,9 +154,10 @@ class InvokeCrewAIAutomationTool(BaseTool):
},
timeout=30,
)
return response.json()
result: dict[str, Any] = response.json()
return result
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
"""Execute the crew invocation tool."""
if kwargs is None:
kwargs = {}
@@ -172,7 +174,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
try:
status_response = self._get_crew_status(crew_id=kickoff_id)
if status_response.get("state", "").lower() == "success":
return status_response.get("result", "No result returned")
return str(status_response.get("result", "No result returned"))
if status_response.get("state", "").lower() == "failed":
return f"Error: Crew task failed. Response: {status_response}"
except Exception as e:

View File

@@ -1,3 +1,5 @@
from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import requests
@@ -15,14 +17,14 @@ class JinaScrapeWebsiteTool(BaseTool):
args_schema: type[BaseModel] = JinaScrapeWebsiteToolInput
website_url: str | None = None
api_key: str | None = None
headers: dict = Field(default_factory=dict)
headers: dict[str, str] = Field(default_factory=dict)
def __init__(
self,
website_url: str | None = None,
api_key: str | None = None,
custom_headers: dict | None = None,
**kwargs,
custom_headers: dict[str, str] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
if website_url is not None:

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -27,7 +29,7 @@ class JSONSearchTool(RagTool):
)
args_schema: type[BaseModel] = JSONSearchToolSchema
def __init__(self, json_path: str | None = None, **kwargs):
def __init__(self, json_path: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if json_path is not None:
self.add(json_path)

View File

@@ -20,7 +20,7 @@ class LinkupSearchTool(BaseTool):
description: str = (
"Performs an API call to Linkup to retrieve contextual information."
)
_client: LinkupClient = PrivateAttr() # type: ignore
_client: Any = PrivateAttr()
package_dependencies: list[str] = Field(default_factory=lambda: ["linkup-sdk"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -60,7 +60,7 @@ class LinkupSearchTool(BaseTool):
output_type: Literal[
"searchResults", "sourcedAnswer", "structured"
] = "searchResults",
) -> dict:
) -> dict[str, Any]:
"""Executes a search using the Linkup API.
:param query: The query to search for.

View File

@@ -17,11 +17,7 @@ class LlamaIndexTool(BaseTool):
**kwargs: Any,
) -> Any:
"""Run tool."""
from llama_index.core.tools import ( # type: ignore[import-not-found]
BaseTool as LlamaBaseTool,
)
tool = cast(LlamaBaseTool, self.llama_index_tool)
tool = self.llama_index_tool
if self.result_as_answer:
return tool(*args, **kwargs).content
@@ -36,7 +32,6 @@ class LlamaIndexTool(BaseTool):
if not isinstance(tool, LlamaBaseTool):
raise ValueError(f"Expected a LlamaBaseTool, got {type(tool)}")
tool = cast(LlamaBaseTool, tool)
if tool.metadata.fn_schema is None:
raise ValueError(
@@ -64,9 +59,7 @@ class LlamaIndexTool(BaseTool):
from llama_index.core.query_engine import ( # type: ignore[import-not-found]
BaseQueryEngine,
)
from llama_index.core.tools import ( # type: ignore[import-not-found]
QueryEngineTool,
)
from llama_index.core.tools import QueryEngineTool
if not isinstance(query_engine, BaseQueryEngine):
raise ValueError(f"Expected a BaseQueryEngine, got {type(query_engine)}")

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
@@ -26,7 +28,7 @@ class MDXSearchTool(RagTool):
)
args_schema: type[BaseModel] = MDXSearchToolSchema
def __init__(self, mdx: str | None = None, **kwargs):
def __init__(self, mdx: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if mdx is not None:
self.add(mdx)
@@ -34,7 +36,7 @@ class MDXSearchTool(RagTool):
self.args_schema = FixedMDXSearchToolSchema
self._generate_description()
def add(self, mdx: str) -> None:
def add(self, mdx: str) -> None: # type: ignore[override]
super().add(mdx, data_type=DataType.MDX)
def _run( # type: ignore[override]

View File

@@ -2,7 +2,7 @@
import json
import logging
from typing import Any
from typing import Any, cast
from uuid import uuid4
from crewai.tools import BaseTool, EnvVar
@@ -108,7 +108,7 @@ class MergeAgentHandlerTool(BaseTool):
)
raise MergeAgentHandlerToolError(f"API Error: {error_msg}")
return result
return cast(dict[str, Any], result)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to call Agent Handler API: {e!s}")
@@ -219,8 +219,8 @@ class MergeAgentHandlerTool(BaseTool):
required = params.get("required", [])
for field_name, field_schema in properties.items():
field_type = Any # Default type
field_default = ... # Required by default
field_type: Any = Any
field_default: Any = ...
# Map JSON schema types to Python types
json_type = field_schema.get("type", "string")
@@ -256,7 +256,7 @@ class MergeAgentHandlerTool(BaseTool):
# Create the Pydantic model
if fields:
args_schema = create_model(
args_schema = create_model( # type: ignore[call-overload]
f"{tool_name.replace('__', '_').title()}Args",
**fields,
)

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pymongo.collection import Collection
from pymongo.collection import Collection as MongoCollection
def _vector_search_index_definition(
@@ -34,7 +34,7 @@ def _vector_search_index_definition(
def create_vector_search_index(
collection: Collection,
collection: MongoCollection[Any],
index_name: str,
dimensions: int,
path: str,
@@ -84,7 +84,7 @@ def create_vector_search_index(
)
def _is_index_ready(collection: Collection, index_name: str) -> bool:
def _is_index_ready(collection: MongoCollection[Any], index_name: str) -> bool:
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY.
@@ -102,7 +102,7 @@ def _is_index_ready(collection: Collection, index_name: str) -> bool:
def _wait_for_predicate(
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
predicate: Callable[[], bool], err: str, timeout: float = 120, interval: float = 0.5
) -> None:
"""Generic to block until the predicate returns true.

View File

@@ -31,7 +31,7 @@ class MongoDBVectorSearchConfig(BaseModel):
default=None,
description="List of MQL match expressions comparing an indexed field",
)
post_filter_pipeline: list[dict] | None = Field(
post_filter_pipeline: list[dict[str, Any]] | None = Field(
default=None,
description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.",
)
@@ -105,7 +105,7 @@ class MongoDBVectorSearchTool(BaseTool):
)
package_dependencies: list[str] = Field(default_factory=lambda: ["mongdb"])
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not MONGODB_AVAILABLE:
import click
@@ -120,6 +120,7 @@ class MongoDBVectorSearchTool(BaseTool):
else:
raise ImportError("You are missing the 'mongodb' crewai tool.")
self._openai_client: AzureOpenAI | Client
if "AZURE_OPENAI_ENDPOINT" in os.environ:
self._openai_client = AzureOpenAI()
elif "OPENAI_API_KEY" in os.environ:
@@ -132,7 +133,7 @@ class MongoDBVectorSearchTool(BaseTool):
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo
self._client = MongoClient(
self._client: MongoClient[dict[str, Any]] = MongoClient(
self.connection_string,
driver=DriverInfo(name="CrewAI", version=version("crewai-tools")),
)
@@ -236,7 +237,7 @@ class MongoDBVectorSearchTool(BaseTool):
def _bulk_embed_and_insert_texts(
self,
texts: list[str],
metadatas: list[dict],
metadatas: list[dict[str, Any]],
ids: list[str],
) -> list[str]:
"""Bulk insert single batch of texts, embeddings, and ids."""
@@ -315,16 +316,18 @@ class MongoDBVectorSearchTool(BaseTool):
logger.error(f"Error: {e}")
return ""
def __del__(self):
def __del__(self) -> None:
"""Cleanup clients on deletion."""
try:
if hasattr(self, "_client") and self._client:
self._client.close()
client: Any = getattr(self, "_client", None)
if client is not None:
client.close()
except Exception as e:
logger.error(f"Error: {e}")
try:
if hasattr(self, "_openai_client") and self._openai_client:
self._openai_client.close()
openai_client: Any = getattr(self, "_openai_client", None)
if openai_client is not None:
openai_client.close()
except Exception as e:
logger.error(f"Error: {e}")

View File

@@ -31,11 +31,11 @@ class MultiOnTool(BaseTool):
def __init__(
self,
api_key: str | None = None,
**kwargs,
**kwargs: Any,
):
super().__init__(**kwargs)
try:
from multion.client import MultiOn # type: ignore
from multion.client import MultiOn
except ImportError:
import click
@@ -78,4 +78,4 @@ class MultiOnTool(BaseTool):
)
self.session_id = browse.session_id
return browse.message + "\n\n STATUS: " + browse.status
return str(browse.message) + "\n\n STATUS: " + str(browse.status)

View File

@@ -21,13 +21,13 @@ class MySQLSearchTool(RagTool):
args_schema: type[BaseModel] = MySQLSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: str, **kwargs):
def __init__(self, table_name: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description()
def add(
def add( # type: ignore[override]
self,
table_name: str,
**kwargs: Any,

View File

@@ -27,8 +27,8 @@ class NL2SQLTool(BaseTool):
title="Database URI",
description="The URI of the database to connect to.",
)
tables: list = Field(default_factory=list)
columns: dict = Field(default_factory=dict)
tables: list[dict[str, Any]] = Field(default_factory=list)
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
args_schema: type[BaseModel] = NL2SQLToolInput
def model_post_init(self, __context: Any) -> None:
@@ -37,8 +37,11 @@ class NL2SQLTool(BaseTool):
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
)
data = {}
tables = self._fetch_available_tables()
data: dict[str, list[dict[str, Any]] | str] = {}
result = self._fetch_available_tables()
if isinstance(result, str):
raise RuntimeError(f"Failed to fetch tables: {result}")
tables: list[dict[str, Any]] = result
for table in tables:
table_columns = self._fetch_all_available_columns(table["table_name"])
@@ -47,17 +50,19 @@ class NL2SQLTool(BaseTool):
self.tables = tables
self.columns = data
def _fetch_available_tables(self):
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
return self.execute_sql(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
def _fetch_all_available_columns(self, table_name: str):
def _fetch_all_available_columns(
self, table_name: str
) -> list[dict[str, Any]] | str:
return self.execute_sql(
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
)
def _run(self, sql_query: str):
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
try:
data = self.execute_sql(sql_query)
except Exception as exc:
@@ -69,7 +74,7 @@ class NL2SQLTool(BaseTool):
return data
def execute_sql(self, sql_query: str) -> list | str:
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
if not SQLALCHEMY_AVAILABLE:
raise ImportError(
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"

View File

@@ -4,6 +4,7 @@ This tool provides functionality for extracting text from images using supported
"""
import base64
from typing import Any
from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
@@ -43,7 +44,7 @@ class OCRTool(BaseTool):
llm: LLM = Field(default_factory=lambda: LLM(model="gpt-4o", temperature=0.7))
args_schema: type[BaseModel] = OCRToolSchema
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
"""Execute the OCR operation on the provided image.
Args:
@@ -88,7 +89,7 @@ class OCRTool(BaseTool):
return self.llm.call(messages=messages)
@staticmethod
def _encode_image(image_path: str):
def _encode_image(image_path: str) -> str:
"""Encode an image file to base64 format.
Args:

View File

@@ -41,12 +41,12 @@ class OxylabsAmazonProductScraperConfig(BaseModel):
user_agent_type: str | None = Field(None, description="Device type and browser.")
render: str | None = Field(None, description="Enables JavaScript rendering.")
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
context: list | None = Field(
context: list[Any] | None = Field(
None,
description="Additional advanced settings and controls for specialized requirements.",
)
parse: bool | None = Field(None, description="True will return structured data.")
parsing_instructions: dict | None = Field(
parsing_instructions: dict[str, Any] | None = Field(
None, description="Instructions for parsing the results."
)
@@ -71,7 +71,7 @@ class OxylabsAmazonProductScraperTool(BaseTool):
description: str = "Scrape Amazon product pages with Oxylabs Amazon Product Scraper"
args_schema: type[BaseModel] = OxylabsAmazonProductScraperArgs
oxylabs_api: RealtimeClient
oxylabs_api: Any
config: OxylabsAmazonProductScraperConfig
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
env_vars: list[EnvVar] = Field(
@@ -93,8 +93,8 @@ class OxylabsAmazonProductScraperTool(BaseTool):
self,
username: str | None = None,
password: str | None = None,
config: OxylabsAmazonProductScraperConfig | dict | None = None,
**kwargs,
config: OxylabsAmazonProductScraperConfig | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
bits, _ = architecture()
sdk_type = (
@@ -164,4 +164,4 @@ class OxylabsAmazonProductScraperTool(BaseTool):
if isinstance(content, dict):
return json.dumps(content)
return content
return str(content)

View File

@@ -43,12 +43,12 @@ class OxylabsAmazonSearchScraperConfig(BaseModel):
user_agent_type: str | None = Field(None, description="Device type and browser.")
render: str | None = Field(None, description="Enables JavaScript rendering.")
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
context: list | None = Field(
context: list[Any] | None = Field(
None,
description="Additional advanced settings and controls for specialized requirements.",
)
parse: bool | None = Field(None, description="True will return structured data.")
parsing_instructions: dict | None = Field(
parsing_instructions: dict[str, Any] | None = Field(
None, description="Instructions for parsing the results."
)
@@ -73,7 +73,7 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
description: str = "Scrape Amazon search results with Oxylabs Amazon Search Scraper"
args_schema: type[BaseModel] = OxylabsAmazonSearchScraperArgs
oxylabs_api: RealtimeClient
oxylabs_api: Any
config: OxylabsAmazonSearchScraperConfig
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
env_vars: list[EnvVar] = Field(
@@ -95,9 +95,9 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
self,
username: str | None = None,
password: str | None = None,
config: OxylabsAmazonSearchScraperConfig | dict | None = None,
**kwargs,
):
config: OxylabsAmazonSearchScraperConfig | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
bits, _ = architecture()
sdk_type = (
f"oxylabs-crewai-sdk-python/"
@@ -166,4 +166,4 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
if isinstance(content, dict):
return json.dumps(content)
return content
return str(content)

View File

@@ -46,12 +46,12 @@ class OxylabsGoogleSearchScraperConfig(BaseModel):
user_agent_type: str | None = Field(None, description="Device type and browser.")
render: str | None = Field(None, description="Enables JavaScript rendering.")
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
context: list | None = Field(
context: list[Any] | None = Field(
None,
description="Additional advanced settings and controls for specialized requirements.",
)
parse: bool | None = Field(None, description="True will return structured data.")
parsing_instructions: dict | None = Field(
parsing_instructions: dict[str, Any] | None = Field(
None, description="Instructions for parsing the results."
)
@@ -76,7 +76,7 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
description: str = "Scrape Google Search results with Oxylabs Google Search Scraper"
args_schema: type[BaseModel] = OxylabsGoogleSearchScraperArgs
oxylabs_api: RealtimeClient
oxylabs_api: Any
config: OxylabsGoogleSearchScraperConfig
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
env_vars: list[EnvVar] = Field(
@@ -98,9 +98,9 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
self,
username: str | None = None,
password: str | None = None,
config: OxylabsGoogleSearchScraperConfig | dict | None = None,
**kwargs,
):
config: OxylabsGoogleSearchScraperConfig | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
bits, _ = architecture()
sdk_type = (
f"oxylabs-crewai-sdk-python/"
@@ -158,7 +158,7 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
)
return username, password
def _run(self, query: str, **kwargs) -> str:
def _run(self, query: str, **kwargs: Any) -> str:
response = self.oxylabs_api.google.scrape_search(
query,
**self.config.model_dump(exclude_none=True),
@@ -169,4 +169,4 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
if isinstance(content, dict):
return json.dumps(content)
return content
return str(content)

View File

@@ -37,12 +37,12 @@ class OxylabsUniversalScraperConfig(BaseModel):
user_agent_type: str | None = Field(None, description="Device type and browser.")
render: str | None = Field(None, description="Enables JavaScript rendering.")
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
context: list | None = Field(
context: list[Any] | None = Field(
None,
description="Additional advanced settings and controls for specialized requirements.",
)
parse: bool | None = Field(None, description="True will return structured data.")
parsing_instructions: dict | None = Field(
parsing_instructions: dict[str, Any] | None = Field(
None, description="Instructions for parsing the results."
)
@@ -67,7 +67,7 @@ class OxylabsUniversalScraperTool(BaseTool):
description: str = "Scrape any url with Oxylabs Universal Scraper"
args_schema: type[BaseModel] = OxylabsUniversalScraperArgs
oxylabs_api: RealtimeClient
oxylabs_api: Any
config: OxylabsUniversalScraperConfig
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
env_vars: list[EnvVar] = Field(
@@ -89,9 +89,9 @@ class OxylabsUniversalScraperTool(BaseTool):
self,
username: str | None = None,
password: str | None = None,
config: OxylabsUniversalScraperConfig | dict | None = None,
**kwargs,
):
config: OxylabsUniversalScraperConfig | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
bits, _ = architecture()
sdk_type = (
f"oxylabs-crewai-sdk-python/"
@@ -160,4 +160,4 @@ class OxylabsUniversalScraperTool(BaseTool):
if isinstance(content, dict):
return json.dumps(content)
return content
return str(content)

View File

@@ -1,11 +1,12 @@
import random
from typing import Any
from crewai import Agent, Crew, Task
from patronus import ( # type: ignore[import-not-found,import-untyped]
from patronus import ( # type: ignore[import-untyped]
Client,
EvaluationResult,
)
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found,import-untyped]
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found]
PatronusLocalEvaluatorTool,
)
@@ -15,8 +16,8 @@ client = Client()
# Example of an evaluator that returns a random pass/fail result
@client.register_local_evaluator("random_evaluator")
def random_evaluator(**kwargs):
@client.register_local_evaluator("random_evaluator") # type: ignore[untyped-decorator]
def random_evaluator(**kwargs: Any) -> Any:
score = random.random() # noqa: S311
return EvaluationResult(
score_raw=score,

View File

@@ -34,7 +34,7 @@ class PatronusEvalTool(BaseTool):
stacklevel=2,
)
def _init_run(self):
def _init_run(self) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
evaluators_set = json.loads(
requests.get(
"https://api.patronus.ai/v1/evaluators",
@@ -136,8 +136,9 @@ class PatronusEvalTool(BaseTool):
"evaluators": evals,
}
api_key = os.getenv("PATRONUS_API_KEY", "")
headers = {
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
"X-API-KEY": api_key,
"accept": "application/json",
"content-type": "application/json",
}

View File

@@ -1,16 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field
if TYPE_CHECKING:
from patronus import Client, EvaluationResult # type: ignore[import-untyped]
try:
import patronus # noqa: F401
import patronus # type: ignore[import-untyped] # noqa: F401
PYPATRONUS_AVAILABLE = True
except ImportError:
@@ -37,7 +34,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
name: str = "Patronus Local Evaluator Tool"
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
args_schema: type[BaseModel] = FixedLocalEvaluatorToolSchema
client: Client = None
client: Any = None
evaluator: str
evaluated_model_gold_answer: str
@@ -46,7 +43,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
def __init__(
self,
patronus_client: Client = None,
patronus_client: Any = None,
evaluator: str = "",
evaluated_model_gold_answer: str = "",
**kwargs: Any,
@@ -56,7 +53,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self._initialize_patronus(patronus_client)
def _initialize_patronus(self, patronus_client: Client) -> None:
def _initialize_patronus(self, patronus_client: Any) -> None:
try:
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
@@ -94,7 +91,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluator = self.evaluator
result: EvaluationResult = self.client.evaluate(
result: Any = self.client.evaluate(
evaluator=evaluator,
evaluated_model_input=evaluated_model_input,
evaluated_model_output=evaluated_model_output,

View File

@@ -8,16 +8,16 @@ import requests
class FixedBaseToolSchema(BaseModel):
evaluated_model_input: dict = Field(
evaluated_model_input: dict[str, Any] = Field(
..., description="The agent's task description in simple text"
)
evaluated_model_output: dict = Field(
evaluated_model_output: dict[str, Any] = Field(
..., description="The agent's output of the task"
)
evaluated_model_retrieved_context: dict = Field(
evaluated_model_retrieved_context: dict[str, Any] = Field(
..., description="The agent's context"
)
evaluated_model_gold_answer: dict = Field(
evaluated_model_gold_answer: dict[str, Any] = Field(
..., description="The agent's gold answer only if available"
)
evaluators: list[dict[str, str]] = Field(
@@ -57,8 +57,9 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer")
evaluators = self.evaluators
api_key = os.getenv("PATRONUS_API_KEY", "")
headers = {
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
"X-API-KEY": api_key,
"accept": "application/json",
"content-type": "application/json",
}

View File

@@ -37,7 +37,7 @@ class PDFSearchTool(RagTool):
self._generate_description()
return self
def add(self, pdf: str) -> None:
def add(self, pdf: str) -> None: # type: ignore[override]
super().add(pdf, data_type=DataType.PDF_FILE)
def _run( # type: ignore[override]

View File

@@ -119,7 +119,7 @@ class QdrantVectorSearchTool(BaseTool):
)
)()
)
results = self.client.query_points(
results = self.client.query_points( # type: ignore[union-attr]
collection_name=self.qdrant_config.collection_name,
query=query_vector,
query_filter=search_filter,

View File

@@ -33,9 +33,9 @@ class ScrapeElementFromWebsiteTool(BaseTool):
description: str = "A tool that can be used to read a website content."
args_schema: type[BaseModel] = ScrapeElementFromWebsiteToolSchema
website_url: str | None = None
cookies: dict | None = None
cookies: dict[str, str] | None = None
css_element: str | None = None
headers: dict | None = Field(
headers: dict[str, str] | None = Field(
default_factory=lambda: {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
@@ -50,9 +50,9 @@ class ScrapeElementFromWebsiteTool(BaseTool):
def __init__(
self,
website_url: str | None = None,
cookies: dict | None = None,
cookies: dict[str, str] | None = None,
css_element: str | None = None,
**kwargs,
**kwargs: Any,
):
super().__init__(**kwargs)
if website_url is not None:
@@ -64,7 +64,7 @@ class ScrapeElementFromWebsiteTool(BaseTool):
self.args_schema = FixedScrapeElementFromWebsiteToolSchema
self._generate_description()
if cookies is not None:
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
self.cookies = {cookies["name"]: os.getenv(cookies["value"]) or ""}
def _run(
self,

View File

@@ -31,8 +31,8 @@ class ScrapeWebsiteTool(BaseTool):
description: str = "A tool that can be used to read a website content."
args_schema: type[BaseModel] = ScrapeWebsiteToolSchema
website_url: str | None = None
cookies: dict | None = None
headers: dict | None = Field(
cookies: dict[str, str] | None = None
headers: dict[str, str] | None = Field(
default_factory=lambda: {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
@@ -46,8 +46,8 @@ class ScrapeWebsiteTool(BaseTool):
def __init__(
self,
website_url: str | None = None,
cookies: dict | None = None,
**kwargs,
cookies: dict[str, str] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
if not BEAUTIFULSOUP_AVAILABLE:
@@ -63,7 +63,7 @@ class ScrapeWebsiteTool(BaseTool):
self.args_schema = FixedScrapeWebsiteToolSchema
self._generate_description()
if cookies is not None:
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
self.cookies = {cookies["name"]: os.getenv(cookies["value"]) or ""}
def _run(
self,

View File

@@ -1,18 +1,13 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any
from typing import Any
from urllib.parse import urlparse
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field, field_validator
# Type checking import
if TYPE_CHECKING:
from scrapegraph_py import Client # type: ignore[import-untyped]
class ScrapegraphError(Exception):
"""Base exception for Scrapegraph-related errors."""
@@ -36,7 +31,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
@field_validator("website_url")
@classmethod
def validate_url(cls, v):
def validate_url(cls, v: str) -> str:
"""Validate URL format."""
try:
result = urlparse(v)
@@ -69,7 +64,7 @@ class ScrapegraphScrapeTool(BaseTool):
user_prompt: str | None = None
api_key: str | None = None
enable_logging: bool = False
_client: Client | None = None
_client: Any = None
package_dependencies: list[str] = Field(default_factory=lambda: ["scrapegraph-py"])
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -87,12 +82,12 @@ class ScrapegraphScrapeTool(BaseTool):
user_prompt: str | None = None,
api_key: str | None = None,
enable_logging: bool = False,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
try:
from scrapegraph_py import Client # type: ignore[import-not-found]
from scrapegraph_py.logger import ( # type: ignore[import-not-found]
from scrapegraph_py import Client
from scrapegraph_py.logger import (
sgai_logger,
)
@@ -146,7 +141,7 @@ class ScrapegraphScrapeTool(BaseTool):
"Invalid URL format. URL must include scheme (http/https) and domain"
) from e
def _handle_api_response(self, response: dict) -> str:
def _handle_api_response(self, response: dict[str, Any]) -> str:
"""Handle and validate API response."""
if not response:
raise RuntimeError("Empty response from Scrapegraph API")
@@ -160,7 +155,7 @@ class ScrapegraphScrapeTool(BaseTool):
if "result" not in response:
raise RuntimeError("Invalid response format from Scrapegraph API")
return response["result"]
return str(response["result"])
def _run(
self,

View File

@@ -69,15 +69,16 @@ class ScrapflyScrapeWebsiteTool(BaseTool):
scrape_format: str = "markdown",
scrape_config: dict[str, Any] | None = None,
ignore_scrape_failures: bool | None = None,
):
from scrapfly import ScrapeApiResponse, ScrapeConfig
) -> str | None:
from scrapfly import ScrapeConfig
scrape_config = scrape_config if scrape_config is not None else {}
try:
response: ScrapeApiResponse = self.scrapfly.scrape( # type: ignore[union-attr]
response = self.scrapfly.scrape( # type: ignore[union-attr]
ScrapeConfig(url, format=scrape_format, **scrape_config)
)
return response.scrape_result["content"]
result: str = response.scrape_result["content"]
return result
except Exception as e:
if ignore_scrape_failures:
logger.error(f"Error fetching data from {url}, exception: {e}")

View File

@@ -25,7 +25,7 @@ class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
@field_validator("website_url")
@classmethod
def validate_website_url(cls, v):
def validate_website_url(cls, v: str) -> str:
if not v:
raise ValueError("Website URL cannot be empty")
@@ -54,7 +54,7 @@ class SeleniumScrapingTool(BaseTool):
args_schema: type[BaseModel] = SeleniumScrapingToolSchema
website_url: str | None = None
driver: Any | None = None
cookie: dict | None = None
cookie: dict[str, Any] | None = None
wait_time: int | None = 3
css_element: str | None = None
return_html: bool | None = False
@@ -66,17 +66,17 @@ class SeleniumScrapingTool(BaseTool):
def __init__(
self,
website_url: str | None = None,
cookie: dict | None = None,
cookie: dict[str, Any] | None = None,
css_element: str | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
try:
from selenium import webdriver # type: ignore[import-not-found]
from selenium.webdriver.chrome.options import ( # type: ignore[import-not-found]
from selenium import webdriver
from selenium.webdriver.chrome.options import (
Options,
)
from selenium.webdriver.common.by import ( # type: ignore[import-not-found]
from selenium.webdriver.common.by import (
By,
)
except ImportError:
@@ -91,11 +91,11 @@ class SeleniumScrapingTool(BaseTool):
["uv", "pip", "install", "selenium", "webdriver-manager"], # noqa: S607
check=True,
)
from selenium import webdriver # type: ignore[import-not-found]
from selenium.webdriver.chrome.options import ( # type: ignore[import-not-found]
from selenium import webdriver
from selenium.webdriver.chrome.options import (
Options,
)
from selenium.webdriver.common.by import ( # type: ignore[import-not-found]
from selenium.webdriver.common.by import (
By,
)
else:
@@ -146,8 +146,10 @@ class SeleniumScrapingTool(BaseTool):
if self.driver is not None:
self.driver.close()
def _get_content(self, css_element, return_html):
content = []
def _get_content(
self, css_element: str | None, return_html: bool | None
) -> list[str]:
content: list[str] = []
if self._is_css_element_empty(css_element):
content.append(self._get_body_content(return_html))
@@ -156,20 +158,26 @@ class SeleniumScrapingTool(BaseTool):
return content
def _is_css_element_empty(self, css_element):
def _is_css_element_empty(self, css_element: str | None) -> bool:
return css_element is None or css_element.strip() == ""
def _get_body_content(self, return_html):
def _get_body_content(self, return_html: bool | None) -> str:
if self.driver is None or self._by is None:
raise RuntimeError("Driver not initialized. Call _run first.")
body_element = self.driver.find_element(self._by.TAG_NAME, "body")
return (
return str(
body_element.get_attribute("outerHTML")
if return_html
else body_element.text
)
def _get_elements_content(self, css_element, return_html):
elements_content = []
def _get_elements_content(
self, css_element: str | None, return_html: bool | None
) -> list[str]:
if self.driver is None or self._by is None:
raise RuntimeError("Driver not initialized. Call _run first.")
elements_content: list[str] = []
for element in self.driver.find_elements(self._by.CSS_SELECTOR, css_element):
elements_content.append( # noqa: PERF401
@@ -178,7 +186,9 @@ class SeleniumScrapingTool(BaseTool):
return elements_content
def _make_request(self, url, cookie, wait_time):
def _make_request(
self, url: str | None, cookie: dict[str, Any] | None, wait_time: int | None
) -> None:
if not url:
raise ValueError("URL cannot be empty")
@@ -186,13 +196,17 @@ class SeleniumScrapingTool(BaseTool):
if not re.match(r"^https?://", url):
raise ValueError("URL must start with http:// or https://")
if self.driver is None:
raise RuntimeError("Driver not initialized. Call _run first.")
sleep_time = wait_time or 0
self.driver.get(url)
time.sleep(wait_time)
time.sleep(sleep_time)
if cookie:
self.driver.add_cookie(cookie)
time.sleep(wait_time)
time.sleep(sleep_time)
self.driver.get(url)
time.sleep(wait_time)
time.sleep(sleep_time)
def close(self):
self.driver.close()
def close(self) -> None:
if self.driver is not None:
self.driver.close()

View File

@@ -22,11 +22,11 @@ class SerpApiBaseTool(BaseTool):
client: Any | None = None
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
try:
from serpapi import Client # type: ignore
from serpapi import Client
except ImportError:
import click
@@ -48,7 +48,9 @@ class SerpApiBaseTool(BaseTool):
)
self.client = Client(api_key=api_key)
def _omit_fields(self, data: dict | list, omit_patterns: list[str]) -> None:
def _omit_fields(
self, data: dict[str, Any] | list[Any], omit_patterns: list[str]
) -> None:
if isinstance(data, dict):
for field in list(data.keys()):
if any(re.compile(p).match(field) for p in omit_patterns):

View File

@@ -160,7 +160,7 @@ class SerperDevTool(BaseTool):
processed_results: list[OrganicResult] = []
for result in organic_results[: self.n_results]:
try:
result_data: OrganicResult = { # type: ignore[typeddict-item]
result_data: OrganicResult = {
"title": result["title"],
"link": result["link"],
"snippet": result.get("snippet", ""),
@@ -168,7 +168,7 @@ class SerperDevTool(BaseTool):
}
if "sitelinks" in result:
result_data["sitelinks"] = [ # type: ignore[typeddict-unknown-key]
result_data["sitelinks"] = [
{
"title": sitelink.get("title", ""),
"link": sitelink.get("link", ""),
@@ -180,7 +180,7 @@ class SerperDevTool(BaseTool):
except KeyError: # noqa: PERF203
logger.warning(f"Skipping malformed organic result: {result}")
continue
return processed_results # type: ignore[return-value]
return processed_results
def _process_people_also_ask(
self, paa_results: list[dict[str, Any]]
@@ -189,7 +189,7 @@ class SerperDevTool(BaseTool):
processed_results: list[PeopleAlsoAskResult] = []
for result in paa_results[: self.n_results]:
try:
result_data: PeopleAlsoAskResult = { # type: ignore[typeddict-item]
result_data: PeopleAlsoAskResult = {
"question": result["question"],
"snippet": result.get("snippet", ""),
"title": result.get("title", ""),
@@ -199,7 +199,7 @@ class SerperDevTool(BaseTool):
except KeyError: # noqa: PERF203
logger.warning(f"Skipping malformed PAA result: {result}")
continue
return processed_results # type: ignore[return-value]
return processed_results
def _process_related_searches(
self, related_results: list[dict[str, Any]]
@@ -208,11 +208,11 @@ class SerperDevTool(BaseTool):
processed_results: list[RelatedSearchResult] = []
for result in related_results[: self.n_results]:
try:
processed_results.append({"query": result["query"]}) # type: ignore[typeddict-item]
processed_results.append({"query": result["query"]})
except KeyError: # noqa: PERF203
logger.warning(f"Skipping malformed related search result: {result}")
continue
return processed_results # type: ignore[return-value]
return processed_results
def _process_news_results(
self, news_results: list[dict[str, Any]]
@@ -221,7 +221,7 @@ class SerperDevTool(BaseTool):
processed_results: list[NewsResult] = []
for result in news_results[: self.n_results]:
try:
result_data: NewsResult = { # type: ignore[typeddict-item]
result_data: NewsResult = {
"title": result["title"],
"link": result["link"],
"snippet": result.get("snippet", ""),
@@ -233,7 +233,7 @@ class SerperDevTool(BaseTool):
except KeyError: # noqa: PERF203
logger.warning(f"Skipping malformed news result: {result}")
continue
return processed_results # type: ignore[return-value]
return processed_results
def _make_api_request(self, search_query: str, search_type: str) -> dict[str, Any]:
"""Make API request to Serper."""
@@ -262,7 +262,7 @@ class SerperDevTool(BaseTool):
if not results:
logger.error("Empty response from Serper API")
raise ValueError("Empty response from Serper API")
return results
return dict(results)
except requests.exceptions.RequestException as e:
error_msg = f"Error making request to Serper API: {e}"
if response is not None and hasattr(response, "content"):

View File

@@ -53,7 +53,7 @@ class SerperScrapeWebsiteTool(BaseTool):
payload = json.dumps({"url": url, "includeMarkdown": include_markdown})
# Set headers
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
headers = {"X-API-KEY": api_key or "", "Content-Type": "application/json"}
# Make the API request
response = requests.post(
@@ -69,7 +69,7 @@ class SerperScrapeWebsiteTool(BaseTool):
# Extract the scraped content
if "text" in result:
return result["text"]
return str(result["text"])
return f"Successfully scraped {url}, but no text content found in response: {response.text}"
return (
f"Error scraping {url}: HTTP {response.status_code} - {response.text}"

View File

@@ -1,4 +1,5 @@
import os
from typing import Any
from urllib.parse import urlencode
from crewai.tools import EnvVar
@@ -29,7 +30,7 @@ class SerplyJobSearchTool(RagTool):
proxy_location: (str): Where to get jobs, specifically for a specific country results.
- Currently only supports US
"""
headers: dict | None = Field(default_factory=dict)
headers: dict[str, str] | None = Field(default_factory=dict)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -40,12 +41,12 @@ class SerplyJobSearchTool(RagTool):
]
)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.headers = {
"X-API-KEY": os.environ["SERPLY_API_KEY"],
"User-Agent": "crew-tools",
"X-Proxy-Location": self.proxy_location,
"X-Proxy-Location": self.proxy_location or "US",
}
def _run( # type: ignore[override]

View File

@@ -21,7 +21,7 @@ class SerplyNewsSearchTool(BaseTool):
args_schema: type[BaseModel] = SerplyNewsSearchToolSchema
search_url: str = "https://api.serply.io/v1/news/"
proxy_location: str | None = "US"
headers: dict | None = Field(default_factory=dict)
headers: dict[str, str] | None = Field(default_factory=dict)
limit: int | None = 10
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
@@ -34,8 +34,8 @@ class SerplyNewsSearchTool(BaseTool):
)
def __init__(
self, limit: int | None = 10, proxy_location: str | None = "US", **kwargs
):
self, limit: int | None = 10, proxy_location: str | None = "US", **kwargs: Any
) -> None:
"""param: limit (int): The maximum number of results to return [10-100, defaults to 10]
proxy_location: (str): Where to get news, specifically for a specific country results.
['US', 'CA', 'IE', 'GB', 'FR', 'DE', 'SE', 'IN', 'JP', 'KR', 'SG', 'AU', 'BR'] (defaults to US).
@@ -46,7 +46,7 @@ class SerplyNewsSearchTool(BaseTool):
self.headers = {
"X-API-KEY": os.environ["SERPLY_API_KEY"],
"User-Agent": "crew-tools",
"X-Proxy-Location": proxy_location,
"X-Proxy-Location": proxy_location or "US",
}
def _run(

View File

@@ -25,7 +25,7 @@ class SerplyScholarSearchTool(BaseTool):
search_url: str = "https://api.serply.io/v1/scholar/"
hl: str | None = "us"
proxy_location: str | None = "US"
headers: dict | None = Field(default_factory=dict)
headers: dict[str, str] | None = Field(default_factory=dict)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -36,7 +36,9 @@ class SerplyScholarSearchTool(BaseTool):
]
)
def __init__(self, hl: str = "us", proxy_location: str | None = "US", **kwargs):
def __init__(
self, hl: str = "us", proxy_location: str | None = "US", **kwargs: Any
) -> None:
"""param: hl (str): host Language code to display results in
(reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
proxy_location: (str): Specify the proxy location for the search, specifically for a specific country results.
@@ -48,7 +50,7 @@ class SerplyScholarSearchTool(BaseTool):
self.headers = {
"X-API-KEY": os.environ["SERPLY_API_KEY"],
"User-Agent": "crew-tools",
"X-Proxy-Location": proxy_location,
"X-Proxy-Location": proxy_location or "US",
}
def _run(

View File

@@ -24,8 +24,8 @@ class SerplyWebSearchTool(BaseTool):
limit: int | None = 10
device_type: str | None = "desktop"
proxy_location: str | None = "US"
query_payload: dict | None = Field(default_factory=dict)
headers: dict | None = Field(default_factory=dict)
query_payload: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, str] | None = Field(default_factory=dict)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -42,8 +42,8 @@ class SerplyWebSearchTool(BaseTool):
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
**kwargs,
):
**kwargs: Any,
) -> None:
"""param: query (str): The query to search for
param: hl (str): host Language code to display results in
(reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
try:
from singlestoredb import connect
from singlestoredb import connect # type: ignore[attr-defined]
from sqlalchemy.pool import QueuePool
SINGLSTORE_AVAILABLE = True
@@ -117,7 +117,7 @@ class SingleStoreSearchTool(BaseTool):
]
)
connection_args: dict = Field(default_factory=dict)
connection_args: dict[str, Any] = Field(default_factory=dict)
connection_pool: Any | None = None
def __init__(
@@ -169,8 +169,8 @@ class SingleStoreSearchTool(BaseTool):
pool_size: int | None = 5,
max_overflow: int | None = 10,
timeout: float | None = 30,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize the SingleStore search tool.
Args:
@@ -274,7 +274,7 @@ class SingleStoreSearchTool(BaseTool):
# Initialize connection pool for efficient connection management
self.connection_pool = QueuePool(
creator=self._create_connection, # type: ignore[arg-type]
creator=self._create_connection,
pool_size=pool_size or 5,
max_overflow=max_overflow or 10,
timeout=timeout or 30.0,

View File

@@ -12,10 +12,10 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr
if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import ( # type: ignore[import-not-found]
from snowflake.connector.connection import (
SnowflakeConnection,
)
from snowflake.connector.errors import ( # type: ignore[import-not-found]
from snowflake.connector.errors import (
DatabaseError,
OperationalError,
)
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
import snowflake.connector # type: ignore[import-not-found]
import snowflake.connector
SNOWFLAKE_AVAILABLE = True
except ImportError:
@@ -60,7 +60,7 @@ class SnowflakeConfig(BaseModel):
def has_auth(self) -> bool:
return bool(self.password or self.private_key_path)
def model_post_init(self, *args, **kwargs):
def model_post_init(self, *args: Any, **kwargs: Any) -> None:
if not self.has_auth:
raise ValueError("Either password or private_key_path must be provided")
@@ -115,7 +115,7 @@ class SnowflakeSearchTool(BaseTool):
]
)
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._initialize_snowflake()
@@ -268,7 +268,7 @@ class SnowflakeSearchTool(BaseTool):
logger.error(f"Error executing query: {e!s}")
raise
def __del__(self):
def __del__(self) -> None:
"""Cleanup connections on deletion."""
try:
if self._connection_pool:

View File

@@ -72,8 +72,8 @@ class SpiderTool(BaseTool):
website_url: str | None = None,
custom_params: dict[str, Any] | None = None,
log_failures: bool = True,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize SpiderTool for web scraping and crawling.
Args:
@@ -96,7 +96,7 @@ class SpiderTool(BaseTool):
self.custom_params = custom_params
try:
from spider import Spider # type: ignore
from spider import Spider
except ImportError:
import click
@@ -191,7 +191,8 @@ class SpiderTool(BaseTool):
action = (
self.spider.scrape_url if mode == "scrape" else self.spider.crawl_url
)
return action(url=url, params=params)
result: str | None = action(url=url, params=params)
return result
except ValueError as ve:
if self.log_failures:

View File

@@ -20,7 +20,7 @@ import os
from crewai import Agent, Crew, Process, Task
from crewai.utilities.printer import Printer
from dotenv import load_dotenv
from stagehand.schemas import AvailableModel # type: ignore[import-untyped]
from stagehand.schemas import AvailableModel # type: ignore[import-not-found]
from crewai_tools import StagehandTool

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import contextvars
import json
@@ -13,13 +15,13 @@ from pydantic import BaseModel, Field
_HAS_STAGEHAND = False
try:
from stagehand import ( # type: ignore[import-untyped]
from stagehand import ( # type: ignore[attr-defined]
Stagehand,
StagehandConfig,
StagehandPage,
configure_logging,
)
from stagehand.schemas import ( # type: ignore[import-untyped]
from stagehand.schemas import ( # type: ignore[import-not-found]
ActOptions,
AvailableModel,
ExtractOptions,
@@ -28,8 +30,7 @@ try:
_HAS_STAGEHAND = True
except ImportError:
# Define type stubs for when stagehand is not installed
Stagehand = Any
Stagehand = Any # type: ignore[assignment, misc]
StagehandPage = Any
StagehandConfig = Any
ActOptions = Any
@@ -37,7 +38,11 @@ except ImportError:
ObserveOptions = Any
# Mock configure_logging function
def configure_logging(level=None, remove_logger_name=None, quiet_dependencies=None):
def configure_logging(
level: str | None = None,
remove_logger_name: bool | None = None,
quiet_dependencies: bool | None = None,
) -> None:
pass
# Define only what's needed for class defaults
@@ -57,7 +62,7 @@ class StagehandResult(BaseModel):
success: bool = Field(
..., description="Whether the operation completed successfully"
)
data: str | dict | list = Field(
data: str | dict[str, Any] | list[Any] = Field(
..., description="The result data from the operation"
)
error: str | None = Field(
@@ -160,7 +165,7 @@ class StagehandTool(BaseTool):
api_key: str | None = None
project_id: str | None = None
model_api_key: str | None = None
model_name: AvailableModel | None = AvailableModel.CLAUDE_3_7_SONNET_LATEST
model_name: Any = AvailableModel.CLAUDE_3_7_SONNET_LATEST
server_url: str | None = "https://api.stagehand.browserbase.com/v1"
headless: bool = False
dom_settle_timeout_ms: int = 3000
@@ -173,8 +178,8 @@ class StagehandTool(BaseTool):
use_simplified_dom: bool = True
# Instance variables
_stagehand: Stagehand | None = None
_page: StagehandPage | None = None
_stagehand: Any = None
_page: Any = None
_session_id: str | None = None
_testing: bool = False
@@ -192,8 +197,8 @@ class StagehandTool(BaseTool):
wait_for_captcha_solves: bool | None = None,
verbose: int | None = None,
_testing: bool = False,
**kwargs,
):
**kwargs: Any,
) -> None:
# Set testing flag early so that other init logic can rely on it
self._testing = _testing
super().__init__(**kwargs)
@@ -235,7 +240,7 @@ class StagehandTool(BaseTool):
self._check_required_credentials()
def _check_required_credentials(self):
def _check_required_credentials(self) -> None:
"""Validate that required credentials are present."""
if not self._testing and not _HAS_STAGEHAND:
raise ImportError(
@@ -249,14 +254,14 @@ class StagehandTool(BaseTool):
"project_id is required (or set BROWSERBASE_PROJECT_ID in env)."
)
def __del__(self):
def __del__(self) -> None:
"""Ensure cleanup on deletion."""
try:
self.close()
except Exception: # noqa: S110
pass
def _get_model_api_key(self):
def _get_model_api_key(self) -> str | None:
"""Get the appropriate API key based on the model being used."""
# Check model type and get appropriate key
model_str = str(self.model_name)
@@ -273,29 +278,29 @@ class StagehandTool(BaseTool):
or os.getenv("ANTHROPIC_API_KEY")
)
async def _setup_stagehand(self, session_id: str | None = None):
async def _setup_stagehand(self, session_id: str | None = None) -> tuple[Any, Any]:
"""Initialize Stagehand if not already set up."""
# If we're in testing mode, return mock objects
if self._testing:
if not self._stagehand:
# Create mock objects for testing
class MockPage:
async def act(self, options):
async def act(self, options: Any) -> Any:
mock_result = type("MockResult", (), {})()
mock_result.model_dump = lambda: {
"message": "Action completed successfully"
}
return mock_result
async def goto(self, url):
async def goto(self, url: str) -> None:
return None
async def extract(self, options):
async def extract(self, options: Any) -> Any:
mock_result = type("MockResult", (), {})()
mock_result.model_dump = lambda: {"data": "Extracted content"}
return mock_result
async def observe(self, options):
async def observe(self, options: Any) -> list[Any]:
mock_result1 = type(
"MockResult",
(),
@@ -303,18 +308,18 @@ class StagehandTool(BaseTool):
)()
return [mock_result1]
async def wait_for_load_state(self, state):
async def wait_for_load_state(self, state: str) -> None:
return None
class MockStagehand:
def __init__(self):
def __init__(self) -> None:
self.page = MockPage()
self.session_id = "test-session-id"
async def init(self):
async def init(self) -> None:
return None
async def close(self):
async def close(self) -> None:
return None
self._stagehand = MockStagehand()
@@ -352,7 +357,7 @@ class StagehandTool(BaseTool):
)
# Initialize Stagehand with config
self._stagehand = Stagehand(config=config)
self._stagehand = Stagehand(config=config) # type: ignore[call-arg]
# Initialize the Stagehand instance
await self._stagehand.init()
@@ -404,7 +409,7 @@ class StagehandTool(BaseTool):
instruction: str | None = None,
url: str | None = None,
command_type: str = "act",
):
) -> StagehandResult:
"""Override _async_run with improved atomic action handling."""
# Handle missing instruction based on command type
if not instruction:
@@ -419,7 +424,7 @@ class StagehandTool(BaseTool):
# For testing mode, return mock result directly without calling parent
if self._testing:
mock_data = {
mock_data: dict[str, str] = {
"message": f"Mock {command_type} completed successfully",
"instruction": instruction,
}
@@ -436,7 +441,7 @@ class StagehandTool(BaseTool):
# Get the API key to pass to model operations
model_api_key = self._get_model_api_key()
model_client_options = {"apiKey": model_api_key}
model_client_options: dict[str, Any] = {"apiKey": model_api_key}
# Always navigate first if URL is provided and we're doing actions
if url and command_type.lower() == "act":
@@ -452,7 +457,7 @@ class StagehandTool(BaseTool):
steps = self._extract_steps(instruction)
self._logger.info(f"Extracted {len(steps)} steps: {steps}")
results = []
results: list[dict[str, Any]] = []
for i, step in enumerate(steps):
self._logger.info(f"Executing step {i + 1}/{len(steps)}: {step}")
@@ -559,16 +564,16 @@ class StagehandTool(BaseTool):
modelClientOptions=model_client_options, # Add API key here
)
results = await page.observe(observe_options)
observe_results = await page.observe(observe_options)
# Format the observation results
formatted_results = []
for i, result in enumerate(results):
formatted_results: list[dict[str, Any]] = []
for i, obs_result in enumerate(observe_results):
formatted_results.append(
{
"index": i + 1,
"description": result.description,
"method": result.method,
"description": obs_result.description,
"method": obs_result.method,
}
)
@@ -586,7 +591,12 @@ class StagehandTool(BaseTool):
self._logger.error(f"Operation failed: {error_msg}")
return self._format_result(False, {}, error_msg)
def _format_result(self, success, data, error=None):
def _format_result(
self,
success: bool,
data: str | dict[str, Any] | list[Any],
error: str | None = None,
) -> StagehandResult:
"""Helper to format results consistently."""
return StagehandResult(success=success, data=data, error=error)
@@ -653,10 +663,14 @@ class StagehandTool(BaseTool):
f"Step {i + 1}: {step.get('message', 'Completed')}"
)
return "\n".join(step_messages)
return f"Action result: {result.data.get('message', 'Completed')}"
if isinstance(result.data, dict):
return (
f"Action result: {result.data.get('message', 'Completed')}"
)
return f"Action result: {result.data}"
if command_type.lower() == "extract":
return f"Extracted data: {json.dumps(result.data, indent=2)}"
if command_type.lower() == "observe":
if command_type.lower() == "observe" and isinstance(result.data, list):
formatted_results = []
for element in result.data:
formatted_results.append(
@@ -680,7 +694,7 @@ class StagehandTool(BaseTool):
return str(result.data)
return f"Error: {result.error}"
async def _async_close(self):
async def _async_close(self) -> None:
"""Asynchronously clean up Stagehand resources."""
# Skip for test mode
if self._testing:
@@ -694,7 +708,7 @@ class StagehandTool(BaseTool):
if self._page:
self._page = None
def close(self):
def close(self) -> None:
"""Clean up Stagehand resources."""
# Skip actual closing for testing mode
if self._testing:
@@ -704,9 +718,9 @@ class StagehandTool(BaseTool):
if self._stagehand:
try:
# Handle both synchronous and asynchronous cases
if hasattr(self._stagehand, "close"):
if asyncio.iscoroutinefunction(self._stagehand.close):
close_method: Any = getattr(self._stagehand, "close", None)
if close_method is not None:
if asyncio.iscoroutinefunction(close_method):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
@@ -725,8 +739,7 @@ class StagehandTool(BaseTool):
except RuntimeError:
asyncio.run(self._async_close())
else:
# Handle non-async close method (for mocks)
self._stagehand.close()
close_method()
except Exception: # noqa: S110
# Log but don't raise - we're cleaning up
pass
@@ -736,10 +749,15 @@ class StagehandTool(BaseTool):
if self._page:
self._page = None
def __enter__(self):
def __enter__(self) -> StagehandTool:
"""Enter the context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Exit the context manager and clean up resources."""
self.close()

Some files were not shown because too many files have changed in this diff Show More