Compare commits

..

2 Commits

Author SHA1 Message Date
Greyson LaLonde
c1776aca8a feat: add utils for deferred imports 2026-01-13 11:27:02 -05:00
Koushiv
8f99fa76ed feat: additional a2a transports
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: Koushiv Sadhukhan <koushiv.777@gmail.com>
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
2026-01-12 12:03:06 -05:00
8 changed files with 424 additions and 735 deletions

View File

@@ -91,6 +91,10 @@ The `A2AConfig` class accepts the following parameters:
Update mechanism for receiving task status. Options: `StreamingConfig`, `PollingConfig`, or `PushNotificationConfig`.
</ParamField>
<ParamField path="transport_protocol" type="Literal['JSONRPC', 'GRPC', 'HTTP+JSON']" default="JSONRPC">
Transport protocol for A2A communication. Options: `JSONRPC` (default), `GRPC`, or `HTTP+JSON`.
</ParamField>
## Authentication
For A2A agents that require authentication, use one of the provided auth schemes:

View File

@@ -5,7 +5,7 @@ This module is separate from experimental.a2a to avoid circular imports.
from __future__ import annotations
from typing import Annotated, Any, ClassVar
from typing import Annotated, Any, ClassVar, Literal
from pydantic import (
BaseModel,
@@ -53,6 +53,7 @@ class A2AConfig(BaseModel):
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@@ -82,3 +83,7 @@ class A2AConfig(BaseModel):
default_factory=_get_default_update_config,
description="Update mechanism config",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
)

View File

@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, MutableMapping
from contextlib import asynccontextmanager
from functools import lru_cache
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
import uuid
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
@@ -18,7 +18,6 @@ from a2a.types import (
PushNotificationConfig as A2APushNotificationConfig,
Role,
TextPart,
TransportProtocol,
)
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
@@ -259,6 +258,7 @@ async def _afetch_agent_card_impl(
def execute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
timeout: int,
task_description: str,
@@ -282,6 +282,23 @@ def execute_a2a_delegation(
use aexecute_a2a_delegation directly.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
timeout: Request timeout in seconds
task_description: The task to delegate
context: Optional context information
context_id: Context ID for correlating messages/tasks
task_id: Specific task identifier
reference_task_ids: List of related task IDs
metadata: Additional metadata (external_id, request_id, etc.)
extensions: Protocol extensions for custom fields
conversation_history: Previous Message objects from conversation
agent_id: Agent identifier for logging
agent_role: Role of the CrewAI agent delegating the task
agent_branch: Optional agent tree branch for logging
response_model: Optional Pydantic model for structured outputs
turn_number: Optional turn number for multi-turn conversations
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
timeout: Request timeout in seconds.
@@ -323,6 +340,7 @@ def execute_a2a_delegation(
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
transport_protocol=transport_protocol,
turn_number=turn_number,
updates=updates,
)
@@ -333,6 +351,7 @@ def execute_a2a_delegation(
async def aexecute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
timeout: int,
task_description: str,
@@ -356,6 +375,23 @@ async def aexecute_a2a_delegation(
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
task_description: Task to delegate
context: Optional context
context_id: Context ID for correlation
task_id: Specific task identifier
reference_task_ids: Related task IDs
metadata: Additional metadata
extensions: Protocol extensions
conversation_history: Previous Message objects
turn_number: Current turn number
agent_branch: Agent tree branch for logging
agent_id: Agent identifier for logging
agent_role: Agent role for logging
response_model: Optional Pydantic model for structured outputs
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
timeout: Request timeout in seconds.
@@ -414,6 +450,7 @@ async def aexecute_a2a_delegation(
agent_role=agent_role,
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
)
crewai_event_bus.emit(
@@ -431,6 +468,7 @@ async def aexecute_a2a_delegation(
async def _aexecute_a2a_delegation_impl(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
timeout: int,
task_description: str,
@@ -524,7 +562,6 @@ async def _aexecute_a2a_delegation_impl(
extensions=extensions,
)
transport_protocol = TransportProtocol("JSONRPC")
new_messages: list[Message] = [*conversation_history, message]
crewai_event_bus.emit(
None,
@@ -596,7 +633,7 @@ async def _aexecute_a2a_delegation_impl(
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
transport_protocol: TransportProtocol,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
@@ -640,7 +677,7 @@ async def _create_a2a_client(
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[str(transport_protocol.value)],
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=["application/json"],

View File

@@ -771,6 +771,7 @@ def _delegate_to_a2a(
response_model=agent_config.response_model,
turn_number=turn_num + 1,
updates=agent_config.updates,
transport_protocol=agent_config.transport_protocol,
)
conversation_history = a2a_result.get("history", [])
@@ -1085,6 +1086,7 @@ async def _adelegate_to_a2a(
agent_branch=agent_branch,
response_model=agent_config.response_model,
turn_number=turn_num + 1,
transport_protocol=agent_config.transport_protocol,
updates=agent_config.updates,
)

View File

@@ -1,12 +1,9 @@
from crewai.tools.base_tool import BaseTool, EnvVar, tool
from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool
__all__ = [
"BaseTool",
"EnvVar",
"SearchStrategy",
"ToolSearchTool",
"tool",
]

View File

@@ -1,333 +0,0 @@
"""Tool Search Tool for on-demand tool discovery.
This module implements a Tool Search Tool that allows agents to dynamically
discover and load tools on-demand, reducing token consumption when working
with large tool libraries.
Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from enum import Enum
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.pydantic_schema_utils import generate_model_description
class SearchStrategy(str, Enum):
"""Search strategy for tool discovery."""
KEYWORD = "keyword"
REGEX = "regex"
class ToolSearchResult(BaseModel):
"""Result from a tool search operation."""
name: str = Field(description="The name of the tool")
description: str = Field(description="The description of the tool")
args_schema: dict[str, Any] = Field(
description="The JSON schema for the tool's arguments"
)
class ToolSearchToolSchema(BaseModel):
"""Schema for the Tool Search Tool arguments."""
query: str = Field(
description="The search query to find relevant tools. Use keywords that describe the capability you need."
)
max_results: int = Field(
default=5,
description="Maximum number of tools to return. Default is 5.",
ge=1,
le=20,
)
class ToolSearchTool(BaseTool):
"""A tool that searches through a catalog of tools to find relevant ones.
This tool enables on-demand tool discovery, allowing agents to work with
large tool libraries without loading all tool definitions upfront. Instead
of consuming tokens with all tool definitions, the agent can search for
relevant tools when needed.
Example:
```python
from crewai.tools import BaseTool, ToolSearchTool
# Create your tools
search_tool = MySearchTool()
scrape_tool = MyScrapeWebsiteTool()
database_tool = MyDatabaseTool()
# Create a tool search tool with your tool catalog
tool_search = ToolSearchTool(
tool_catalog=[search_tool, scrape_tool, database_tool],
search_strategy=SearchStrategy.KEYWORD,
)
# Use with an agent - only the tool_search is loaded initially
agent = Agent(
role="Researcher",
tools=[tool_search], # Other tools discovered on-demand
)
```
Attributes:
tool_catalog: List of tools available for search.
search_strategy: Strategy to use for searching (keyword or regex).
custom_search_fn: Optional custom search function for advanced matching.
"""
name: str = Field(
default="Tool Search",
description="The name of the tool search tool.",
)
description: str = Field(
default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.",
description="Description of what the tool search tool does.",
)
args_schema: type[BaseModel] = Field(
default=ToolSearchToolSchema,
description="The schema for the tool search arguments.",
)
tool_catalog: list[BaseTool | CrewStructuredTool] = Field(
default_factory=list,
description="List of tools available for search.",
)
search_strategy: SearchStrategy = Field(
default=SearchStrategy.KEYWORD,
description="Strategy to use for searching tools.",
)
custom_search_fn: Callable[
[str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool]
] | None = Field(
default=None,
description="Optional custom search function for advanced matching.",
)
def _run(self, query: str, max_results: int = 5) -> str:
"""Search for tools matching the query.
Args:
query: The search query to find relevant tools.
max_results: Maximum number of tools to return.
Returns:
JSON string containing the matching tool definitions.
"""
if not self.tool_catalog:
return json.dumps(
{
"status": "error",
"message": "No tools available in the catalog.",
"tools": [],
}
)
if self.custom_search_fn:
matching_tools = self.custom_search_fn(query, self.tool_catalog)
elif self.search_strategy == SearchStrategy.REGEX:
matching_tools = self._regex_search(query)
else:
matching_tools = self._keyword_search(query)
matching_tools = matching_tools[:max_results]
if not matching_tools:
return json.dumps(
{
"status": "no_results",
"message": f"No tools found matching query: '{query}'. Try different keywords.",
"tools": [],
}
)
tool_results = []
for tool in matching_tools:
tool_info = self._get_tool_info(tool)
tool_results.append(tool_info)
return json.dumps(
{
"status": "success",
"message": f"Found {len(tool_results)} tool(s) matching your query.",
"tools": tool_results,
},
indent=2,
)
def _keyword_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using keyword matching.
Args:
query: The search query.
Returns:
List of matching tools sorted by relevance.
"""
query_lower = query.lower()
query_words = set(query_lower.split())
scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = []
for tool in self.tool_catalog:
score = self._calculate_keyword_score(tool, query_lower, query_words)
if score > 0:
scored_tools.append((score, tool))
scored_tools.sort(key=lambda x: x[0], reverse=True)
return [tool for _, tool in scored_tools]
def _calculate_keyword_score(
self,
tool: BaseTool | CrewStructuredTool,
query_lower: str,
query_words: set[str],
) -> float:
"""Calculate relevance score for a tool based on keyword matching.
Args:
tool: The tool to score.
query_lower: Lowercase query string.
query_words: Set of query words.
Returns:
Relevance score (higher is better).
"""
score = 0.0
tool_name_lower = tool.name.lower()
tool_desc_lower = tool.description.lower()
if query_lower in tool_name_lower:
score += 10.0
if query_lower in tool_desc_lower:
score += 5.0
for word in query_words:
if len(word) < 2:
continue
if word in tool_name_lower:
score += 3.0
if word in tool_desc_lower:
score += 1.0
return score
def _regex_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using regex pattern matching.
Args:
query: The regex pattern to search for.
Returns:
List of matching tools.
"""
try:
pattern = re.compile(query, re.IGNORECASE)
except re.error:
pattern = re.compile(re.escape(query), re.IGNORECASE)
return [
tool
for tool in self.tool_catalog
if pattern.search(tool.name) or pattern.search(tool.description)
]
def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]:
"""Get tool information as a dictionary.
Args:
tool: The tool to get information from.
Returns:
Dictionary containing tool name, description, and args schema.
"""
if isinstance(tool, BaseTool):
schema_dict = generate_model_description(tool.args_schema)
args_schema = schema_dict.get("json_schema", {}).get("schema", {})
else:
args_schema = tool.args_schema.model_json_schema()
return {
"name": tool.name,
"description": self._get_original_description(tool),
"args_schema": args_schema,
}
def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str:
"""Get the original description of a tool without the generated schema.
Args:
tool: The tool to get the description from.
Returns:
The original tool description.
"""
description = tool.description
if "Tool Description:" in description:
parts = description.split("Tool Description:")
if len(parts) > 1:
return parts[1].strip()
return description
def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None:
"""Add a tool to the catalog.
Args:
tool: The tool to add.
"""
self.tool_catalog.append(tool)
def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None:
"""Add multiple tools to the catalog.
Args:
tools: The tools to add.
"""
self.tool_catalog.extend(tools)
def remove_tool(self, tool_name: str) -> bool:
"""Remove a tool from the catalog by name.
Args:
tool_name: The name of the tool to remove.
Returns:
True if the tool was removed, False if not found.
"""
for i, tool in enumerate(self.tool_catalog):
if tool.name == tool_name:
self.tool_catalog.pop(i)
return True
return False
def get_catalog_size(self) -> int:
"""Get the number of tools in the catalog.
Returns:
The number of tools in the catalog.
"""
return len(self.tool_catalog)
def list_tool_names(self) -> list[str]:
"""List all tool names in the catalog.
Returns:
List of tool names.
"""
return [tool.name for tool in self.tool_catalog]

View File

@@ -0,0 +1,370 @@
"""Lazy loader for Python packages.
Makes it easy to load subpackages and functions on demand.
Pulled from https://github.com/scientific-python/lazy-loader/blob/main/src/lazy_loader/__init__.py,
modernized a little.
"""
import ast
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
import importlib
import importlib.metadata
import importlib.util
import inspect
import os
from pathlib import Path
import sys
import threading
import types
from typing import Any, NoReturn
import warnings
import packaging.requirements
_threadlock = threading.Lock()
@dataclass(frozen=True, slots=True)
class _FrameData:
"""Captured stack frame information for delayed error reporting."""
filename: str
lineno: int
function: str
code_context: Sequence[str] | None
def attach(
package_name: str,
submodules: set[str] | None = None,
submod_attrs: dict[str, list[str]] | None = None,
) -> tuple[Callable[[str], Any], Callable[[], list[str]], list[str]]:
"""Attach lazily loaded submodules, functions, or other attributes.
Replaces a package's `__getattr__`, `__dir__`, and `__all__` such that
imports work normally but occur upon first use.
Example:
__getattr__, __dir__, __all__ = lazy.attach(
__name__, ["mysubmodule"], {"foo": ["someattr"]}
)
Args:
package_name: The package name, typically ``__name__``.
submodules: Set of submodule names to attach.
submod_attrs: Mapping of submodule names to lists of attributes.
These attributes are imported as they are used.
Returns:
A tuple of (__getattr__, __dir__, __all__) to assign in the package.
"""
submod_attrs = submod_attrs or {}
submodules = set(submodules) if submodules else set()
attr_to_modules = {
attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
}
__all__ = sorted(submodules | attr_to_modules.keys())
def __getattr__(name: str) -> Any: # noqa: N807
if name in submodules:
return importlib.import_module(f"{package_name}.{name}")
if name in attr_to_modules:
submod_path = f"{package_name}.{attr_to_modules[name]}"
submod = importlib.import_module(submod_path)
attr = getattr(submod, name)
# If the attribute lives in a file (module) with the same
# name as the attribute, ensure that the attribute and *not*
# the module is accessible on the package.
if name == attr_to_modules[name]:
pkg = sys.modules[package_name]
pkg.__dict__[name] = attr
return attr
raise AttributeError(f"No {package_name} attribute {name}")
def __dir__() -> list[str]: # noqa: N807
return __all__.copy()
if os.environ.get("EAGER_IMPORT"):
for attr in set(attr_to_modules.keys()) | submodules:
__getattr__(attr)
return __getattr__, __dir__, __all__.copy()
class DelayedImportErrorModule(types.ModuleType):
"""Module type that delays raising ModuleNotFoundError until attribute access.
Captures stack frame data to provide helpful error messages showing where
the original import was attempted.
"""
def __init__(
self,
frame_data: _FrameData,
*args: Any,
message: str,
**kwargs: Any,
) -> None:
"""Initialize the delayed error module.
Args:
frame_data: Captured frame information for error reporting.
*args: Positional arguments passed to ModuleType.
message: The error message to display when accessed.
**kwargs: Keyword arguments passed to ModuleType.
"""
self._frame_data = frame_data
self._message = message
super().__init__(*args, **kwargs)
def __getattr__(self, name: str) -> NoReturn:
"""Raise ModuleNotFoundError with detailed context on any attribute access."""
frame = self._frame_data
code = "".join(frame.code_context) if frame.code_context else ""
raise ModuleNotFoundError(
f"{self._message}\n\n"
"This error is lazily reported, having originally occurred in\n"
f" File {frame.filename}, line {frame.lineno}, in {frame.function}\n\n"
f"----> {code.strip()}"
)
def load(
fullname: str,
*,
require: str | None = None,
error_on_import: bool = False,
suppress_warning: bool = False,
) -> types.ModuleType:
"""Return a lazily imported proxy for a module.
The proxy module delays actual import until first attribute access.
Example:
np = lazy.load("numpy")
def myfunc():
np.norm(...)
Warning:
Lazily loading subpackages causes the parent package to be eagerly
loaded. Use `lazy_loader.attach` instead for subpackages.
Args:
fullname: The full name of the module to import (e.g., "scipy").
require: A PEP-508 dependency requirement (e.g., "numpy >=1.24").
If specified, raises an error if the installed version doesn't match.
error_on_import: If True, raise import errors immediately.
If False (default), delay errors until module is accessed.
suppress_warning: If True, suppress the warning when loading subpackages.
Returns:
A proxy module that loads on first attribute access.
"""
with _threadlock:
module = sys.modules.get(fullname)
# Most common, short-circuit
if module is not None and require is None:
return module
have_module = module is not None
if not suppress_warning and "." in fullname:
msg = (
"subpackages can technically be lazily loaded, but it causes the "
"package to be eagerly loaded even if it is already lazily loaded. "
"So, you probably shouldn't use subpackages with this lazy feature."
)
warnings.warn(msg, RuntimeWarning, stacklevel=2)
spec = None
if not have_module:
spec = importlib.util.find_spec(fullname)
have_module = spec is not None
if not have_module:
not_found_message = f"No module named '{fullname}'"
elif require is not None:
try:
have_module = _check_requirement(require)
except ModuleNotFoundError as e:
raise ValueError(
f"Found module '{fullname}' but cannot test "
"requirement '{require}'. "
"Requirements must match distribution name, not module name."
) from e
not_found_message = f"No distribution can be found matching '{require}'"
if not have_module:
if error_on_import:
raise ModuleNotFoundError(not_found_message)
parent = inspect.stack()[1]
frame_data = _FrameData(
filename=parent.filename,
lineno=parent.lineno,
function=parent.function,
code_context=parent.code_context,
)
del parent
return DelayedImportErrorModule(
frame_data,
"DelayedImportErrorModule",
message=not_found_message,
)
if spec is not None:
module = importlib.util.module_from_spec(spec)
sys.modules[fullname] = module
if spec.loader is not None:
loader = importlib.util.LazyLoader(spec.loader)
loader.exec_module(module)
if module is None:
raise ModuleNotFoundError(f"No module named '{fullname}'")
return module
def _check_requirement(require: str) -> bool:
"""Verify that a package requirement is satisfied.
Args:
require: A dependency requirement as defined in PEP-508.
Returns:
True if the installed version matches the requirement, False otherwise.
Raises:
ModuleNotFoundError: If the package is not installed.
"""
req = packaging.requirements.Requirement(require)
return req.specifier.contains(
importlib.metadata.version(req.name),
prereleases=True,
)
@dataclass
class _StubVisitor(ast.NodeVisitor):
"""AST visitor to parse a stub file for submodules and submod_attrs."""
_submodules: set[str] = field(default_factory=set)
_submod_attrs: dict[str, list[str]] = field(default_factory=dict)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Visit an ImportFrom node and extract submodule/attribute information.
Args:
node: The AST ImportFrom node to visit.
Raises:
ValueError: If the import is not a relative import or uses star import.
"""
if node.level != 1:
raise ValueError(
"Only within-module imports are supported (`from .* import`)"
)
names = [alias.name for alias in node.names]
if node.module:
if "*" in names:
raise ValueError(
f"lazy stub loader does not support star import "
f"`from {node.module} import *`"
)
self._submod_attrs.setdefault(node.module, []).extend(names)
else:
self._submodules.update(names)
def attach_stub(
package_name: str,
filename: str,
) -> tuple[Callable[[str], Any], Callable[[], list[str]], list[str]]:
"""Attach lazily loaded submodules and functions from a type stub.
Parses a `.pyi` stub file to infer submodules and attributes. This allows
static type checkers to find imports while providing lazy loading at runtime.
Args:
package_name: The package name, typically ``__name__``.
filename: Path to `.py` file with an adjacent `.pyi` file.
Typically use ``__file__``.
Returns:
A tuple of (__getattr__, __dir__, __all__) to assign in the package.
Raises:
ValueError: If stub file is not found or contains invalid imports.
"""
path = Path(filename)
stubfile = path if path.suffix == ".pyi" else path.with_suffix(".pyi")
if not stubfile.exists():
raise ValueError(f"Cannot load imports from non-existent stub {stubfile!r}")
visitor = _StubVisitor()
visitor.visit(ast.parse(stubfile.read_text()))
return attach(package_name, visitor._submodules, visitor._submod_attrs)
def lazy_exports_stub(package_name: str, filename: str) -> None:
"""Install lazy loading on a module based on its .pyi stub file.
Parses the adjacent `.pyi` stub file to determine what to export lazily.
Type checkers see the stub, runtime gets lazy loading.
Example:
# __init__.py
from crewai.utilities.lazy import lazy_exports_stub
lazy_exports_stub(__name__, __file__)
# __init__.pyi
from .config import ChromaDBConfig, ChromaDBSettings
from .types import EmbeddingType
Args:
package_name: The package name, typically ``__name__``.
filename: Path to the module file, typically ``__file__``.
"""
__getattr__, __dir__, __all__ = attach_stub(package_name, filename)
module = sys.modules[package_name]
module.__getattr__ = __getattr__ # type: ignore[method-assign]
module.__dir__ = __dir__ # type: ignore[method-assign]
module.__dict__["__all__"] = __all__
def lazy_exports(
package_name: str,
submod_attrs: dict[str, list[str]],
submodules: set[str] | None = None,
) -> None:
"""Install lazy loading on a module.
Example:
from crewai.utilities.lazy import lazy_exports
lazy_exports(__name__, {
'config': ['ChromaDBConfig', 'ChromaDBSettings'],
'types': ['EmbeddingType'],
})
Args:
package_name: The package name, typically ``__name__``.
submod_attrs: Mapping of submodule names to lists of attributes.
submodules: Optional set of submodule names to expose directly.
"""
__getattr__, __dir__, __all__ = attach(package_name, submodules, submod_attrs)
module = sys.modules[package_name]
module.__getattr__ = __getattr__ # type: ignore[method-assign]
module.__dir__ = __dir__ # type: ignore[method-assign]
module.__dict__["__all__"] = __all__

View File

@@ -1,393 +0,0 @@
"""Tests for the ToolSearchTool functionality."""
import json
import pytest
from pydantic import BaseModel
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
class MockSearchTool(BaseTool):
"""A mock search tool for testing."""
name: str = "Web Search"
description: str = "Search the web for information on any topic."
def _run(self, query: str) -> str:
return f"Search results for: {query}"
class MockDatabaseTool(BaseTool):
"""A mock database tool for testing."""
name: str = "Database Query"
description: str = "Query a SQL database to retrieve data."
def _run(self, query: str) -> str:
return f"Database results for: {query}"
class MockScrapeTool(BaseTool):
"""A mock web scraping tool for testing."""
name: str = "Web Scraper"
description: str = "Scrape content from websites and extract text."
def _run(self, url: str) -> str:
return f"Scraped content from: {url}"
class MockEmailTool(BaseTool):
"""A mock email tool for testing."""
name: str = "Send Email"
description: str = "Send an email to a specified recipient."
def _run(self, to: str, subject: str, body: str) -> str:
return f"Email sent to {to}"
class MockCalculatorTool(BaseTool):
"""A mock calculator tool for testing."""
name: str = "Calculator"
description: str = "Perform mathematical calculations and arithmetic operations."
def _run(self, expression: str) -> str:
return f"Result: {eval(expression)}"
@pytest.fixture
def sample_tools() -> list[BaseTool]:
"""Create a list of sample tools for testing."""
return [
MockSearchTool(),
MockDatabaseTool(),
MockScrapeTool(),
MockEmailTool(),
MockCalculatorTool(),
]
@pytest.fixture
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
"""Create a ToolSearchTool with sample tools."""
return ToolSearchTool(tool_catalog=sample_tools)
class TestToolSearchToolCreation:
"""Tests for ToolSearchTool creation and initialization."""
def test_create_tool_search_with_empty_catalog(self) -> None:
"""Test creating a ToolSearchTool with an empty catalog."""
tool_search = ToolSearchTool()
assert tool_search.name == "Tool Search"
assert tool_search.tool_catalog == []
assert tool_search.search_strategy == SearchStrategy.KEYWORD
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
"""Test creating a ToolSearchTool with a list of tools."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert len(tool_search.tool_catalog) == 5
assert tool_search.get_catalog_size() == 5
def test_create_tool_search_with_regex_strategy(
self, sample_tools: list[BaseTool]
) -> None:
"""Test creating a ToolSearchTool with regex search strategy."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
assert tool_search.search_strategy == SearchStrategy.REGEX
def test_create_tool_search_with_custom_name(self) -> None:
"""Test creating a ToolSearchTool with a custom name."""
tool_search = ToolSearchTool(name="My Tool Finder")
assert tool_search.name == "My Tool Finder"
class TestToolSearchKeywordSearch:
"""Tests for keyword-based tool search."""
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by its exact name."""
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
assert result_data["tools"][0]["name"] == "Web Search"
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by partial name."""
result = tool_search._run("Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by keyword in description."""
result = tool_search._run("database")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Database Query" in tool_names
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
"""Test searching with multiple keywords."""
result = tool_search._run("web scrape content")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Scraper" in tool_names
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
"""Test searching with a query that returns no results."""
result = tool_search._run("xyznonexistent123abc")
result_data = json.loads(result)
assert result_data["status"] == "no_results"
assert len(result_data["tools"]) == 0
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
"""Test that max_results limits the number of returned tools."""
result = tool_search._run("tool", max_results=2)
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) <= 2
def test_search_empty_catalog(self) -> None:
"""Test searching with an empty tool catalog."""
tool_search = ToolSearchTool()
result = tool_search._run("search")
result_data = json.loads(result)
assert result_data["status"] == "error"
assert "No tools available" in result_data["message"]
class TestToolSearchRegexSearch:
"""Tests for regex-based tool search."""
def test_regex_search_simple_pattern(
self, sample_tools: list[BaseTool]
) -> None:
"""Test regex search with a simple pattern."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("Web.*")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names or "Web Scraper" in tool_names
def test_regex_search_case_insensitive(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that regex search is case insensitive."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Send Email" in tool_names
def test_regex_search_invalid_pattern_fallback(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that invalid regex patterns are escaped and still work."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("[invalid(regex")
result_data = json.loads(result)
assert result_data["status"] in ["success", "no_results"]
class TestToolSearchCustomSearch:
"""Tests for custom search function."""
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
"""Test using a custom search function."""
def custom_search(
query: str, tools: list[BaseTool]
) -> list[BaseTool]:
return [t for t in tools if "email" in t.name.lower()]
tool_search = ToolSearchTool(
tool_catalog=sample_tools, custom_search_fn=custom_search
)
result = tool_search._run("anything")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) == 1
assert result_data["tools"][0]["name"] == "Send Email"
class TestToolSearchCatalogManagement:
"""Tests for tool catalog management."""
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
"""Test adding a tool to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool(BaseTool):
name: str = "New Tool"
description: str = "A new tool for testing."
def _run(self) -> str:
return "New tool result"
tool_search.add_tool(NewTool())
assert tool_search.get_catalog_size() == initial_size + 1
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
"""Test adding multiple tools to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool1(BaseTool):
name: str = "New Tool 1"
description: str = "First new tool."
def _run(self) -> str:
return "Result 1"
class NewTool2(BaseTool):
name: str = "New Tool 2"
description: str = "Second new tool."
def _run(self) -> str:
return "Result 2"
tool_search.add_tools([NewTool1(), NewTool2()])
assert tool_search.get_catalog_size() == initial_size + 2
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool from the catalog."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Web Search")
assert result is True
assert tool_search.get_catalog_size() == initial_size - 1
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool that doesn't exist."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Nonexistent Tool")
assert result is False
assert tool_search.get_catalog_size() == initial_size
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
"""Test listing all tool names in the catalog."""
names = tool_search.list_tool_names()
assert len(names) == 5
assert "Web Search" in names
assert "Database Query" in names
assert "Web Scraper" in names
assert "Send Email" in names
assert "Calculator" in names
class TestToolSearchResultFormat:
"""Tests for the format of search results."""
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
"""Test that search results contain complete tool information."""
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "name" in tool_info
assert "description" in tool_info
assert "args_schema" in tool_info
assert tool_info["name"] == "Calculator"
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
"""Test that args_schema is properly formatted."""
result = tool_search._run("Email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "args_schema" in tool_info
args_schema = tool_info["args_schema"]
assert isinstance(args_schema, dict)
class TestToolSearchIntegration:
"""Integration tests for ToolSearchTool."""
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
"""Test that ToolSearchTool works as a BaseTool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert isinstance(tool_search, BaseTool)
assert tool_search.name == "Tool Search"
assert "search" in tool_search.description.lower()
def test_tool_search_to_structured_tool(
self, sample_tools: list[BaseTool]
) -> None:
"""Test converting ToolSearchTool to structured tool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
structured = tool_search.to_structured_tool()
assert structured.name == "Tool Search"
assert structured.args_schema is not None
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
"""Test the run method of ToolSearchTool."""
result = tool_search.run(query="search", max_results=3)
assert isinstance(result, str)
result_data = json.loads(result)
assert "status" in result_data
assert "tools" in result_data
class TestToolSearchScoring:
"""Tests for the keyword scoring algorithm."""
def test_exact_name_match_scores_highest(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that exact name matches score higher than partial matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Web Search"
def test_name_match_scores_higher_than_description(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that name matches score higher than description matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Calculator"