mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
fix: resolve all strict mypy errors across crewai-tools package
This commit is contained in:
@@ -136,7 +136,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
enum_values = schema["enum"]
|
||||
if not enum_values:
|
||||
return self._map_json_type_to_python(json_type)
|
||||
return Literal[tuple(enum_values)] # type: ignore[return-value]
|
||||
return Literal[tuple(enum_values)]
|
||||
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
@@ -155,7 +155,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
if full_model_name in self._model_registry:
|
||||
return self._model_registry[full_model_name]
|
||||
return cast(type[Any], self._model_registry[full_model_name])
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = schema.get("required", [])
|
||||
@@ -178,19 +178,19 @@ class EnterpriseActionTool(BaseTool):
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
prop_type,
|
||||
is_required,
|
||||
prop_desc, # type: ignore[arg-type]
|
||||
prop_desc,
|
||||
)
|
||||
|
||||
try:
|
||||
nested_model = create_model(full_model_name, **field_definitions) # type: ignore[call-overload]
|
||||
self._model_registry[full_model_name] = nested_model
|
||||
return nested_model
|
||||
return cast(type[Any], nested_model)
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: type[Any] | _SpecialForm, is_required: bool, description: str
|
||||
) -> tuple:
|
||||
) -> tuple[type[Any] | _SpecialForm, Any]:
|
||||
"""Create Pydantic field definition based on type and requirement."""
|
||||
if is_required:
|
||||
return (field_type, Field(description=description))
|
||||
@@ -232,7 +232,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
return schema.get("type") == "null"
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the specific enterprise action with validated parameters."""
|
||||
try:
|
||||
cleaned_kwargs = {}
|
||||
@@ -280,8 +280,8 @@ class EnterpriseActionKitToolAdapter:
|
||||
):
|
||||
"""Initialize the adapter with an enterprise action token."""
|
||||
self._set_enterprise_action_token(enterprise_action_token)
|
||||
self._actions_schema = {} # type: ignore[var-annotated]
|
||||
self._tools = None
|
||||
self._actions_schema: dict[str, Any] = {}
|
||||
self._tools: list[BaseTool] | None = None
|
||||
self.enterprise_api_base_url = (
|
||||
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
)
|
||||
@@ -293,7 +293,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
self._create_tools()
|
||||
return self._tools or []
|
||||
|
||||
def _fetch_actions(self):
|
||||
def _fetch_actions(self) -> None:
|
||||
"""Fetch available actions from the API."""
|
||||
try:
|
||||
actions_url = f"{self.enterprise_api_base_url}/actions"
|
||||
@@ -379,9 +379,9 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
return descriptions
|
||||
|
||||
def _create_tools(self):
|
||||
def _create_tools(self) -> None:
|
||||
"""Create BaseTool instances for each action."""
|
||||
tools = []
|
||||
tools: list[BaseTool] = []
|
||||
|
||||
for action_name, action_schema in self._actions_schema.items():
|
||||
function_details = action_schema.get("function", {})
|
||||
@@ -403,7 +403,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
description=full_description,
|
||||
action_name=action_name,
|
||||
action_schema=action_schema,
|
||||
enterprise_action_token=self.enterprise_action_token,
|
||||
enterprise_action_token=self.enterprise_action_token or "",
|
||||
enterprise_api_base_url=self.enterprise_api_base_url,
|
||||
)
|
||||
|
||||
@@ -411,7 +411,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
|
||||
self._tools = tools
|
||||
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: str | None):
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: str | None) -> None:
|
||||
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
|
||||
warnings.warn(
|
||||
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations.",
|
||||
@@ -423,10 +423,15 @@ class EnterpriseActionKitToolAdapter:
|
||||
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
|
||||
)
|
||||
|
||||
self.enterprise_action_token = token
|
||||
self.enterprise_action_token: str | None = token
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> list[BaseTool]:
|
||||
return self.tools()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -5,20 +5,19 @@ from typing import Any
|
||||
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
from lancedb import ( # type: ignore[import-untyped]
|
||||
DBConnection as LanceDBConnection,
|
||||
connect as lancedb_connect,
|
||||
)
|
||||
from lancedb.table import Table as LanceDBTable # type: ignore[import-untyped]
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
def _default_embedding_function() -> Callable[[list[str]], list[list[float]]]:
|
||||
"""Create a default embedding function using OpenAI."""
|
||||
client = OpenAIClient()
|
||||
|
||||
def _embedding_function(input):
|
||||
def _embedding_function(input: list[str]) -> list[list[float]]:
|
||||
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
|
||||
return [record.embedding for record in rs.data]
|
||||
|
||||
@@ -28,13 +27,15 @@ def _default_embedding_function():
|
||||
class LanceDBAdapter(Adapter):
|
||||
uri: str | Path
|
||||
table_name: str
|
||||
embedding_function: Callable = Field(default_factory=_default_embedding_function)
|
||||
embedding_function: Callable[[list[str]], list[list[float]]] = Field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
top_k: int = 3
|
||||
vector_column_name: str = "vector"
|
||||
text_column_name: str = "text"
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
_db: Any = PrivateAttr()
|
||||
_table: Any = PrivateAttr()
|
||||
_lock_name: str = PrivateAttr(default="")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
||||
@@ -12,7 +12,7 @@ class RAGAdapter(Adapter):
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
top_k: int = 5,
|
||||
embedding_api_key: str | None = None,
|
||||
**embedding_kwargs,
|
||||
**embedding_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from crewai.tools import BaseTool
|
||||
T = TypeVar("T", bound=BaseTool)
|
||||
|
||||
|
||||
class ToolCollection(list, Generic[T]):
|
||||
class ToolCollection(list[T], Generic[T]):
|
||||
"""A collection of tools that can be accessed by index or name.
|
||||
|
||||
This class extends the built-in list to provide dictionary-like
|
||||
@@ -34,7 +34,8 @@ class ToolCollection(list, Generic[T]):
|
||||
def __getitem__(self, key: int | str) -> T: # type: ignore[override]
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key.lower()]
|
||||
return super().__getitem__(key)
|
||||
result: T = super().__getitem__(key)
|
||||
return result
|
||||
|
||||
def append(self, tool: T) -> None:
|
||||
super().append(tool)
|
||||
@@ -54,7 +55,7 @@ class ToolCollection(list, Generic[T]):
|
||||
del self._name_cache[tool.name.lower()]
|
||||
|
||||
def pop(self, index: int = -1) -> T: # type: ignore[override]
|
||||
tool = super().pop(index)
|
||||
tool: T = super().pop(index)
|
||||
if tool.name.lower() in self._name_cache:
|
||||
del self._name_cache[tool.name.lower()]
|
||||
return tool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Final, Literal
|
||||
from typing import Any, Final, Literal
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import Field, create_model
|
||||
@@ -22,7 +22,7 @@ class ZapierActionTool(BaseTool):
|
||||
action_id: str = Field(description="Zapier action ID")
|
||||
api_key: str = Field(description="Zapier API key")
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> Any:
|
||||
"""Execute the Zapier action."""
|
||||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
||||
|
||||
@@ -64,9 +64,9 @@ class ZapierActionsAdapter:
|
||||
logger.error("Zapier Actions API key is required")
|
||||
raise ValueError("Zapier Actions API key is required")
|
||||
|
||||
def get_zapier_actions(self):
|
||||
def get_zapier_actions(self) -> Any:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"x-api-key": self.api_key or "",
|
||||
}
|
||||
response = requests.request(
|
||||
"GET",
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
@@ -42,17 +43,17 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
enable_trace: bool = False,
|
||||
end_session: bool = False,
|
||||
description: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
||||
|
||||
Args:
|
||||
agent_id (str): The unique identifier of the Bedrock agent
|
||||
agent_alias_id (str): The unique identifier of the agent alias
|
||||
session_id (str): The unique identifier of the session
|
||||
enable_trace (bool): Whether to enable trace for the agent invocation
|
||||
end_session (bool): Whether to end the session with the agent
|
||||
description (Optional[str]): Custom description for the tool
|
||||
agent_id: The unique identifier of the Bedrock agent.
|
||||
agent_alias_id: The unique identifier of the agent alias.
|
||||
session_id: The unique identifier of the session.
|
||||
enable_trace: Whether to enable trace for the agent invocation.
|
||||
end_session: Whether to end the session with the agent.
|
||||
description: Custom description for the tool.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -72,7 +73,7 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate agent_id
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -82,7 +82,7 @@ class CurrentWebPageToolInput(BaseModel):
|
||||
class BrowserBaseTool(BaseTool):
|
||||
"""Base class for browser tools."""
|
||||
|
||||
def __init__(self, session_manager: BrowserSessionManager): # type: ignore[call-arg]
|
||||
def __init__(self, session_manager: BrowserSessionManager) -> None:
|
||||
"""Initialize with a session manager."""
|
||||
super().__init__() # type: ignore[call-arg]
|
||||
self._session_manager = session_manager
|
||||
@@ -90,16 +90,16 @@ class BrowserBaseTool(BaseTool):
|
||||
if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
|
||||
self._original_run = self._run
|
||||
|
||||
# Override _run to use _arun when in an asyncio loop
|
||||
def patched_run(*args, **kwargs):
|
||||
def patched_run(*args: Any, **kwargs: Any) -> str:
|
||||
try:
|
||||
import nest_asyncio # type: ignore[import-untyped]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
nest_asyncio.apply(loop)
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
result: str = asyncio.get_event_loop().run_until_complete(
|
||||
self._arun(*args, **kwargs)
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error in patched _run: {e!s}"
|
||||
|
||||
@@ -132,7 +132,7 @@ class NavigateTool(BrowserBaseTool):
|
||||
description: str = "Navigate a browser to the specified URL"
|
||||
args_schema: type[BaseModel] = NavigateToolInput
|
||||
|
||||
def _run(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get page for this thread
|
||||
@@ -150,7 +150,7 @@ class NavigateTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error navigating to {url}: {e!s}"
|
||||
|
||||
async def _arun(self, url: str, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(self, url: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get page for this thread
|
||||
@@ -188,7 +188,7 @@ class ClickTool(BrowserBaseTool):
|
||||
return selector
|
||||
return f"{selector} >> visible=1"
|
||||
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -213,7 +213,9 @@ class ClickTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error clicking on element: {e!s}"
|
||||
|
||||
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(
|
||||
self, selector: str, thread_id: str = "default", **kwargs: Any
|
||||
) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -246,7 +248,7 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
description: str = "Navigate back to the previous page"
|
||||
args_schema: type[BaseModel] = NavigateBackToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -261,7 +263,7 @@ class NavigateBackTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error navigating back: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -284,7 +286,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
description: str = "Extract all the text on the current webpage"
|
||||
args_schema: type[BaseModel] = ExtractTextToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
@@ -306,7 +308,7 @@ class ExtractTextTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error extracting text: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
@@ -336,12 +338,12 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
description: str = "Extract all hyperlinks on the current webpage"
|
||||
args_schema: type[BaseModel] = ExtractHyperlinksToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
except ImportError:
|
||||
return (
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
@@ -356,9 +358,10 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
links = []
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith(("http", "https")): # type: ignore[union-attr]
|
||||
tag = cast(Tag, link)
|
||||
text = tag.get_text().strip()
|
||||
href = str(tag.get("href", ""))
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
@@ -368,12 +371,12 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error extracting hyperlinks: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Import BeautifulSoup
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
except ImportError:
|
||||
return (
|
||||
"The 'beautifulsoup4' package is required to use this tool."
|
||||
@@ -388,9 +391,10 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
links = []
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith(("http", "https")): # type: ignore[union-attr]
|
||||
tag = cast(Tag, link)
|
||||
text = tag.get_text().strip()
|
||||
href = str(tag.get("href", ""))
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
@@ -408,7 +412,7 @@ class GetElementsTool(BrowserBaseTool):
|
||||
description: str = "Get elements from the webpage using a CSS selector"
|
||||
args_schema: type[BaseModel] = GetElementsToolInput
|
||||
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, selector: str, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -428,7 +432,9 @@ class GetElementsTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error getting elements: {e!s}"
|
||||
|
||||
async def _arun(self, selector: str, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(
|
||||
self, selector: str, thread_id: str = "default", **kwargs: Any
|
||||
) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -456,7 +462,7 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
description: str = "Get information about the current webpage"
|
||||
args_schema: type[BaseModel] = CurrentWebPageToolInput
|
||||
|
||||
def _run(self, thread_id: str = "default", **kwargs) -> str:
|
||||
def _run(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the sync tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -469,7 +475,7 @@ class CurrentWebPageTool(BrowserBaseTool):
|
||||
except Exception as e:
|
||||
return f"Error getting current webpage info: {e!s}"
|
||||
|
||||
async def _arun(self, thread_id: str = "default", **kwargs) -> str:
|
||||
async def _arun(self, thread_id: str = "default", **kwargs: Any) -> str:
|
||||
"""Use the async tool."""
|
||||
try:
|
||||
# Get the current page
|
||||
@@ -535,7 +541,7 @@ class BrowserToolkit:
|
||||
self._nest_current_loop()
|
||||
self._setup_tools()
|
||||
|
||||
def _nest_current_loop(self):
|
||||
def _nest_current_loop(self) -> None:
|
||||
"""Apply nest_asyncio if we're in an asyncio loop."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_output_from_stream(response):
|
||||
def extract_output_from_stream(response: dict[str, Any]) -> str:
|
||||
"""Extract output from code interpreter response stream.
|
||||
|
||||
Args:
|
||||
@@ -143,8 +143,8 @@ class ExecuteCodeTool(BaseTool):
|
||||
args_schema: type[BaseModel] = ExecuteCodeInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(
|
||||
@@ -198,8 +198,8 @@ class ExecuteCommandTool(BaseTool):
|
||||
args_schema: type[BaseModel] = ExecuteCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
@@ -231,8 +231,8 @@ class ReadFilesTool(BaseTool):
|
||||
args_schema: type[BaseModel] = ReadFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
@@ -264,8 +264,8 @@ class ListFilesTool(BaseTool):
|
||||
args_schema: type[BaseModel] = ListFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, directory_path: str = "", thread_id: str = "default") -> str:
|
||||
@@ -297,8 +297,8 @@ class DeleteFilesTool(BaseTool):
|
||||
args_schema: type[BaseModel] = DeleteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, paths: list[str], thread_id: str = "default") -> str:
|
||||
@@ -330,8 +330,8 @@ class WriteFilesTool(BaseTool):
|
||||
args_schema: type[BaseModel] = WriteFilesInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, files: list[dict[str, str]], thread_id: str = "default") -> str:
|
||||
@@ -365,8 +365,8 @@ class StartCommandTool(BaseTool):
|
||||
args_schema: type[BaseModel] = StartCommandInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, command: str, thread_id: str = "default") -> str:
|
||||
@@ -398,8 +398,8 @@ class GetTaskTool(BaseTool):
|
||||
args_schema: type[BaseModel] = GetTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
@@ -431,8 +431,8 @@ class StopTaskTool(BaseTool):
|
||||
args_schema: type[BaseModel] = StopTaskInput
|
||||
toolkit: Any = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, toolkit):
|
||||
super().__init__()
|
||||
def __init__(self, toolkit: CodeInterpreterToolkit, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.toolkit = toolkit
|
||||
|
||||
def _run(self, task_id: str, thread_id: str = "default") -> str:
|
||||
|
||||
@@ -44,16 +44,16 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
retrieval_configuration: dict[str, Any] | None = None,
|
||||
guardrail_configuration: dict[str, Any] | None = None,
|
||||
next_token: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||
|
||||
Args:
|
||||
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
||||
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
|
||||
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
|
||||
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
|
||||
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
|
||||
knowledge_base_id: The unique identifier of the knowledge base to query.
|
||||
number_of_results: The maximum number of results to return.
|
||||
retrieval_configuration: Configurations for the knowledge base query and retrieval process.
|
||||
guardrail_configuration: Guardrail settings.
|
||||
next_token: Token for retrieving the next batch of results.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -89,7 +89,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
|
||||
return {"vectorSearchConfiguration": vector_search_config}
|
||||
|
||||
def _validate_parameters(self):
|
||||
def _validate_parameters(self) -> None:
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate knowledge_base_id
|
||||
|
||||
@@ -39,11 +39,12 @@ class S3ReaderTool(BaseTool):
|
||||
|
||||
# Read file content from S3
|
||||
response = s3.get_object(Bucket=bucket_name, Key=object_key)
|
||||
return response["Body"].read().decode("utf-8")
|
||||
result: str = response["Body"].read().decode("utf-8")
|
||||
return result
|
||||
|
||||
except ClientError as e:
|
||||
return f"Error reading file from S3: {e!s}"
|
||||
|
||||
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||
def _parse_s3_path(self, file_path: str) -> tuple[str, str]:
|
||||
parts = file_path.replace("s3://", "").split("/", 1)
|
||||
return parts[0], parts[1]
|
||||
|
||||
@@ -45,6 +45,6 @@ class S3WriterTool(BaseTool):
|
||||
except ClientError as e:
|
||||
return f"Error writing file to S3: {e!s}"
|
||||
|
||||
def _parse_s3_path(self, file_path: str) -> tuple:
|
||||
def _parse_s3_path(self, file_path: str) -> tuple[str, str]:
|
||||
parts = file_path.replace("s3://", "").split("/", 1)
|
||||
return parts[0], parts[1]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from pydantic import BaseModel
|
||||
@@ -115,7 +115,8 @@ class ToolSpecExtractor:
|
||||
default_value = field.default
|
||||
if default_value is PydanticUndefined or default_value is None:
|
||||
if field.default_factory:
|
||||
return field.default_factory()
|
||||
factory = cast(Callable[[], Any], field.default_factory)
|
||||
return factory()
|
||||
return None
|
||||
|
||||
return default_value
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai_tools.rag.core import RAG, EmbeddingService
|
||||
from crewai_tools.rag.core import RAG
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.embedding_service import EmbeddingService
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -21,7 +21,7 @@ class BaseLoader(ABC):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
|
||||
def load(self, content: SourceContent, **kwargs: Any) -> LoaderResult: ...
|
||||
|
||||
@staticmethod
|
||||
def generate_doc_id(
|
||||
|
||||
@@ -77,7 +77,7 @@ class RAG(Adapter):
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
content: str | Path,
|
||||
data_type: str | DataType | None = None,
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -81,7 +82,7 @@ class EmbeddingService:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._embedding_function = None
|
||||
self._embedding_function: EmbeddingFunction[Any] | None = None
|
||||
self._initialize_embedding_function()
|
||||
|
||||
@staticmethod
|
||||
@@ -107,7 +108,7 @@ class EmbeddingService:
|
||||
return os.getenv(env_key)
|
||||
return None
|
||||
|
||||
def _initialize_embedding_function(self):
|
||||
def _initialize_embedding_function(self) -> None:
|
||||
"""Initialize the embedding function using CrewAI's factory."""
|
||||
try:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
@@ -264,7 +265,7 @@ class EmbeddingService:
|
||||
try:
|
||||
# Use ChromaDB's embedding function interface
|
||||
embeddings = self._embedding_function([text]) # type: ignore
|
||||
return embeddings[0] if embeddings else []
|
||||
return list(embeddings[0]) if embeddings else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding for text: {e}")
|
||||
@@ -294,12 +295,12 @@ class EmbeddingService:
|
||||
|
||||
try:
|
||||
# Process in batches to avoid API limits
|
||||
all_embeddings = []
|
||||
all_embeddings: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(valid_texts), self.config.batch_size):
|
||||
batch = valid_texts[i : i + self.config.batch_size]
|
||||
batch_embeddings = self._embedding_function(batch) # type: ignore
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
all_embeddings.extend(list(e) for e in batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
from io import StringIO
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
@@ -7,7 +8,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class CSVLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
content_str = source_content.source
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DirectoryLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and process all files from a directory recursively.
|
||||
|
||||
Args:
|
||||
@@ -32,7 +33,7 @@ class DirectoryLoader(BaseLoader):
|
||||
|
||||
return self._process_directory(source_ref, kwargs)
|
||||
|
||||
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
|
||||
def _process_directory(self, dir_path: str, kwargs: dict[str, Any]) -> LoaderResult:
|
||||
recursive: bool = kwargs.get("recursive", True)
|
||||
include_extensions: list[str] | None = kwargs.get("include_extensions", None)
|
||||
exclude_extensions: list[str] | None = kwargs.get("exclude_extensions", None)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Documentation site loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
import requests
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
@@ -12,7 +13,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
"""Loader for documentation websites."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load content from a documentation site.
|
||||
|
||||
Args:
|
||||
@@ -53,7 +54,9 @@ class DocsSiteLoader(BaseLoader):
|
||||
break
|
||||
|
||||
if not main_content:
|
||||
main_content = soup.find("body")
|
||||
body = soup.find("body")
|
||||
if isinstance(body, Tag):
|
||||
main_content = body
|
||||
|
||||
if not main_content:
|
||||
raise ValueError(
|
||||
@@ -66,6 +69,8 @@ class DocsSiteLoader(BaseLoader):
|
||||
if headings:
|
||||
text_parts.append("Table of Contents:")
|
||||
for heading in headings[:15]:
|
||||
if not isinstance(heading, Tag):
|
||||
continue
|
||||
level = int(heading.name[1])
|
||||
indent = " " * (level - 1)
|
||||
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||
@@ -81,6 +86,8 @@ class DocsSiteLoader(BaseLoader):
|
||||
if nav:
|
||||
links = nav.find_all("a", href=True)
|
||||
for link in links[:20]:
|
||||
if not isinstance(link, Tag):
|
||||
continue
|
||||
href = link.get("href", "")
|
||||
if isinstance(href, str) and not href.startswith(
|
||||
("http://", "https://", "mailto:", "#")
|
||||
|
||||
@@ -9,7 +9,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DOCXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError as e:
|
||||
@@ -33,7 +33,7 @@ class DOCXLoader(BaseLoader):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _download_from_url(url: str, kwargs: dict) -> str:
|
||||
def _download_from_url(url: str, kwargs: dict[str, Any]) -> str:
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""GitHub repository content loader."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from github import Github, GithubException
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
@@ -9,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Loader for GitHub repository content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load content from a GitHub repository.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
@@ -6,7 +7,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
@@ -15,7 +15,7 @@ _EXTRA_NEWLINES_PATTERN: Final[re.Pattern[str]] = re.compile(r"\n\s*\n\s*\n")
|
||||
|
||||
|
||||
class MDXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""PDF loader for extracting text from PDF files."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -25,7 +25,7 @@ class PDFLoader(BaseLoader):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_from_url(url: str, kwargs: dict) -> str:
|
||||
def _download_from_url(url: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Download PDF from a URL to a temporary file and return its path.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""PostgreSQL database loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from psycopg2 import Error, connect
|
||||
@@ -12,7 +13,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TextFileLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
source_ref = source_content.source_ref
|
||||
if not source_content.path_exists():
|
||||
raise FileNotFoundError(
|
||||
@@ -21,7 +23,7 @@ class TextFileLoader(BaseLoader):
|
||||
|
||||
|
||||
class TextLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
return LoaderResult(
|
||||
content=source_content.source,
|
||||
source=source_content.source_ref,
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Utility functions for RAG loaders."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def load_from_url(
|
||||
url: str, kwargs: dict, accept_header: str = "*/*", loader_name: str = "Loader"
|
||||
url: str,
|
||||
kwargs: dict[str, Any],
|
||||
accept_header: str = "*/*",
|
||||
loader_name: str = "Loader",
|
||||
) -> str:
|
||||
"""Load content from a URL.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
import requests
|
||||
@@ -13,7 +13,7 @@ _NEWLINE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\s+\n\s+")
|
||||
|
||||
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
url = source_content.source
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
|
||||
@@ -10,7 +10,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for YouTube channels."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract content from a YouTube channel.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
"""Loader for YouTube videos."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract transcript from a YouTube video.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -42,7 +42,7 @@ class AIMindTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||
if not self.api_key:
|
||||
@@ -51,8 +51,10 @@ class AIMindTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
from minds.client import Client # type: ignore
|
||||
from minds.datasources import DatabaseConfig # type: ignore
|
||||
from minds.client import Client # type: ignore[import-not-found]
|
||||
from minds.datasources import ( # type: ignore[import-not-found]
|
||||
DatabaseConfig,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`minds_sdk` package not found, please run `pip install minds-sdk`"
|
||||
@@ -81,7 +83,7 @@ class AIMindTool(BaseTool):
|
||||
|
||||
self.mind_name = mind.name
|
||||
|
||||
def _run(self, query: str):
|
||||
def _run(self, query: str) -> str | None:
|
||||
# Run the query on the AI-Mind.
|
||||
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||
openai_client = OpenAI(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import time
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
@@ -75,7 +75,9 @@ class ArxivPaperTool(BaseTool):
|
||||
logger.error(f"ArxivTool Error: {e!s}")
|
||||
return f"Failed to fetch or download Arxiv papers: {e!s}"
|
||||
|
||||
def fetch_arxiv_data(self, search_query: str, max_results: int) -> list[dict]:
|
||||
def fetch_arxiv_data(
|
||||
self, search_query: str, max_results: int
|
||||
) -> list[dict[str, Any]]:
|
||||
api_url = f"{self.BASE_API_URL}?search_query={urllib.parse.quote(search_query)}&start=0&max_results={max_results}"
|
||||
logger.info(f"Fetching data from Arxiv API: {api_url}")
|
||||
|
||||
@@ -135,7 +137,7 @@ class ArxivPaperTool(BaseTool):
|
||||
return href
|
||||
return None
|
||||
|
||||
def _format_paper_result(self, paper: dict) -> str:
|
||||
def _format_paper_result(self, paper: dict[str, Any]) -> str:
|
||||
summary = (
|
||||
(paper["summary"][: self.SUMMARY_TRUNCATE_LENGTH] + "...")
|
||||
if len(paper["summary"]) > self.SUMMARY_TRUNCATE_LENGTH
|
||||
@@ -156,7 +158,7 @@ class ArxivPaperTool(BaseTool):
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
return save_path
|
||||
|
||||
def download_pdf(self, pdf_url: str, save_path: str):
|
||||
def download_pdf(self, pdf_url: str, save_path: str) -> None:
|
||||
try:
|
||||
logger.info(f"Downloading PDF from {pdf_url} to {save_path}")
|
||||
urllib.request.urlretrieve(pdf_url, str(save_path)) # noqa: S310
|
||||
|
||||
@@ -138,7 +138,7 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
self._rate_limit_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
def api_key(self) -> str | None:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
@@ -214,7 +214,8 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
# Response was OK, return the JSON body
|
||||
if resp.ok:
|
||||
try:
|
||||
return resp.json()
|
||||
result: dict[str, Any] = resp.json()
|
||||
return result
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(
|
||||
f"Brave Search API returned invalid JSON (HTTP {resp.status_code}): {exc}"
|
||||
@@ -239,9 +240,9 @@ class BraveSearchToolBase(BaseTool, ABC):
|
||||
# (e.g., 422 Unprocessable Entity, 400 Bad Request (OPTION_NOT_IN_PLAN))
|
||||
_raise_for_error(resp)
|
||||
|
||||
# All retries exhausted
|
||||
_raise_for_error(last_resp or resp) # type: ignore[possibly-undefined]
|
||||
return {} # unreachable (here to satisfy the type checker and linter)
|
||||
# All retries exhausted — last_resp is always set when we reach here
|
||||
_raise_for_error(last_resp or resp)
|
||||
return {} # unreachable; satisfies return type
|
||||
|
||||
def _run(self, q: str | None = None, **params: Any) -> Any:
|
||||
# Allow positional usage: tool.run("latest Brave browser features")
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.base import BraveSearchToolBase
|
||||
from crewai_tools.tools.brave_search_tool.response_types import LLMContext
|
||||
from crewai_tools.tools.brave_search_tool.schemas import (
|
||||
LLMContextHeaders,
|
||||
LLMContextParams,
|
||||
@@ -27,6 +26,6 @@ class BraveLLMContextTool(BraveSearchToolBase):
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LLMContext.Response) -> LLMContext.Response:
|
||||
def _refine_response(self, response: dict[str, Any]) -> Any:
|
||||
"""The LLM Context response schema is fairly simple. Return as is."""
|
||||
return response
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -65,8 +65,8 @@ class BraveLocalPOIsTool(BraveSearchToolBase):
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
results = response.get("results", [])
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"title": result.get("title"),
|
||||
@@ -76,7 +76,7 @@ class BraveLocalPOIsTool(BraveSearchToolBase):
|
||||
"contact": result.get("contact", {}).get("telephone")
|
||||
or result.get("contact", {}).get("email")
|
||||
or None,
|
||||
"opening_hours": _simplify_opening_hours(result),
|
||||
"opening_hours": _simplify_opening_hours(cast(LocationResult, result)),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
@@ -97,9 +97,8 @@ class BraveLocalPOIsDescriptionTool(BraveSearchToolBase):
|
||||
def _refine_request_payload(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return params
|
||||
|
||||
def _refine_response(self, response: LocalPOIsResponse) -> list[dict[str, Any]]:
|
||||
# Make the response more concise, and easier to consume
|
||||
results = response.get("results", [])
|
||||
def _refine_response(self, response: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = response.get("results", [])
|
||||
return [
|
||||
{
|
||||
"id": result.get("id"),
|
||||
|
||||
@@ -50,7 +50,7 @@ class BraveSearchTool(BaseTool):
|
||||
_last_request_time: ClassVar[float] = 0
|
||||
_min_request_interval: ClassVar[float] = 1.0 # seconds
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if "BRAVE_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -13,7 +15,7 @@ class BrightDataConfig(BaseModel):
|
||||
DEFAULT_POLLING_INTERVAL: int = 1
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
def from_env(cls) -> BrightDataConfig:
|
||||
return cls(
|
||||
API_URL=os.environ.get("BRIGHTDATA_API_URL", "https://api.brightdata.com"),
|
||||
DEFAULT_TIMEOUT=int(os.environ.get("BRIGHTDATA_DEFAULT_TIMEOUT", "600")),
|
||||
@@ -26,12 +28,12 @@ class BrightDataConfig(BaseModel):
|
||||
class BrightDataDatasetToolException(Exception): # noqa: N818
|
||||
"""Exception raised for custom error in the application."""
|
||||
|
||||
def __init__(self, message, error_code):
|
||||
def __init__(self, message: str, error_code: int) -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message} (Error Code: {self.error_code})"
|
||||
|
||||
|
||||
@@ -62,7 +64,7 @@ config = BrightDataConfig.from_env()
|
||||
BRIGHTDATA_API_URL = config.API_URL
|
||||
timeout = config.DEFAULT_TIMEOUT
|
||||
|
||||
datasets = [
|
||||
datasets: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "amazon_product",
|
||||
"dataset_id": "gd_l7q7dkf244hwjntr0",
|
||||
@@ -440,7 +442,7 @@ class BrightDataDatasetTool(BaseTool):
|
||||
self.zipcode = zipcode
|
||||
self.additional_params = additional_params
|
||||
|
||||
def filter_dataset_by_id(self, target_id):
|
||||
def filter_dataset_by_id(self, target_id: str) -> list[dict[str, Any]]:
|
||||
return [dataset for dataset in datasets if dataset["id"] == target_id]
|
||||
|
||||
async def get_dataset_data_async(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
import urllib.parse
|
||||
@@ -11,7 +13,7 @@ class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
def from_env(cls) -> BrightDataConfig:
|
||||
return cls(
|
||||
API_URL=os.environ.get(
|
||||
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||
@@ -127,7 +129,7 @@ class BrightDataSearchTool(BaseTool):
|
||||
if not self.zone:
|
||||
raise ValueError("BRIGHT_DATA_ZONE environment variable is required.")
|
||||
|
||||
def get_search_url(self, engine: str, query: str):
|
||||
def get_search_url(self, engine: str, query: str) -> str:
|
||||
if engine == "yandex":
|
||||
return f"https://yandex.com/search/?text=${query}"
|
||||
if engine == "bing":
|
||||
@@ -143,7 +145,7 @@ class BrightDataSearchTool(BaseTool):
|
||||
search_type: str | None = None,
|
||||
device_type: str | None = None,
|
||||
parse_results: bool | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Executes a search query using Bright Data SERP API and returns results.
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +12,7 @@ class BrightDataConfig(BaseModel):
|
||||
API_URL: str = "https://api.brightdata.com/request"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
def from_env(cls) -> BrightDataConfig:
|
||||
return cls(
|
||||
API_URL=os.environ.get(
|
||||
"BRIGHTDATA_API_URL", "https://api.brightdata.com/request"
|
||||
|
||||
@@ -42,15 +42,15 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
text_content: bool | None = False,
|
||||
session_id: str | None = None,
|
||||
proxy: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not self.api_key:
|
||||
raise EnvironmentError(
|
||||
"BROWSERBASE_API_KEY environment variable is required for initialization"
|
||||
)
|
||||
try:
|
||||
from browserbase import Browserbase # type: ignore
|
||||
from browserbase import Browserbase
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
@@ -60,7 +60,7 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "browserbase"], check=True) # noqa: S607
|
||||
from browserbase import Browserbase # type: ignore
|
||||
from browserbase import Browserbase
|
||||
else:
|
||||
raise ImportError(
|
||||
"`browserbase` package not found, please run `uv add browserbase`"
|
||||
@@ -71,7 +71,7 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
self.session_id = session_id
|
||||
self.proxy = proxy
|
||||
|
||||
def _run(self, url: str):
|
||||
def _run(self, url: str) -> Any:
|
||||
return self.browserbase.load_url( # type: ignore[union-attr]
|
||||
url, self.text_content, self.session_id, self.proxy
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -26,7 +28,7 @@ class CodeDocsSearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = CodeDocsSearchToolSchema
|
||||
|
||||
def __init__(self, docs_url: str | None = None, **kwargs):
|
||||
def __init__(self, docs_url: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
@@ -34,7 +36,7 @@ class CodeDocsSearchTool(RagTool):
|
||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docs_url: str) -> None:
|
||||
def add(self, docs_url: str) -> None: # type: ignore[override]
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -18,7 +18,6 @@ from docker import ( # type: ignore[import-untyped]
|
||||
from_env as docker_from_env,
|
||||
)
|
||||
from docker.errors import ImageNotFound, NotFound # type: ignore[import-untyped]
|
||||
from docker.models.containers import Container # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Unpack
|
||||
|
||||
@@ -232,7 +231,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
|
||||
@staticmethod
|
||||
def _install_libraries(container: Container, libraries: list[str]) -> None:
|
||||
def _install_libraries(container: Any, libraries: list[str]) -> None:
|
||||
"""Installs required Python libraries in the Docker container.
|
||||
|
||||
Args:
|
||||
@@ -242,7 +241,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
for library in libraries:
|
||||
container.exec_run(["pip", "install", library])
|
||||
|
||||
def _init_docker_container(self) -> Container:
|
||||
def _init_docker_container(self) -> Any:
|
||||
"""Initializes and returns a Docker container for code execution.
|
||||
|
||||
Stops and removes any existing container with the same name before creating
|
||||
@@ -269,7 +268,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
tty=True,
|
||||
working_dir="/workspace",
|
||||
name=container_name,
|
||||
volumes={current_path: {"bind": "/workspace", "mode": "rw"}}, # type: ignore
|
||||
volumes={current_path: {"bind": "/workspace", "mode": "rw"}},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -351,14 +350,14 @@ class CodeInterpreterTool(BaseTool):
|
||||
container = self._init_docker_container()
|
||||
self._install_libraries(container, libraries_used)
|
||||
|
||||
exec_result = container.exec_run(["python3", "-c", code])
|
||||
exec_result: Any = container.exec_run(["python3", "-c", code])
|
||||
|
||||
container.stop()
|
||||
container.remove()
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
return f"Something went wrong while running the code: \n{exec_result.output.decode('utf-8')}"
|
||||
return exec_result.output.decode("utf-8")
|
||||
return str(exec_result.output.decode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def run_code_in_restricted_sandbox(code: str) -> str:
|
||||
@@ -385,12 +384,12 @@ class CodeInterpreterTool(BaseTool):
|
||||
"""
|
||||
Printer.print(
|
||||
"WARNING: Running code in INSECURE restricted sandbox (vulnerable to escape attacks)",
|
||||
color="bold_red"
|
||||
color="bold_red",
|
||||
)
|
||||
exec_locals: dict[str, Any] = {}
|
||||
try:
|
||||
SandboxPython.exec(code=code, locals_=exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
return exec_locals.get("result", "No result variable found.") # type: ignore[no-any-return]
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
@@ -412,12 +411,14 @@ class CodeInterpreterTool(BaseTool):
|
||||
Printer.print("WARNING: Running code in unsafe mode", color="bold_magenta")
|
||||
# Install libraries on the host machine
|
||||
for library in libraries_used:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", library], check=False) # noqa: S603
|
||||
subprocess.run( # noqa: S603
|
||||
[sys.executable, "-m", "pip", "install", library], check=False
|
||||
)
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
exec_locals: dict[str, Any] = {}
|
||||
exec(code, {}, exec_locals) # noqa: S102
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
return exec_locals.get("result", "No result variable found.") # type: ignore[no-any-return]
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
@@ -10,7 +10,7 @@ import typing_extensions as te
|
||||
class ComposioTool(BaseTool):
|
||||
"""Wrapper for composio tools."""
|
||||
|
||||
composio_action: t.Callable
|
||||
composio_action: t.Callable[..., t.Any]
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -70,7 +70,7 @@ class ComposioTool(BaseTool):
|
||||
schema = action_schema.model_dump(exclude_none=True)
|
||||
entity_id = kwargs.pop("entity_id", DEFAULT_ENTITY_ID)
|
||||
|
||||
def function(**kwargs: t.Any) -> dict:
|
||||
def function(**kwargs: t.Any) -> dict[str, t.Any]:
|
||||
"""Wrapper function for composio action."""
|
||||
return toolset.execute_action(
|
||||
action=Action(schema["name"]),
|
||||
|
||||
@@ -28,7 +28,7 @@ class ContextualAICreateAgentTool(BaseTool):
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from contextual import ContextualAI
|
||||
|
||||
@@ -31,7 +31,7 @@ class ContextualAIQueryTool(BaseTool):
|
||||
default_factory=lambda: ["contextual-client"]
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from contextual import ContextualAI
|
||||
@@ -99,20 +99,19 @@ class ContextualAIQueryTool(BaseTool):
|
||||
response = self.contextual_client.agents.query.create(
|
||||
agent_id=agent_id, messages=[{"role": "user", "content": query}]
|
||||
)
|
||||
if hasattr(response, "content"):
|
||||
return response.content
|
||||
if hasattr(response, "message"):
|
||||
content = getattr(response, "content", None)
|
||||
if content is not None:
|
||||
return str(content)
|
||||
message = getattr(response, "message", None)
|
||||
if message is not None:
|
||||
msg_content = getattr(message, "content", None)
|
||||
return str(msg_content) if msg_content is not None else str(message)
|
||||
messages = getattr(response, "messages", None)
|
||||
if messages and len(messages) > 0:
|
||||
last_message = messages[-1]
|
||||
last_content = getattr(last_message, "content", None)
|
||||
return (
|
||||
response.message.content
|
||||
if hasattr(response.message, "content")
|
||||
else str(response.message)
|
||||
)
|
||||
if hasattr(response, "messages") and len(response.messages) > 0:
|
||||
last_message = response.messages[-1]
|
||||
return (
|
||||
last_message.content
|
||||
if hasattr(last_message, "content")
|
||||
else str(last_message)
|
||||
str(last_content) if last_content is not None else str(last_message)
|
||||
)
|
||||
return str(response)
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,11 +15,11 @@ try:
|
||||
COUCHBASE_AVAILABLE = True
|
||||
except ImportError:
|
||||
COUCHBASE_AVAILABLE = False
|
||||
search = Any
|
||||
Cluster = Any
|
||||
SearchOptions = Any
|
||||
VectorQuery = Any
|
||||
VectorSearch = Any
|
||||
search = Any # type: ignore[assignment,unused-ignore]
|
||||
Cluster = Any # type: ignore[assignment,unused-ignore]
|
||||
SearchOptions = Any # type: ignore[assignment,unused-ignore]
|
||||
VectorQuery = Any # type: ignore[assignment,unused-ignore]
|
||||
VectorSearch = Any # type: ignore[assignment,unused-ignore]
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, SkipValidation
|
||||
@@ -41,7 +41,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
name: str = "CouchbaseFTSVectorSearchTool"
|
||||
description: str = "A tool to search the Couchbase database for relevant information on internal documents."
|
||||
args_schema: type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Cluster] = Field(
|
||||
cluster: SkipValidation[Any] = Field(
|
||||
description="An instance of the Couchbase Cluster connected to the desired Couchbase server.",
|
||||
)
|
||||
collection_name: str = Field(
|
||||
@@ -136,7 +136,7 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
|
||||
return True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the CouchbaseFTSVectorSearchTool.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -26,7 +28,7 @@ class CSVSearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = CSVSearchToolSchema
|
||||
|
||||
def __init__(self, csv: str | None = None, **kwargs):
|
||||
def __init__(self, csv: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
@@ -34,7 +36,7 @@ class CSVSearchTool(RagTool):
|
||||
self.args_schema = FixedCSVSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(self, csv: str) -> None:
|
||||
def add(self, csv: str) -> None: # type: ignore[override]
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import Omit, OpenAI
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -33,9 +33,9 @@ class DallETool(BaseTool):
|
||||
]
|
||||
| None
|
||||
) = "1024x1024"
|
||||
quality: (
|
||||
Literal["standard", "hd", "low", "medium", "high", "auto"] | None | Omit
|
||||
) = "standard"
|
||||
quality: Literal["standard", "hd", "low", "medium", "high", "auto"] | None = (
|
||||
"standard"
|
||||
)
|
||||
n: int = 1
|
||||
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -48,7 +48,7 @@ class DallETool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
client = OpenAI()
|
||||
|
||||
image_description = kwargs.get("image_description")
|
||||
|
||||
@@ -23,7 +23,7 @@ class DirectoryReadTool(BaseTool):
|
||||
args_schema: type[BaseModel] = DirectoryReadToolSchema
|
||||
directory: str | None = None
|
||||
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
def __init__(self, directory: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.directory = directory
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -26,7 +28,7 @@ class DirectorySearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: str | None = None, **kwargs):
|
||||
def __init__(self, directory: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
@@ -34,7 +36,7 @@ class DirectorySearchTool(RagTool):
|
||||
self.args_schema = FixedDirectorySearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None:
|
||||
def add(self, directory: str) -> None: # type: ignore[override]
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -34,7 +34,7 @@ class DOCXSearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = DOCXSearchToolSchema
|
||||
|
||||
def __init__(self, docx: str | None = None, **kwargs):
|
||||
def __init__(self, docx: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
@@ -42,7 +42,7 @@ class DOCXSearchTool(RagTool):
|
||||
self.args_schema = FixedDOCXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docx: str) -> None:
|
||||
def add(self, docx: str) -> None: # type: ignore[override]
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -71,8 +71,8 @@ class EXASearchTool(BaseTool):
|
||||
content: bool | None = False,
|
||||
summary: bool | None = False,
|
||||
type: str | None = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def strtobool(val) -> bool:
|
||||
def strtobool(val: str | bool) -> bool:
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
val = val.lower()
|
||||
@@ -46,7 +46,10 @@ class FileWriterTool(BaseTool):
|
||||
# itself, since that is not a valid file target.
|
||||
real_directory = Path(directory).resolve()
|
||||
real_filepath = Path(filepath).resolve()
|
||||
if not real_filepath.is_relative_to(real_directory) or real_filepath == real_directory:
|
||||
if (
|
||||
not real_filepath.is_relative_to(real_directory)
|
||||
or real_filepath == real_directory
|
||||
):
|
||||
return "Error: Invalid file path — the filename must not escape the target directory."
|
||||
|
||||
if kwargs.get("directory"):
|
||||
@@ -62,9 +65,7 @@ class FileWriterTool(BaseTool):
|
||||
file.write(kwargs["content"])
|
||||
return f"Content successfully written to {real_filepath}"
|
||||
except FileExistsError:
|
||||
return (
|
||||
f"File {real_filepath} already exists and overwrite option was not passed."
|
||||
)
|
||||
return f"File {real_filepath} already exists and overwrite option was not passed."
|
||||
except KeyError as e:
|
||||
return f"An error occurred while accessing key: {e!s}"
|
||||
except Exception as e:
|
||||
|
||||
@@ -106,7 +106,7 @@ class FileCompressorTool(BaseTool):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _compress_zip(input_path: str, output_path: str):
|
||||
def _compress_zip(input_path: str, output_path: str) -> None:
|
||||
"""Compresses input into a zip archive."""
|
||||
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
if os.path.isfile(input_path):
|
||||
@@ -119,7 +119,7 @@ class FileCompressorTool(BaseTool):
|
||||
zipf.write(full_path, arcname)
|
||||
|
||||
@staticmethod
|
||||
def _compress_tar(input_path: str, output_path: str, format: str):
|
||||
def _compress_tar(input_path: str, output_path: str, format: str) -> None:
|
||||
"""Compresses input into a tar archive with the given format."""
|
||||
format_mode = {
|
||||
"tar": "w",
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
@@ -63,7 +60,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
},
|
||||
}
|
||||
)
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Any = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -75,14 +72,14 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self._initialize_firecrawl()
|
||||
|
||||
def _initialize_firecrawl(self) -> None:
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
self._firecrawl = FirecrawlApp(api_key=self.api_key)
|
||||
except ImportError:
|
||||
@@ -105,7 +102,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
"`firecrawl-py` package not found, please run `uv add firecrawl-py`"
|
||||
) from None
|
||||
|
||||
def _run(self, url: str):
|
||||
def _run(self, url: str) -> Any:
|
||||
if not self._firecrawl:
|
||||
raise RuntimeError("FirecrawlApp not properly initialized")
|
||||
|
||||
@@ -113,13 +110,10 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl import FirecrawlApp # noqa: F401
|
||||
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(FirecrawlCrawlWebsiteTool, "_model_rebuilt"):
|
||||
if not getattr(FirecrawlCrawlWebsiteTool, "_model_rebuilt", False):
|
||||
FirecrawlCrawlWebsiteTool.model_rebuild()
|
||||
FirecrawlCrawlWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
@@ -70,7 +67,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
}
|
||||
)
|
||||
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Any = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -82,10 +79,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
from firecrawl import FirecrawlApp
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
@@ -105,7 +102,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
|
||||
self._firecrawl = FirecrawlApp(api_key=api_key)
|
||||
|
||||
def _run(self, url: str):
|
||||
def _run(self, url: str) -> Any:
|
||||
if not self._firecrawl:
|
||||
raise RuntimeError("FirecrawlApp not properly initialized")
|
||||
|
||||
@@ -113,13 +110,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl import FirecrawlApp # noqa: F401
|
||||
|
||||
# Must rebuild model after class is defined
|
||||
if not hasattr(FirecrawlScrapeWebsiteTool, "_model_rebuilt"):
|
||||
if not getattr(FirecrawlScrapeWebsiteTool, "_model_rebuilt", False):
|
||||
FirecrawlScrapeWebsiteTool.model_rebuild()
|
||||
FirecrawlScrapeWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
|
||||
@@ -65,7 +61,7 @@ class FirecrawlSearchTool(BaseTool):
|
||||
},
|
||||
}
|
||||
)
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Any = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -77,14 +73,14 @@ class FirecrawlSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self._initialize_firecrawl()
|
||||
|
||||
def _initialize_firecrawl(self) -> None:
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
self._firecrawl = FirecrawlApp(api_key=self.api_key)
|
||||
except ImportError:
|
||||
@@ -121,13 +117,10 @@ class FirecrawlSearchTool(BaseTool):
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
from firecrawl import FirecrawlApp # noqa: F401
|
||||
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(FirecrawlSearchTool, "_model_rebuilt"):
|
||||
if not getattr(FirecrawlSearchTool, "_model_rebuilt", False):
|
||||
FirecrawlSearchTool.model_rebuild()
|
||||
FirecrawlSearchTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -46,7 +47,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
input_data = GenerateCrewaiAutomationToolSchema(**kwargs)
|
||||
response = requests.post( # noqa: S113
|
||||
f"{self.crewai_enterprise_url}/crewai_plus/api/v1/studio",
|
||||
@@ -58,7 +59,7 @@ class GenerateCrewaiAutomationTool(BaseTool):
|
||||
studio_project_url = response.json().get("url")
|
||||
return f"Generated CrewAI Studio project URL: {studio_project_url}"
|
||||
|
||||
def _get_headers(self, organization_id: str | None = None) -> dict:
|
||||
def _get_headers(self, organization_id: str | None = None) -> dict[str, str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.personal_access_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -38,8 +40,8 @@ class GithubSearchTool(RagTool):
|
||||
self,
|
||||
github_repo: str | None = None,
|
||||
content_types: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if github_repo and content_types:
|
||||
@@ -48,7 +50,7 @@ class GithubSearchTool(RagTool):
|
||||
self.args_schema = FixedGithubSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
repo: str,
|
||||
content_types: list[str] | None = None,
|
||||
|
||||
@@ -10,7 +10,7 @@ class HyperbrowserLoadToolSchema(BaseModel):
|
||||
operation: Literal["scrape", "crawl"] = Field(
|
||||
description="Operation to perform on the website. Either 'scrape' or 'crawl'"
|
||||
)
|
||||
params: dict | None = Field(
|
||||
params: dict[str, Any] | None = Field(
|
||||
description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("HYPERBROWSER_API_KEY")
|
||||
if not api_key:
|
||||
@@ -65,7 +65,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_params(params: dict) -> dict:
|
||||
def _prepare_params(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare session and scrape options parameters."""
|
||||
try:
|
||||
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
|
||||
@@ -91,7 +91,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
|
||||
return params
|
||||
|
||||
def _extract_content(self, data: Any | None):
|
||||
def _extract_content(self, data: Any | None) -> str:
|
||||
"""Extract content from response data."""
|
||||
content = ""
|
||||
if data:
|
||||
@@ -102,15 +102,15 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
self,
|
||||
url: str,
|
||||
operation: Literal["scrape", "crawl"] = "scrape",
|
||||
params: dict | None = None,
|
||||
):
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
if params is None:
|
||||
params = {}
|
||||
try:
|
||||
from hyperbrowser.models.crawl import ( # type: ignore[import-untyped]
|
||||
StartCrawlJobParams,
|
||||
)
|
||||
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
|
||||
from hyperbrowser.models.scrape import (
|
||||
StartScrapeJobParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
|
||||
@@ -134,7 +134,8 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
json={"inputs": inputs},
|
||||
timeout=30,
|
||||
)
|
||||
return response.json()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
|
||||
def _get_crew_status(self, crew_id: str) -> dict[str, Any]:
|
||||
"""Get the status of a crew task.
|
||||
@@ -153,9 +154,10 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
return response.json()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the crew invocation tool."""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@@ -172,7 +174,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
try:
|
||||
status_response = self._get_crew_status(crew_id=kickoff_id)
|
||||
if status_response.get("state", "").lower() == "success":
|
||||
return status_response.get("result", "No result returned")
|
||||
return str(status_response.get("result", "No result returned"))
|
||||
if status_response.get("state", "").lower() == "failed":
|
||||
return f"Error: Crew task failed. Response: {status_response}"
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
@@ -15,14 +17,14 @@ class JinaScrapeWebsiteTool(BaseTool):
|
||||
args_schema: type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: str | None = None
|
||||
api_key: str | None = None
|
||||
headers: dict = Field(default_factory=dict)
|
||||
headers: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
website_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
custom_headers: dict | None = None,
|
||||
**kwargs,
|
||||
custom_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if website_url is not None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -27,7 +29,7 @@ class JSONSearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = JSONSearchToolSchema
|
||||
|
||||
def __init__(self, json_path: str | None = None, **kwargs):
|
||||
def __init__(self, json_path: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ class LinkupSearchTool(BaseTool):
|
||||
description: str = (
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
_client: Any = PrivateAttr()
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["linkup-sdk"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -60,7 +60,7 @@ class LinkupSearchTool(BaseTool):
|
||||
output_type: Literal[
|
||||
"searchResults", "sourcedAnswer", "structured"
|
||||
] = "searchResults",
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
"""Executes a search using the Linkup API.
|
||||
|
||||
:param query: The query to search for.
|
||||
|
||||
@@ -17,11 +17,7 @@ class LlamaIndexTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run tool."""
|
||||
from llama_index.core.tools import ( # type: ignore[import-not-found]
|
||||
BaseTool as LlamaBaseTool,
|
||||
)
|
||||
|
||||
tool = cast(LlamaBaseTool, self.llama_index_tool)
|
||||
tool = self.llama_index_tool
|
||||
|
||||
if self.result_as_answer:
|
||||
return tool(*args, **kwargs).content
|
||||
@@ -36,7 +32,6 @@ class LlamaIndexTool(BaseTool):
|
||||
|
||||
if not isinstance(tool, LlamaBaseTool):
|
||||
raise ValueError(f"Expected a LlamaBaseTool, got {type(tool)}")
|
||||
tool = cast(LlamaBaseTool, tool)
|
||||
|
||||
if tool.metadata.fn_schema is None:
|
||||
raise ValueError(
|
||||
@@ -64,9 +59,7 @@ class LlamaIndexTool(BaseTool):
|
||||
from llama_index.core.query_engine import ( # type: ignore[import-not-found]
|
||||
BaseQueryEngine,
|
||||
)
|
||||
from llama_index.core.tools import ( # type: ignore[import-not-found]
|
||||
QueryEngineTool,
|
||||
)
|
||||
from llama_index.core.tools import QueryEngineTool
|
||||
|
||||
if not isinstance(query_engine, BaseQueryEngine):
|
||||
raise ValueError(f"Expected a BaseQueryEngine, got {type(query_engine)}")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -26,7 +28,7 @@ class MDXSearchTool(RagTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = MDXSearchToolSchema
|
||||
|
||||
def __init__(self, mdx: str | None = None, **kwargs):
|
||||
def __init__(self, mdx: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
@@ -34,7 +36,7 @@ class MDXSearchTool(RagTool):
|
||||
self.args_schema = FixedMDXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(self, mdx: str) -> None:
|
||||
def add(self, mdx: str) -> None: # type: ignore[override]
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
@@ -108,7 +108,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
)
|
||||
raise MergeAgentHandlerToolError(f"API Error: {error_msg}")
|
||||
|
||||
return result
|
||||
return cast(dict[str, Any], result)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to call Agent Handler API: {e!s}")
|
||||
@@ -219,8 +219,8 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
required = params.get("required", [])
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = Any # Default type
|
||||
field_default = ... # Required by default
|
||||
field_type: Any = Any
|
||||
field_default: Any = ...
|
||||
|
||||
# Map JSON schema types to Python types
|
||||
json_type = field_schema.get("type", "string")
|
||||
@@ -256,7 +256,7 @@ class MergeAgentHandlerTool(BaseTool):
|
||||
|
||||
# Create the Pydantic model
|
||||
if fields:
|
||||
args_schema = create_model(
|
||||
args_schema = create_model( # type: ignore[call-overload]
|
||||
f"{tool_name.replace('__', '_').title()}Args",
|
||||
**fields,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.collection import Collection as MongoCollection
|
||||
|
||||
|
||||
def _vector_search_index_definition(
|
||||
@@ -34,7 +34,7 @@ def _vector_search_index_definition(
|
||||
|
||||
|
||||
def create_vector_search_index(
|
||||
collection: Collection,
|
||||
collection: MongoCollection[Any],
|
||||
index_name: str,
|
||||
dimensions: int,
|
||||
path: str,
|
||||
@@ -84,7 +84,7 @@ def create_vector_search_index(
|
||||
)
|
||||
|
||||
|
||||
def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
||||
def _is_index_ready(collection: MongoCollection[Any], index_name: str) -> bool:
|
||||
"""Check for the index name in the list of available search indexes to see if the
|
||||
specified index is of status READY.
|
||||
|
||||
@@ -102,7 +102,7 @@ def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
||||
|
||||
|
||||
def _wait_for_predicate(
|
||||
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
|
||||
predicate: Callable[[], bool], err: str, timeout: float = 120, interval: float = 0.5
|
||||
) -> None:
|
||||
"""Generic to block until the predicate returns true.
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class MongoDBVectorSearchConfig(BaseModel):
|
||||
default=None,
|
||||
description="List of MQL match expressions comparing an indexed field",
|
||||
)
|
||||
post_filter_pipeline: list[dict] | None = Field(
|
||||
post_filter_pipeline: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.",
|
||||
)
|
||||
@@ -105,7 +105,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["mongdb"])
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not MONGODB_AVAILABLE:
|
||||
import click
|
||||
@@ -120,6 +120,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
else:
|
||||
raise ImportError("You are missing the 'mongodb' crewai tool.")
|
||||
|
||||
self._openai_client: AzureOpenAI | Client
|
||||
if "AZURE_OPENAI_ENDPOINT" in os.environ:
|
||||
self._openai_client = AzureOpenAI()
|
||||
elif "OPENAI_API_KEY" in os.environ:
|
||||
@@ -132,7 +133,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
from pymongo import MongoClient
|
||||
from pymongo.driver_info import DriverInfo
|
||||
|
||||
self._client = MongoClient(
|
||||
self._client: MongoClient[dict[str, Any]] = MongoClient(
|
||||
self.connection_string,
|
||||
driver=DriverInfo(name="CrewAI", version=version("crewai-tools")),
|
||||
)
|
||||
@@ -236,7 +237,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
def _bulk_embed_and_insert_texts(
|
||||
self,
|
||||
texts: list[str],
|
||||
metadatas: list[dict],
|
||||
metadatas: list[dict[str, Any]],
|
||||
ids: list[str],
|
||||
) -> list[str]:
|
||||
"""Bulk insert single batch of texts, embeddings, and ids."""
|
||||
@@ -315,16 +316,18 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
logger.error(f"Error: {e}")
|
||||
return ""
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup clients on deletion."""
|
||||
try:
|
||||
if hasattr(self, "_client") and self._client:
|
||||
self._client.close()
|
||||
client: Any = getattr(self, "_client", None)
|
||||
if client is not None:
|
||||
client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, "_openai_client") and self._openai_client:
|
||||
self._openai_client.close()
|
||||
openai_client: Any = getattr(self, "_openai_client", None)
|
||||
if openai_client is not None:
|
||||
openai_client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
@@ -31,11 +31,11 @@ class MultiOnTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from multion.client import MultiOn # type: ignore
|
||||
from multion.client import MultiOn
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
@@ -78,4 +78,4 @@ class MultiOnTool(BaseTool):
|
||||
)
|
||||
self.session_id = browse.session_id
|
||||
|
||||
return browse.message + "\n\n STATUS: " + browse.status
|
||||
return str(browse.message) + "\n\n STATUS: " + str(browse.status)
|
||||
|
||||
@@ -21,13 +21,13 @@ class MySQLSearchTool(RagTool):
|
||||
args_schema: type[BaseModel] = MySQLSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
def __init__(self, table_name: str, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -27,8 +27,8 @@ class NL2SQLTool(BaseTool):
|
||||
title="Database URI",
|
||||
description="The URI of the database to connect to.",
|
||||
)
|
||||
tables: list = Field(default_factory=list)
|
||||
columns: dict = Field(default_factory=dict)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
@@ -37,8 +37,11 @@ class NL2SQLTool(BaseTool):
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
data = {}
|
||||
tables = self._fetch_available_tables()
|
||||
data: dict[str, list[dict[str, Any]] | str] = {}
|
||||
result = self._fetch_available_tables()
|
||||
if isinstance(result, str):
|
||||
raise RuntimeError(f"Failed to fetch tables: {result}")
|
||||
tables: list[dict[str, Any]] = result
|
||||
|
||||
for table in tables:
|
||||
table_columns = self._fetch_all_available_columns(table["table_name"])
|
||||
@@ -47,17 +50,19 @@ class NL2SQLTool(BaseTool):
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
def _fetch_available_tables(self):
|
||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
def _fetch_all_available_columns(self, table_name: str):
|
||||
def _fetch_all_available_columns(
|
||||
self, table_name: str
|
||||
) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
)
|
||||
|
||||
def _run(self, sql_query: str):
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
data = self.execute_sql(sql_query)
|
||||
except Exception as exc:
|
||||
@@ -69,7 +74,7 @@ class NL2SQLTool(BaseTool):
|
||||
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> list | str:
|
||||
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
|
||||
@@ -4,6 +4,7 @@ This tool provides functionality for extracting text from images using supported
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -43,7 +44,7 @@ class OCRTool(BaseTool):
|
||||
llm: LLM = Field(default_factory=lambda: LLM(model="gpt-4o", temperature=0.7))
|
||||
args_schema: type[BaseModel] = OCRToolSchema
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute the OCR operation on the provided image.
|
||||
|
||||
Args:
|
||||
@@ -88,7 +89,7 @@ class OCRTool(BaseTool):
|
||||
return self.llm.call(messages=messages)
|
||||
|
||||
@staticmethod
|
||||
def _encode_image(image_path: str):
|
||||
def _encode_image(image_path: str) -> str:
|
||||
"""Encode an image file to base64 format.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -41,12 +41,12 @@ class OxylabsAmazonProductScraperConfig(BaseModel):
|
||||
user_agent_type: str | None = Field(None, description="Device type and browser.")
|
||||
render: str | None = Field(None, description="Enables JavaScript rendering.")
|
||||
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
|
||||
context: list | None = Field(
|
||||
context: list[Any] | None = Field(
|
||||
None,
|
||||
description="Additional advanced settings and controls for specialized requirements.",
|
||||
)
|
||||
parse: bool | None = Field(None, description="True will return structured data.")
|
||||
parsing_instructions: dict | None = Field(
|
||||
parsing_instructions: dict[str, Any] | None = Field(
|
||||
None, description="Instructions for parsing the results."
|
||||
)
|
||||
|
||||
@@ -71,7 +71,7 @@ class OxylabsAmazonProductScraperTool(BaseTool):
|
||||
description: str = "Scrape Amazon product pages with Oxylabs Amazon Product Scraper"
|
||||
args_schema: type[BaseModel] = OxylabsAmazonProductScraperArgs
|
||||
|
||||
oxylabs_api: RealtimeClient
|
||||
oxylabs_api: Any
|
||||
config: OxylabsAmazonProductScraperConfig
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -93,8 +93,8 @@ class OxylabsAmazonProductScraperTool(BaseTool):
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
config: OxylabsAmazonProductScraperConfig | dict | None = None,
|
||||
**kwargs,
|
||||
config: OxylabsAmazonProductScraperConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
bits, _ = architecture()
|
||||
sdk_type = (
|
||||
@@ -164,4 +164,4 @@ class OxylabsAmazonProductScraperTool(BaseTool):
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content)
|
||||
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
@@ -43,12 +43,12 @@ class OxylabsAmazonSearchScraperConfig(BaseModel):
|
||||
user_agent_type: str | None = Field(None, description="Device type and browser.")
|
||||
render: str | None = Field(None, description="Enables JavaScript rendering.")
|
||||
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
|
||||
context: list | None = Field(
|
||||
context: list[Any] | None = Field(
|
||||
None,
|
||||
description="Additional advanced settings and controls for specialized requirements.",
|
||||
)
|
||||
parse: bool | None = Field(None, description="True will return structured data.")
|
||||
parsing_instructions: dict | None = Field(
|
||||
parsing_instructions: dict[str, Any] | None = Field(
|
||||
None, description="Instructions for parsing the results."
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
|
||||
description: str = "Scrape Amazon search results with Oxylabs Amazon Search Scraper"
|
||||
args_schema: type[BaseModel] = OxylabsAmazonSearchScraperArgs
|
||||
|
||||
oxylabs_api: RealtimeClient
|
||||
oxylabs_api: Any
|
||||
config: OxylabsAmazonSearchScraperConfig
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -95,9 +95,9 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
config: OxylabsAmazonSearchScraperConfig | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
config: OxylabsAmazonSearchScraperConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
bits, _ = architecture()
|
||||
sdk_type = (
|
||||
f"oxylabs-crewai-sdk-python/"
|
||||
@@ -166,4 +166,4 @@ class OxylabsAmazonSearchScraperTool(BaseTool):
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content)
|
||||
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
@@ -46,12 +46,12 @@ class OxylabsGoogleSearchScraperConfig(BaseModel):
|
||||
user_agent_type: str | None = Field(None, description="Device type and browser.")
|
||||
render: str | None = Field(None, description="Enables JavaScript rendering.")
|
||||
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
|
||||
context: list | None = Field(
|
||||
context: list[Any] | None = Field(
|
||||
None,
|
||||
description="Additional advanced settings and controls for specialized requirements.",
|
||||
)
|
||||
parse: bool | None = Field(None, description="True will return structured data.")
|
||||
parsing_instructions: dict | None = Field(
|
||||
parsing_instructions: dict[str, Any] | None = Field(
|
||||
None, description="Instructions for parsing the results."
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
|
||||
description: str = "Scrape Google Search results with Oxylabs Google Search Scraper"
|
||||
args_schema: type[BaseModel] = OxylabsGoogleSearchScraperArgs
|
||||
|
||||
oxylabs_api: RealtimeClient
|
||||
oxylabs_api: Any
|
||||
config: OxylabsGoogleSearchScraperConfig
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -98,9 +98,9 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
config: OxylabsGoogleSearchScraperConfig | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
config: OxylabsGoogleSearchScraperConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
bits, _ = architecture()
|
||||
sdk_type = (
|
||||
f"oxylabs-crewai-sdk-python/"
|
||||
@@ -158,7 +158,7 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
|
||||
)
|
||||
return username, password
|
||||
|
||||
def _run(self, query: str, **kwargs) -> str:
|
||||
def _run(self, query: str, **kwargs: Any) -> str:
|
||||
response = self.oxylabs_api.google.scrape_search(
|
||||
query,
|
||||
**self.config.model_dump(exclude_none=True),
|
||||
@@ -169,4 +169,4 @@ class OxylabsGoogleSearchScraperTool(BaseTool):
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content)
|
||||
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
@@ -37,12 +37,12 @@ class OxylabsUniversalScraperConfig(BaseModel):
|
||||
user_agent_type: str | None = Field(None, description="Device type and browser.")
|
||||
render: str | None = Field(None, description="Enables JavaScript rendering.")
|
||||
callback_url: str | None = Field(None, description="URL to your callback endpoint.")
|
||||
context: list | None = Field(
|
||||
context: list[Any] | None = Field(
|
||||
None,
|
||||
description="Additional advanced settings and controls for specialized requirements.",
|
||||
)
|
||||
parse: bool | None = Field(None, description="True will return structured data.")
|
||||
parsing_instructions: dict | None = Field(
|
||||
parsing_instructions: dict[str, Any] | None = Field(
|
||||
None, description="Instructions for parsing the results."
|
||||
)
|
||||
|
||||
@@ -67,7 +67,7 @@ class OxylabsUniversalScraperTool(BaseTool):
|
||||
description: str = "Scrape any url with Oxylabs Universal Scraper"
|
||||
args_schema: type[BaseModel] = OxylabsUniversalScraperArgs
|
||||
|
||||
oxylabs_api: RealtimeClient
|
||||
oxylabs_api: Any
|
||||
config: OxylabsUniversalScraperConfig
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["oxylabs"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -89,9 +89,9 @@ class OxylabsUniversalScraperTool(BaseTool):
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
config: OxylabsUniversalScraperConfig | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
config: OxylabsUniversalScraperConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
bits, _ = architecture()
|
||||
sdk_type = (
|
||||
f"oxylabs-crewai-sdk-python/"
|
||||
@@ -160,4 +160,4 @@ class OxylabsUniversalScraperTool(BaseTool):
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content)
|
||||
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus import ( # type: ignore[import-not-found,import-untyped]
|
||||
from patronus import ( # type: ignore[import-untyped]
|
||||
Client,
|
||||
EvaluationResult,
|
||||
)
|
||||
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found,import-untyped]
|
||||
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found]
|
||||
PatronusLocalEvaluatorTool,
|
||||
)
|
||||
|
||||
@@ -15,8 +16,8 @@ client = Client()
|
||||
|
||||
|
||||
# Example of an evaluator that returns a random pass/fail result
|
||||
@client.register_local_evaluator("random_evaluator")
|
||||
def random_evaluator(**kwargs):
|
||||
@client.register_local_evaluator("random_evaluator") # type: ignore[untyped-decorator]
|
||||
def random_evaluator(**kwargs: Any) -> Any:
|
||||
score = random.random() # noqa: S311
|
||||
return EvaluationResult(
|
||||
score_raw=score,
|
||||
|
||||
@@ -34,7 +34,7 @@ class PatronusEvalTool(BaseTool):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def _init_run(self):
|
||||
def _init_run(self) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
|
||||
evaluators_set = json.loads(
|
||||
requests.get(
|
||||
"https://api.patronus.ai/v1/evaluators",
|
||||
@@ -136,8 +136,9 @@ class PatronusEvalTool(BaseTool):
|
||||
"evaluators": evals,
|
||||
}
|
||||
|
||||
api_key = os.getenv("PATRONUS_API_KEY", "")
|
||||
headers = {
|
||||
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
|
||||
"X-API-KEY": api_key,
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from patronus import Client, EvaluationResult # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
import patronus # noqa: F401
|
||||
import patronus # type: ignore[import-untyped] # noqa: F401
|
||||
|
||||
PYPATRONUS_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -37,7 +34,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
name: str = "Patronus Local Evaluator Tool"
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
args_schema: type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
client: Client = None
|
||||
client: Any = None
|
||||
evaluator: str
|
||||
evaluated_model_gold_answer: str
|
||||
|
||||
@@ -46,7 +43,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client = None,
|
||||
patronus_client: Any = None,
|
||||
evaluator: str = "",
|
||||
evaluated_model_gold_answer: str = "",
|
||||
**kwargs: Any,
|
||||
@@ -56,7 +53,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self._initialize_patronus(patronus_client)
|
||||
|
||||
def _initialize_patronus(self, patronus_client: Client) -> None:
|
||||
def _initialize_patronus(self, patronus_client: Any) -> None:
|
||||
try:
|
||||
if PYPATRONUS_AVAILABLE:
|
||||
self.client = patronus_client
|
||||
@@ -94,7 +91,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
||||
evaluator = self.evaluator
|
||||
|
||||
result: EvaluationResult = self.client.evaluate(
|
||||
result: Any = self.client.evaluate(
|
||||
evaluator=evaluator,
|
||||
evaluated_model_input=evaluated_model_input,
|
||||
evaluated_model_output=evaluated_model_output,
|
||||
|
||||
@@ -8,16 +8,16 @@ import requests
|
||||
|
||||
|
||||
class FixedBaseToolSchema(BaseModel):
|
||||
evaluated_model_input: dict = Field(
|
||||
evaluated_model_input: dict[str, Any] = Field(
|
||||
..., description="The agent's task description in simple text"
|
||||
)
|
||||
evaluated_model_output: dict = Field(
|
||||
evaluated_model_output: dict[str, Any] = Field(
|
||||
..., description="The agent's output of the task"
|
||||
)
|
||||
evaluated_model_retrieved_context: dict = Field(
|
||||
evaluated_model_retrieved_context: dict[str, Any] = Field(
|
||||
..., description="The agent's context"
|
||||
)
|
||||
evaluated_model_gold_answer: dict = Field(
|
||||
evaluated_model_gold_answer: dict[str, Any] = Field(
|
||||
..., description="The agent's gold answer only if available"
|
||||
)
|
||||
evaluators: list[dict[str, str]] = Field(
|
||||
@@ -57,8 +57,9 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer")
|
||||
evaluators = self.evaluators
|
||||
|
||||
api_key = os.getenv("PATRONUS_API_KEY", "")
|
||||
headers = {
|
||||
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
|
||||
"X-API-KEY": api_key,
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ class PDFSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
def add(self, pdf: str) -> None: # type: ignore[override]
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -119,7 +119,7 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
)
|
||||
)()
|
||||
)
|
||||
results = self.client.query_points(
|
||||
results = self.client.query_points( # type: ignore[union-attr]
|
||||
collection_name=self.qdrant_config.collection_name,
|
||||
query=query_vector,
|
||||
query_filter=search_filter,
|
||||
|
||||
@@ -33,9 +33,9 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
description: str = "A tool that can be used to read a website content."
|
||||
args_schema: type[BaseModel] = ScrapeElementFromWebsiteToolSchema
|
||||
website_url: str | None = None
|
||||
cookies: dict | None = None
|
||||
cookies: dict[str, str] | None = None
|
||||
css_element: str | None = None
|
||||
headers: dict | None = Field(
|
||||
headers: dict[str, str] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
@@ -50,9 +50,9 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
website_url: str | None = None,
|
||||
cookies: dict | None = None,
|
||||
cookies: dict[str, str] | None = None,
|
||||
css_element: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if website_url is not None:
|
||||
@@ -64,7 +64,7 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
self.args_schema = FixedScrapeElementFromWebsiteToolSchema
|
||||
self._generate_description()
|
||||
if cookies is not None:
|
||||
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
|
||||
self.cookies = {cookies["name"]: os.getenv(cookies["value"]) or ""}
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
@@ -31,8 +31,8 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
description: str = "A tool that can be used to read a website content."
|
||||
args_schema: type[BaseModel] = ScrapeWebsiteToolSchema
|
||||
website_url: str | None = None
|
||||
cookies: dict | None = None
|
||||
headers: dict | None = Field(
|
||||
cookies: dict[str, str] | None = None
|
||||
headers: dict[str, str] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
@@ -46,8 +46,8 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
website_url: str | None = None,
|
||||
cookies: dict | None = None,
|
||||
**kwargs,
|
||||
cookies: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if not BEAUTIFULSOUP_AVAILABLE:
|
||||
@@ -63,7 +63,7 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
self.args_schema = FixedScrapeWebsiteToolSchema
|
||||
self._generate_description()
|
||||
if cookies is not None:
|
||||
self.cookies = {cookies["name"]: os.getenv(cookies["value"])}
|
||||
self.cookies = {cookies["name"]: os.getenv(cookies["value"]) or ""}
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
# Type checking import
|
||||
if TYPE_CHECKING:
|
||||
from scrapegraph_py import Client # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class ScrapegraphError(Exception):
|
||||
"""Base exception for Scrapegraph-related errors."""
|
||||
|
||||
@@ -36,7 +31,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
|
||||
@field_validator("website_url")
|
||||
@classmethod
|
||||
def validate_url(cls, v):
|
||||
def validate_url(cls, v: str) -> str:
|
||||
"""Validate URL format."""
|
||||
try:
|
||||
result = urlparse(v)
|
||||
@@ -69,7 +64,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
user_prompt: str | None = None
|
||||
api_key: str | None = None
|
||||
enable_logging: bool = False
|
||||
_client: Client | None = None
|
||||
_client: Any = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["scrapegraph-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -87,12 +82,12 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
user_prompt: str | None = None,
|
||||
api_key: str | None = None,
|
||||
enable_logging: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from scrapegraph_py import Client # type: ignore[import-not-found]
|
||||
from scrapegraph_py.logger import ( # type: ignore[import-not-found]
|
||||
from scrapegraph_py import Client
|
||||
from scrapegraph_py.logger import (
|
||||
sgai_logger,
|
||||
)
|
||||
|
||||
@@ -146,7 +141,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
"Invalid URL format. URL must include scheme (http/https) and domain"
|
||||
) from e
|
||||
|
||||
def _handle_api_response(self, response: dict) -> str:
|
||||
def _handle_api_response(self, response: dict[str, Any]) -> str:
|
||||
"""Handle and validate API response."""
|
||||
if not response:
|
||||
raise RuntimeError("Empty response from Scrapegraph API")
|
||||
@@ -160,7 +155,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if "result" not in response:
|
||||
raise RuntimeError("Invalid response format from Scrapegraph API")
|
||||
|
||||
return response["result"]
|
||||
return str(response["result"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
@@ -69,15 +69,16 @@ class ScrapflyScrapeWebsiteTool(BaseTool):
|
||||
scrape_format: str = "markdown",
|
||||
scrape_config: dict[str, Any] | None = None,
|
||||
ignore_scrape_failures: bool | None = None,
|
||||
):
|
||||
from scrapfly import ScrapeApiResponse, ScrapeConfig
|
||||
) -> str | None:
|
||||
from scrapfly import ScrapeConfig
|
||||
|
||||
scrape_config = scrape_config if scrape_config is not None else {}
|
||||
try:
|
||||
response: ScrapeApiResponse = self.scrapfly.scrape( # type: ignore[union-attr]
|
||||
response = self.scrapfly.scrape( # type: ignore[union-attr]
|
||||
ScrapeConfig(url, format=scrape_format, **scrape_config)
|
||||
)
|
||||
return response.scrape_result["content"]
|
||||
result: str = response.scrape_result["content"]
|
||||
return result
|
||||
except Exception as e:
|
||||
if ignore_scrape_failures:
|
||||
logger.error(f"Error fetching data from {url}, exception: {e}")
|
||||
|
||||
@@ -25,7 +25,7 @@ class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
|
||||
|
||||
@field_validator("website_url")
|
||||
@classmethod
|
||||
def validate_website_url(cls, v):
|
||||
def validate_website_url(cls, v: str) -> str:
|
||||
if not v:
|
||||
raise ValueError("Website URL cannot be empty")
|
||||
|
||||
@@ -54,7 +54,7 @@ class SeleniumScrapingTool(BaseTool):
|
||||
args_schema: type[BaseModel] = SeleniumScrapingToolSchema
|
||||
website_url: str | None = None
|
||||
driver: Any | None = None
|
||||
cookie: dict | None = None
|
||||
cookie: dict[str, Any] | None = None
|
||||
wait_time: int | None = 3
|
||||
css_element: str | None = None
|
||||
return_html: bool | None = False
|
||||
@@ -66,17 +66,17 @@ class SeleniumScrapingTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
website_url: str | None = None,
|
||||
cookie: dict | None = None,
|
||||
cookie: dict[str, Any] | None = None,
|
||||
css_element: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from selenium import webdriver # type: ignore[import-not-found]
|
||||
from selenium.webdriver.chrome.options import ( # type: ignore[import-not-found]
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import (
|
||||
Options,
|
||||
)
|
||||
from selenium.webdriver.common.by import ( # type: ignore[import-not-found]
|
||||
from selenium.webdriver.common.by import (
|
||||
By,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -91,11 +91,11 @@ class SeleniumScrapingTool(BaseTool):
|
||||
["uv", "pip", "install", "selenium", "webdriver-manager"], # noqa: S607
|
||||
check=True,
|
||||
)
|
||||
from selenium import webdriver # type: ignore[import-not-found]
|
||||
from selenium.webdriver.chrome.options import ( # type: ignore[import-not-found]
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import (
|
||||
Options,
|
||||
)
|
||||
from selenium.webdriver.common.by import ( # type: ignore[import-not-found]
|
||||
from selenium.webdriver.common.by import (
|
||||
By,
|
||||
)
|
||||
else:
|
||||
@@ -146,8 +146,10 @@ class SeleniumScrapingTool(BaseTool):
|
||||
if self.driver is not None:
|
||||
self.driver.close()
|
||||
|
||||
def _get_content(self, css_element, return_html):
|
||||
content = []
|
||||
def _get_content(
|
||||
self, css_element: str | None, return_html: bool | None
|
||||
) -> list[str]:
|
||||
content: list[str] = []
|
||||
|
||||
if self._is_css_element_empty(css_element):
|
||||
content.append(self._get_body_content(return_html))
|
||||
@@ -156,20 +158,26 @@ class SeleniumScrapingTool(BaseTool):
|
||||
|
||||
return content
|
||||
|
||||
def _is_css_element_empty(self, css_element):
|
||||
def _is_css_element_empty(self, css_element: str | None) -> bool:
|
||||
return css_element is None or css_element.strip() == ""
|
||||
|
||||
def _get_body_content(self, return_html):
|
||||
def _get_body_content(self, return_html: bool | None) -> str:
|
||||
if self.driver is None or self._by is None:
|
||||
raise RuntimeError("Driver not initialized. Call _run first.")
|
||||
body_element = self.driver.find_element(self._by.TAG_NAME, "body")
|
||||
|
||||
return (
|
||||
return str(
|
||||
body_element.get_attribute("outerHTML")
|
||||
if return_html
|
||||
else body_element.text
|
||||
)
|
||||
|
||||
def _get_elements_content(self, css_element, return_html):
|
||||
elements_content = []
|
||||
def _get_elements_content(
|
||||
self, css_element: str | None, return_html: bool | None
|
||||
) -> list[str]:
|
||||
if self.driver is None or self._by is None:
|
||||
raise RuntimeError("Driver not initialized. Call _run first.")
|
||||
elements_content: list[str] = []
|
||||
|
||||
for element in self.driver.find_elements(self._by.CSS_SELECTOR, css_element):
|
||||
elements_content.append( # noqa: PERF401
|
||||
@@ -178,7 +186,9 @@ class SeleniumScrapingTool(BaseTool):
|
||||
|
||||
return elements_content
|
||||
|
||||
def _make_request(self, url, cookie, wait_time):
|
||||
def _make_request(
|
||||
self, url: str | None, cookie: dict[str, Any] | None, wait_time: int | None
|
||||
) -> None:
|
||||
if not url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
@@ -186,13 +196,17 @@ class SeleniumScrapingTool(BaseTool):
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
if self.driver is None:
|
||||
raise RuntimeError("Driver not initialized. Call _run first.")
|
||||
sleep_time = wait_time or 0
|
||||
self.driver.get(url)
|
||||
time.sleep(wait_time)
|
||||
time.sleep(sleep_time)
|
||||
if cookie:
|
||||
self.driver.add_cookie(cookie)
|
||||
time.sleep(wait_time)
|
||||
time.sleep(sleep_time)
|
||||
self.driver.get(url)
|
||||
time.sleep(wait_time)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def close(self):
|
||||
self.driver.close()
|
||||
def close(self) -> None:
|
||||
if self.driver is not None:
|
||||
self.driver.close()
|
||||
|
||||
@@ -22,11 +22,11 @@ class SerpApiBaseTool(BaseTool):
|
||||
|
||||
client: Any | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
try:
|
||||
from serpapi import Client # type: ignore
|
||||
from serpapi import Client
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
@@ -48,7 +48,9 @@ class SerpApiBaseTool(BaseTool):
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _omit_fields(self, data: dict | list, omit_patterns: list[str]) -> None:
|
||||
def _omit_fields(
|
||||
self, data: dict[str, Any] | list[Any], omit_patterns: list[str]
|
||||
) -> None:
|
||||
if isinstance(data, dict):
|
||||
for field in list(data.keys()):
|
||||
if any(re.compile(p).match(field) for p in omit_patterns):
|
||||
|
||||
@@ -160,7 +160,7 @@ class SerperDevTool(BaseTool):
|
||||
processed_results: list[OrganicResult] = []
|
||||
for result in organic_results[: self.n_results]:
|
||||
try:
|
||||
result_data: OrganicResult = { # type: ignore[typeddict-item]
|
||||
result_data: OrganicResult = {
|
||||
"title": result["title"],
|
||||
"link": result["link"],
|
||||
"snippet": result.get("snippet", ""),
|
||||
@@ -168,7 +168,7 @@ class SerperDevTool(BaseTool):
|
||||
}
|
||||
|
||||
if "sitelinks" in result:
|
||||
result_data["sitelinks"] = [ # type: ignore[typeddict-unknown-key]
|
||||
result_data["sitelinks"] = [
|
||||
{
|
||||
"title": sitelink.get("title", ""),
|
||||
"link": sitelink.get("link", ""),
|
||||
@@ -180,7 +180,7 @@ class SerperDevTool(BaseTool):
|
||||
except KeyError: # noqa: PERF203
|
||||
logger.warning(f"Skipping malformed organic result: {result}")
|
||||
continue
|
||||
return processed_results # type: ignore[return-value]
|
||||
return processed_results
|
||||
|
||||
def _process_people_also_ask(
|
||||
self, paa_results: list[dict[str, Any]]
|
||||
@@ -189,7 +189,7 @@ class SerperDevTool(BaseTool):
|
||||
processed_results: list[PeopleAlsoAskResult] = []
|
||||
for result in paa_results[: self.n_results]:
|
||||
try:
|
||||
result_data: PeopleAlsoAskResult = { # type: ignore[typeddict-item]
|
||||
result_data: PeopleAlsoAskResult = {
|
||||
"question": result["question"],
|
||||
"snippet": result.get("snippet", ""),
|
||||
"title": result.get("title", ""),
|
||||
@@ -199,7 +199,7 @@ class SerperDevTool(BaseTool):
|
||||
except KeyError: # noqa: PERF203
|
||||
logger.warning(f"Skipping malformed PAA result: {result}")
|
||||
continue
|
||||
return processed_results # type: ignore[return-value]
|
||||
return processed_results
|
||||
|
||||
def _process_related_searches(
|
||||
self, related_results: list[dict[str, Any]]
|
||||
@@ -208,11 +208,11 @@ class SerperDevTool(BaseTool):
|
||||
processed_results: list[RelatedSearchResult] = []
|
||||
for result in related_results[: self.n_results]:
|
||||
try:
|
||||
processed_results.append({"query": result["query"]}) # type: ignore[typeddict-item]
|
||||
processed_results.append({"query": result["query"]})
|
||||
except KeyError: # noqa: PERF203
|
||||
logger.warning(f"Skipping malformed related search result: {result}")
|
||||
continue
|
||||
return processed_results # type: ignore[return-value]
|
||||
return processed_results
|
||||
|
||||
def _process_news_results(
|
||||
self, news_results: list[dict[str, Any]]
|
||||
@@ -221,7 +221,7 @@ class SerperDevTool(BaseTool):
|
||||
processed_results: list[NewsResult] = []
|
||||
for result in news_results[: self.n_results]:
|
||||
try:
|
||||
result_data: NewsResult = { # type: ignore[typeddict-item]
|
||||
result_data: NewsResult = {
|
||||
"title": result["title"],
|
||||
"link": result["link"],
|
||||
"snippet": result.get("snippet", ""),
|
||||
@@ -233,7 +233,7 @@ class SerperDevTool(BaseTool):
|
||||
except KeyError: # noqa: PERF203
|
||||
logger.warning(f"Skipping malformed news result: {result}")
|
||||
continue
|
||||
return processed_results # type: ignore[return-value]
|
||||
return processed_results
|
||||
|
||||
def _make_api_request(self, search_query: str, search_type: str) -> dict[str, Any]:
|
||||
"""Make API request to Serper."""
|
||||
@@ -262,7 +262,7 @@ class SerperDevTool(BaseTool):
|
||||
if not results:
|
||||
logger.error("Empty response from Serper API")
|
||||
raise ValueError("Empty response from Serper API")
|
||||
return results
|
||||
return dict(results)
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error making request to Serper API: {e}"
|
||||
if response is not None and hasattr(response, "content"):
|
||||
|
||||
@@ -53,7 +53,7 @@ class SerperScrapeWebsiteTool(BaseTool):
|
||||
payload = json.dumps({"url": url, "includeMarkdown": include_markdown})
|
||||
|
||||
# Set headers
|
||||
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
|
||||
headers = {"X-API-KEY": api_key or "", "Content-Type": "application/json"}
|
||||
|
||||
# Make the API request
|
||||
response = requests.post(
|
||||
@@ -69,7 +69,7 @@ class SerperScrapeWebsiteTool(BaseTool):
|
||||
|
||||
# Extract the scraped content
|
||||
if "text" in result:
|
||||
return result["text"]
|
||||
return str(result["text"])
|
||||
return f"Successfully scraped {url}, but no text content found in response: {response.text}"
|
||||
return (
|
||||
f"Error scraping {url}: HTTP {response.status_code} - {response.text}"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from crewai.tools import EnvVar
|
||||
@@ -29,7 +30,7 @@ class SerplyJobSearchTool(RagTool):
|
||||
proxy_location: (str): Where to get jobs, specifically for a specific country results.
|
||||
- Currently only supports US
|
||||
"""
|
||||
headers: dict | None = Field(default_factory=dict)
|
||||
headers: dict[str, str] | None = Field(default_factory=dict)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -40,12 +41,12 @@ class SerplyJobSearchTool(RagTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.headers = {
|
||||
"X-API-KEY": os.environ["SERPLY_API_KEY"],
|
||||
"User-Agent": "crew-tools",
|
||||
"X-Proxy-Location": self.proxy_location,
|
||||
"X-Proxy-Location": self.proxy_location or "US",
|
||||
}
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -21,7 +21,7 @@ class SerplyNewsSearchTool(BaseTool):
|
||||
args_schema: type[BaseModel] = SerplyNewsSearchToolSchema
|
||||
search_url: str = "https://api.serply.io/v1/news/"
|
||||
proxy_location: str | None = "US"
|
||||
headers: dict | None = Field(default_factory=dict)
|
||||
headers: dict[str, str] | None = Field(default_factory=dict)
|
||||
limit: int | None = 10
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -34,8 +34,8 @@ class SerplyNewsSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, limit: int | None = 10, proxy_location: str | None = "US", **kwargs
|
||||
):
|
||||
self, limit: int | None = 10, proxy_location: str | None = "US", **kwargs: Any
|
||||
) -> None:
|
||||
"""param: limit (int): The maximum number of results to return [10-100, defaults to 10]
|
||||
proxy_location: (str): Where to get news, specifically for a specific country results.
|
||||
['US', 'CA', 'IE', 'GB', 'FR', 'DE', 'SE', 'IN', 'JP', 'KR', 'SG', 'AU', 'BR'] (defaults to US).
|
||||
@@ -46,7 +46,7 @@ class SerplyNewsSearchTool(BaseTool):
|
||||
self.headers = {
|
||||
"X-API-KEY": os.environ["SERPLY_API_KEY"],
|
||||
"User-Agent": "crew-tools",
|
||||
"X-Proxy-Location": proxy_location,
|
||||
"X-Proxy-Location": proxy_location or "US",
|
||||
}
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -25,7 +25,7 @@ class SerplyScholarSearchTool(BaseTool):
|
||||
search_url: str = "https://api.serply.io/v1/scholar/"
|
||||
hl: str | None = "us"
|
||||
proxy_location: str | None = "US"
|
||||
headers: dict | None = Field(default_factory=dict)
|
||||
headers: dict[str, str] | None = Field(default_factory=dict)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -36,7 +36,9 @@ class SerplyScholarSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, hl: str = "us", proxy_location: str | None = "US", **kwargs):
|
||||
def __init__(
|
||||
self, hl: str = "us", proxy_location: str | None = "US", **kwargs: Any
|
||||
) -> None:
|
||||
"""param: hl (str): host Language code to display results in
|
||||
(reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
|
||||
proxy_location: (str): Specify the proxy location for the search, specifically for a specific country results.
|
||||
@@ -48,7 +50,7 @@ class SerplyScholarSearchTool(BaseTool):
|
||||
self.headers = {
|
||||
"X-API-KEY": os.environ["SERPLY_API_KEY"],
|
||||
"User-Agent": "crew-tools",
|
||||
"X-Proxy-Location": proxy_location,
|
||||
"X-Proxy-Location": proxy_location or "US",
|
||||
}
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -24,8 +24,8 @@ class SerplyWebSearchTool(BaseTool):
|
||||
limit: int | None = 10
|
||||
device_type: str | None = "desktop"
|
||||
proxy_location: str | None = "US"
|
||||
query_payload: dict | None = Field(default_factory=dict)
|
||||
headers: dict | None = Field(default_factory=dict)
|
||||
query_payload: dict[str, Any] | None = Field(default_factory=dict)
|
||||
headers: dict[str, str] | None = Field(default_factory=dict)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -42,8 +42,8 @@ class SerplyWebSearchTool(BaseTool):
|
||||
limit: int = 10,
|
||||
device_type: str = "desktop",
|
||||
proxy_location: str = "US",
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""param: query (str): The query to search for
|
||||
param: hl (str): host Language code to display results in
|
||||
(reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
try:
|
||||
from singlestoredb import connect
|
||||
from singlestoredb import connect # type: ignore[attr-defined]
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
SINGLSTORE_AVAILABLE = True
|
||||
@@ -117,7 +117,7 @@ class SingleStoreSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
connection_args: dict = Field(default_factory=dict)
|
||||
connection_args: dict[str, Any] = Field(default_factory=dict)
|
||||
connection_pool: Any | None = None
|
||||
|
||||
def __init__(
|
||||
@@ -169,8 +169,8 @@ class SingleStoreSearchTool(BaseTool):
|
||||
pool_size: int | None = 5,
|
||||
max_overflow: int | None = 10,
|
||||
timeout: float | None = 30,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the SingleStore search tool.
|
||||
|
||||
Args:
|
||||
@@ -274,7 +274,7 @@ class SingleStoreSearchTool(BaseTool):
|
||||
|
||||
# Initialize connection pool for efficient connection management
|
||||
self.connection_pool = QueuePool(
|
||||
creator=self._create_connection, # type: ignore[arg-type]
|
||||
creator=self._create_connection,
|
||||
pool_size=pool_size or 5,
|
||||
max_overflow=max_overflow or 10,
|
||||
timeout=timeout or 30.0,
|
||||
|
||||
@@ -12,10 +12,10 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import types for type checking only
|
||||
from snowflake.connector.connection import ( # type: ignore[import-not-found]
|
||||
from snowflake.connector.connection import (
|
||||
SnowflakeConnection,
|
||||
)
|
||||
from snowflake.connector.errors import ( # type: ignore[import-not-found]
|
||||
from snowflake.connector.errors import (
|
||||
DatabaseError,
|
||||
OperationalError,
|
||||
)
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import snowflake.connector # type: ignore[import-not-found]
|
||||
import snowflake.connector
|
||||
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -60,7 +60,7 @@ class SnowflakeConfig(BaseModel):
|
||||
def has_auth(self) -> bool:
|
||||
return bool(self.password or self.private_key_path)
|
||||
|
||||
def model_post_init(self, *args, **kwargs):
|
||||
def model_post_init(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not self.has_auth:
|
||||
raise ValueError("Either password or private_key_path must be provided")
|
||||
|
||||
@@ -115,7 +115,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._initialize_snowflake()
|
||||
@@ -268,7 +268,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
logger.error(f"Error executing query: {e!s}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
if self._connection_pool:
|
||||
|
||||
@@ -72,8 +72,8 @@ class SpiderTool(BaseTool):
|
||||
website_url: str | None = None,
|
||||
custom_params: dict[str, Any] | None = None,
|
||||
log_failures: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize SpiderTool for web scraping and crawling.
|
||||
|
||||
Args:
|
||||
@@ -96,7 +96,7 @@ class SpiderTool(BaseTool):
|
||||
self.custom_params = custom_params
|
||||
|
||||
try:
|
||||
from spider import Spider # type: ignore
|
||||
from spider import Spider
|
||||
|
||||
except ImportError:
|
||||
import click
|
||||
@@ -191,7 +191,8 @@ class SpiderTool(BaseTool):
|
||||
action = (
|
||||
self.spider.scrape_url if mode == "scrape" else self.spider.crawl_url
|
||||
)
|
||||
return action(url=url, params=params)
|
||||
result: str | None = action(url=url, params=params)
|
||||
return result
|
||||
|
||||
except ValueError as ve:
|
||||
if self.log_failures:
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.utilities.printer import Printer
|
||||
from dotenv import load_dotenv
|
||||
from stagehand.schemas import AvailableModel # type: ignore[import-untyped]
|
||||
from stagehand.schemas import AvailableModel # type: ignore[import-not-found]
|
||||
|
||||
from crewai_tools import StagehandTool
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
@@ -13,13 +15,13 @@ from pydantic import BaseModel, Field
|
||||
_HAS_STAGEHAND = False
|
||||
|
||||
try:
|
||||
from stagehand import ( # type: ignore[import-untyped]
|
||||
from stagehand import ( # type: ignore[attr-defined]
|
||||
Stagehand,
|
||||
StagehandConfig,
|
||||
StagehandPage,
|
||||
configure_logging,
|
||||
)
|
||||
from stagehand.schemas import ( # type: ignore[import-untyped]
|
||||
from stagehand.schemas import ( # type: ignore[import-not-found]
|
||||
ActOptions,
|
||||
AvailableModel,
|
||||
ExtractOptions,
|
||||
@@ -28,8 +30,7 @@ try:
|
||||
|
||||
_HAS_STAGEHAND = True
|
||||
except ImportError:
|
||||
# Define type stubs for when stagehand is not installed
|
||||
Stagehand = Any
|
||||
Stagehand = Any # type: ignore[assignment, misc]
|
||||
StagehandPage = Any
|
||||
StagehandConfig = Any
|
||||
ActOptions = Any
|
||||
@@ -37,7 +38,11 @@ except ImportError:
|
||||
ObserveOptions = Any
|
||||
|
||||
# Mock configure_logging function
|
||||
def configure_logging(level=None, remove_logger_name=None, quiet_dependencies=None):
|
||||
def configure_logging(
|
||||
level: str | None = None,
|
||||
remove_logger_name: bool | None = None,
|
||||
quiet_dependencies: bool | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
# Define only what's needed for class defaults
|
||||
@@ -57,7 +62,7 @@ class StagehandResult(BaseModel):
|
||||
success: bool = Field(
|
||||
..., description="Whether the operation completed successfully"
|
||||
)
|
||||
data: str | dict | list = Field(
|
||||
data: str | dict[str, Any] | list[Any] = Field(
|
||||
..., description="The result data from the operation"
|
||||
)
|
||||
error: str | None = Field(
|
||||
@@ -160,7 +165,7 @@ class StagehandTool(BaseTool):
|
||||
api_key: str | None = None
|
||||
project_id: str | None = None
|
||||
model_api_key: str | None = None
|
||||
model_name: AvailableModel | None = AvailableModel.CLAUDE_3_7_SONNET_LATEST
|
||||
model_name: Any = AvailableModel.CLAUDE_3_7_SONNET_LATEST
|
||||
server_url: str | None = "https://api.stagehand.browserbase.com/v1"
|
||||
headless: bool = False
|
||||
dom_settle_timeout_ms: int = 3000
|
||||
@@ -173,8 +178,8 @@ class StagehandTool(BaseTool):
|
||||
use_simplified_dom: bool = True
|
||||
|
||||
# Instance variables
|
||||
_stagehand: Stagehand | None = None
|
||||
_page: StagehandPage | None = None
|
||||
_stagehand: Any = None
|
||||
_page: Any = None
|
||||
_session_id: str | None = None
|
||||
_testing: bool = False
|
||||
|
||||
@@ -192,8 +197,8 @@ class StagehandTool(BaseTool):
|
||||
wait_for_captcha_solves: bool | None = None,
|
||||
verbose: int | None = None,
|
||||
_testing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Set testing flag early so that other init logic can rely on it
|
||||
self._testing = _testing
|
||||
super().__init__(**kwargs)
|
||||
@@ -235,7 +240,7 @@ class StagehandTool(BaseTool):
|
||||
|
||||
self._check_required_credentials()
|
||||
|
||||
def _check_required_credentials(self):
|
||||
def _check_required_credentials(self) -> None:
|
||||
"""Validate that required credentials are present."""
|
||||
if not self._testing and not _HAS_STAGEHAND:
|
||||
raise ImportError(
|
||||
@@ -249,14 +254,14 @@ class StagehandTool(BaseTool):
|
||||
"project_id is required (or set BROWSERBASE_PROJECT_ID in env)."
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
"""Ensure cleanup on deletion."""
|
||||
try:
|
||||
self.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _get_model_api_key(self):
|
||||
def _get_model_api_key(self) -> str | None:
|
||||
"""Get the appropriate API key based on the model being used."""
|
||||
# Check model type and get appropriate key
|
||||
model_str = str(self.model_name)
|
||||
@@ -273,29 +278,29 @@ class StagehandTool(BaseTool):
|
||||
or os.getenv("ANTHROPIC_API_KEY")
|
||||
)
|
||||
|
||||
async def _setup_stagehand(self, session_id: str | None = None):
|
||||
async def _setup_stagehand(self, session_id: str | None = None) -> tuple[Any, Any]:
|
||||
"""Initialize Stagehand if not already set up."""
|
||||
# If we're in testing mode, return mock objects
|
||||
if self._testing:
|
||||
if not self._stagehand:
|
||||
# Create mock objects for testing
|
||||
class MockPage:
|
||||
async def act(self, options):
|
||||
async def act(self, options: Any) -> Any:
|
||||
mock_result = type("MockResult", (), {})()
|
||||
mock_result.model_dump = lambda: {
|
||||
"message": "Action completed successfully"
|
||||
}
|
||||
return mock_result
|
||||
|
||||
async def goto(self, url):
|
||||
async def goto(self, url: str) -> None:
|
||||
return None
|
||||
|
||||
async def extract(self, options):
|
||||
async def extract(self, options: Any) -> Any:
|
||||
mock_result = type("MockResult", (), {})()
|
||||
mock_result.model_dump = lambda: {"data": "Extracted content"}
|
||||
return mock_result
|
||||
|
||||
async def observe(self, options):
|
||||
async def observe(self, options: Any) -> list[Any]:
|
||||
mock_result1 = type(
|
||||
"MockResult",
|
||||
(),
|
||||
@@ -303,18 +308,18 @@ class StagehandTool(BaseTool):
|
||||
)()
|
||||
return [mock_result1]
|
||||
|
||||
async def wait_for_load_state(self, state):
|
||||
async def wait_for_load_state(self, state: str) -> None:
|
||||
return None
|
||||
|
||||
class MockStagehand:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.page = MockPage()
|
||||
self.session_id = "test-session-id"
|
||||
|
||||
async def init(self):
|
||||
async def init(self) -> None:
|
||||
return None
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
self._stagehand = MockStagehand()
|
||||
@@ -352,7 +357,7 @@ class StagehandTool(BaseTool):
|
||||
)
|
||||
|
||||
# Initialize Stagehand with config
|
||||
self._stagehand = Stagehand(config=config)
|
||||
self._stagehand = Stagehand(config=config) # type: ignore[call-arg]
|
||||
|
||||
# Initialize the Stagehand instance
|
||||
await self._stagehand.init()
|
||||
@@ -404,7 +409,7 @@ class StagehandTool(BaseTool):
|
||||
instruction: str | None = None,
|
||||
url: str | None = None,
|
||||
command_type: str = "act",
|
||||
):
|
||||
) -> StagehandResult:
|
||||
"""Override _async_run with improved atomic action handling."""
|
||||
# Handle missing instruction based on command type
|
||||
if not instruction:
|
||||
@@ -419,7 +424,7 @@ class StagehandTool(BaseTool):
|
||||
|
||||
# For testing mode, return mock result directly without calling parent
|
||||
if self._testing:
|
||||
mock_data = {
|
||||
mock_data: dict[str, str] = {
|
||||
"message": f"Mock {command_type} completed successfully",
|
||||
"instruction": instruction,
|
||||
}
|
||||
@@ -436,7 +441,7 @@ class StagehandTool(BaseTool):
|
||||
|
||||
# Get the API key to pass to model operations
|
||||
model_api_key = self._get_model_api_key()
|
||||
model_client_options = {"apiKey": model_api_key}
|
||||
model_client_options: dict[str, Any] = {"apiKey": model_api_key}
|
||||
|
||||
# Always navigate first if URL is provided and we're doing actions
|
||||
if url and command_type.lower() == "act":
|
||||
@@ -452,7 +457,7 @@ class StagehandTool(BaseTool):
|
||||
steps = self._extract_steps(instruction)
|
||||
self._logger.info(f"Extracted {len(steps)} steps: {steps}")
|
||||
|
||||
results = []
|
||||
results: list[dict[str, Any]] = []
|
||||
for i, step in enumerate(steps):
|
||||
self._logger.info(f"Executing step {i + 1}/{len(steps)}: {step}")
|
||||
|
||||
@@ -559,16 +564,16 @@ class StagehandTool(BaseTool):
|
||||
modelClientOptions=model_client_options, # Add API key here
|
||||
)
|
||||
|
||||
results = await page.observe(observe_options)
|
||||
observe_results = await page.observe(observe_options)
|
||||
|
||||
# Format the observation results
|
||||
formatted_results = []
|
||||
for i, result in enumerate(results):
|
||||
formatted_results: list[dict[str, Any]] = []
|
||||
for i, obs_result in enumerate(observe_results):
|
||||
formatted_results.append(
|
||||
{
|
||||
"index": i + 1,
|
||||
"description": result.description,
|
||||
"method": result.method,
|
||||
"description": obs_result.description,
|
||||
"method": obs_result.method,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -586,7 +591,12 @@ class StagehandTool(BaseTool):
|
||||
self._logger.error(f"Operation failed: {error_msg}")
|
||||
return self._format_result(False, {}, error_msg)
|
||||
|
||||
def _format_result(self, success, data, error=None):
|
||||
def _format_result(
|
||||
self,
|
||||
success: bool,
|
||||
data: str | dict[str, Any] | list[Any],
|
||||
error: str | None = None,
|
||||
) -> StagehandResult:
|
||||
"""Helper to format results consistently."""
|
||||
return StagehandResult(success=success, data=data, error=error)
|
||||
|
||||
@@ -653,10 +663,14 @@ class StagehandTool(BaseTool):
|
||||
f"Step {i + 1}: {step.get('message', 'Completed')}"
|
||||
)
|
||||
return "\n".join(step_messages)
|
||||
return f"Action result: {result.data.get('message', 'Completed')}"
|
||||
if isinstance(result.data, dict):
|
||||
return (
|
||||
f"Action result: {result.data.get('message', 'Completed')}"
|
||||
)
|
||||
return f"Action result: {result.data}"
|
||||
if command_type.lower() == "extract":
|
||||
return f"Extracted data: {json.dumps(result.data, indent=2)}"
|
||||
if command_type.lower() == "observe":
|
||||
if command_type.lower() == "observe" and isinstance(result.data, list):
|
||||
formatted_results = []
|
||||
for element in result.data:
|
||||
formatted_results.append(
|
||||
@@ -680,7 +694,7 @@ class StagehandTool(BaseTool):
|
||||
return str(result.data)
|
||||
return f"Error: {result.error}"
|
||||
|
||||
async def _async_close(self):
|
||||
async def _async_close(self) -> None:
|
||||
"""Asynchronously clean up Stagehand resources."""
|
||||
# Skip for test mode
|
||||
if self._testing:
|
||||
@@ -694,7 +708,7 @@ class StagehandTool(BaseTool):
|
||||
if self._page:
|
||||
self._page = None
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Clean up Stagehand resources."""
|
||||
# Skip actual closing for testing mode
|
||||
if self._testing:
|
||||
@@ -704,9 +718,9 @@ class StagehandTool(BaseTool):
|
||||
|
||||
if self._stagehand:
|
||||
try:
|
||||
# Handle both synchronous and asynchronous cases
|
||||
if hasattr(self._stagehand, "close"):
|
||||
if asyncio.iscoroutinefunction(self._stagehand.close):
|
||||
close_method: Any = getattr(self._stagehand, "close", None)
|
||||
if close_method is not None:
|
||||
if asyncio.iscoroutinefunction(close_method):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
@@ -725,8 +739,7 @@ class StagehandTool(BaseTool):
|
||||
except RuntimeError:
|
||||
asyncio.run(self._async_close())
|
||||
else:
|
||||
# Handle non-async close method (for mocks)
|
||||
self._stagehand.close()
|
||||
close_method()
|
||||
except Exception: # noqa: S110
|
||||
# Log but don't raise - we're cleaning up
|
||||
pass
|
||||
@@ -736,10 +749,15 @@ class StagehandTool(BaseTool):
|
||||
if self._page:
|
||||
self._page = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> StagehandTool:
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Exit the context manager and clean up resources."""
|
||||
self.close()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user