mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 21:38:14 +00:00
Compare commits
1 Commits
lorenze/ex
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fae812ffb7 |
@@ -91,10 +91,6 @@ The `A2AConfig` class accepts the following parameters:
|
||||
Update mechanism for receiving task status. Options: `StreamingConfig`, `PollingConfig`, or `PushNotificationConfig`.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="transport_protocol" type="Literal['JSONRPC', 'GRPC', 'HTTP+JSON']" default="JSONRPC">
|
||||
Transport protocol for A2A communication. Options: `JSONRPC` (default), `GRPC`, or `HTTP+JSON`.
|
||||
</ParamField>
|
||||
|
||||
## Authentication
|
||||
|
||||
For A2A agents that require authentication, use one of the provided auth schemes:
|
||||
|
||||
@@ -5,7 +5,7 @@ This module is separate from experimental.a2a to avoid circular imports.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
from typing import Annotated, Any, ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -53,7 +53,6 @@ class A2AConfig(BaseModel):
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
@@ -83,7 +82,3 @@ class A2AConfig(BaseModel):
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
|
||||
default="JSONRPC",
|
||||
description="Specified mode of A2A transport protocol",
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import lru_cache
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
|
||||
@@ -18,6 +18,7 @@ from a2a.types import (
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
Role,
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
@@ -258,7 +259,6 @@ async def _afetch_agent_card_impl(
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -282,23 +282,6 @@ def execute_a2a_delegation(
|
||||
use aexecute_a2a_delegation directly.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
|
||||
timeout: Request timeout in seconds
|
||||
task_description: The task to delegate
|
||||
context: Optional context information
|
||||
context_id: Context ID for correlating messages/tasks
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: List of related task IDs
|
||||
metadata: Additional metadata (external_id, request_id, etc.)
|
||||
extensions: Protocol extensions for custom fields
|
||||
conversation_history: Previous Message objects from conversation
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Role of the CrewAI agent delegating the task
|
||||
agent_branch: Optional agent tree branch for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
turn_number: Optional turn number for multi-turn conversations
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
@@ -340,7 +323,6 @@ def execute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
agent_branch=agent_branch,
|
||||
response_model=response_model,
|
||||
transport_protocol=transport_protocol,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
)
|
||||
@@ -351,7 +333,6 @@ def execute_a2a_delegation(
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -375,23 +356,6 @@ async def aexecute_a2a_delegation(
|
||||
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
task_description: Task to delegate
|
||||
context: Optional context
|
||||
context_id: Context ID for correlation
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: Related task IDs
|
||||
metadata: Additional metadata
|
||||
extensions: Protocol extensions
|
||||
conversation_history: Previous Message objects
|
||||
turn_number: Current turn number
|
||||
agent_branch: Agent tree branch for logging
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Agent role for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
@@ -450,7 +414,6 @@ async def aexecute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -468,7 +431,6 @@ async def aexecute_a2a_delegation(
|
||||
|
||||
async def _aexecute_a2a_delegation_impl(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -562,6 +524,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
transport_protocol = TransportProtocol("JSONRPC")
|
||||
new_messages: list[Message] = [*conversation_history, message]
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
@@ -633,7 +596,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
transport_protocol: TransportProtocol,
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
@@ -677,7 +640,7 @@ async def _create_a2a_client(
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
supported_transports=[str(transport_protocol.value)],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=["application/json"],
|
||||
|
||||
@@ -771,7 +771,6 @@ def _delegate_to_a2a(
|
||||
response_model=agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
updates=agent_config.updates,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
@@ -1086,7 +1085,6 @@ async def _adelegate_to_a2a(
|
||||
agent_branch=agent_branch,
|
||||
response_model=agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
updates=agent_config.updates,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
|
||||
from crewai.experimental.environment_tools import (
|
||||
BaseEnvironmentTool,
|
||||
EnvironmentTools,
|
||||
FileReadTool,
|
||||
FileSearchTool,
|
||||
GrepTool,
|
||||
ListDirTool,
|
||||
)
|
||||
from crewai.experimental.evaluation import (
|
||||
AgentEvaluationResult,
|
||||
AgentEvaluator,
|
||||
@@ -31,20 +23,14 @@ from crewai.experimental.evaluation import (
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEnvironmentTool",
|
||||
"BaseEvaluator",
|
||||
"CrewAgentExecutorFlow",
|
||||
"EnvironmentTools",
|
||||
"EvaluationScore",
|
||||
"EvaluationTraceCallback",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentRunner",
|
||||
"FileReadTool",
|
||||
"FileSearchTool",
|
||||
"GoalAlignmentEvaluator",
|
||||
"GrepTool",
|
||||
"ListDirTool",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Environment tools for file system operations.
|
||||
|
||||
These tools provide agents with the ability to explore and read from
|
||||
the filesystem for context engineering purposes.
|
||||
"""
|
||||
|
||||
from crewai.experimental.environment_tools.base_environment_tool import (
|
||||
BaseEnvironmentTool,
|
||||
)
|
||||
from crewai.experimental.environment_tools.environment_tools import EnvironmentTools
|
||||
from crewai.experimental.environment_tools.file_read_tool import FileReadTool
|
||||
from crewai.experimental.environment_tools.file_search_tool import FileSearchTool
|
||||
from crewai.experimental.environment_tools.grep_tool import GrepTool
|
||||
from crewai.experimental.environment_tools.list_dir_tool import ListDirTool
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseEnvironmentTool",
|
||||
"EnvironmentTools",
|
||||
"FileReadTool",
|
||||
"FileSearchTool",
|
||||
"GrepTool",
|
||||
"ListDirTool",
|
||||
]
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Base class for environment tools with path security."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class BaseEnvironmentTool(BaseTool):
|
||||
"""Base class for environment/file system tools with path security.
|
||||
|
||||
Provides path validation to restrict file operations to allowed directories.
|
||||
This prevents path traversal attacks and enforces security sandboxing.
|
||||
|
||||
Attributes:
|
||||
allowed_paths: List of paths that operations are restricted to.
|
||||
Empty list means allow all paths (no restrictions).
|
||||
"""
|
||||
|
||||
allowed_paths: list[str] = Field(
|
||||
default_factory=lambda: ["."],
|
||||
description="Restrict operations to these paths. Defaults to current directory.",
|
||||
)
|
||||
|
||||
def _validate_path(self, path: str) -> tuple[bool, Path | str]:
|
||||
"""Validate and resolve a path against allowed_paths whitelist.
|
||||
|
||||
Args:
|
||||
path: The path to validate.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, result) where:
|
||||
- If valid: (True, resolved_path as Path)
|
||||
- If invalid: (False, error_message as str)
|
||||
"""
|
||||
try:
|
||||
resolved = Path(path).resolve()
|
||||
|
||||
# If no restrictions, allow all paths
|
||||
if not self.allowed_paths:
|
||||
return True, resolved
|
||||
|
||||
# Check if path is within any allowed path
|
||||
for allowed in self.allowed_paths:
|
||||
allowed_resolved = Path(allowed).resolve()
|
||||
try:
|
||||
# This will raise ValueError if resolved is not relative to allowed_resolved
|
||||
resolved.relative_to(allowed_resolved)
|
||||
return True, resolved
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return (
|
||||
False,
|
||||
f"Path '{path}' is outside allowed paths: {self.allowed_paths}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Invalid path '{path}': {e}"
|
||||
|
||||
def _format_size(self, size: int) -> str:
|
||||
"""Format file size in human-readable format.
|
||||
|
||||
Args:
|
||||
size: Size in bytes.
|
||||
|
||||
Returns:
|
||||
Human-readable size string (e.g., "1.5KB", "2.3MB").
|
||||
"""
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
if size < 1024 * 1024:
|
||||
return f"{size / 1024:.1f}KB"
|
||||
if size < 1024 * 1024 * 1024:
|
||||
return f"{size / (1024 * 1024):.1f}MB"
|
||||
return f"{size / (1024 * 1024 * 1024):.1f}GB"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Subclasses must implement this method."""
|
||||
raise NotImplementedError("Subclasses must implement _run method")
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Manager class for environment tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.experimental.environment_tools.file_read_tool import FileReadTool
|
||||
from crewai.experimental.environment_tools.file_search_tool import FileSearchTool
|
||||
from crewai.experimental.environment_tools.grep_tool import GrepTool
|
||||
from crewai.experimental.environment_tools.list_dir_tool import ListDirTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class EnvironmentTools:
|
||||
"""Manager class for file system/environment tools.
|
||||
|
||||
Provides a convenient way to create a set of file system tools
|
||||
with shared security configuration (allowed_paths).
|
||||
|
||||
Similar to AgentTools but for file system operations. Use this to
|
||||
give agents the ability to explore and read files for context engineering.
|
||||
|
||||
Example:
|
||||
from crewai.experimental import EnvironmentTools
|
||||
|
||||
# Create tools with security sandbox
|
||||
env_tools = EnvironmentTools(
|
||||
allowed_paths=["./src", "./docs"],
|
||||
)
|
||||
|
||||
# Use with an agent
|
||||
agent = Agent(
|
||||
role="Code Analyst",
|
||||
tools=env_tools.tools(),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_paths: list[str] | None = None,
|
||||
include_grep: bool = True,
|
||||
include_search: bool = True,
|
||||
) -> None:
|
||||
"""Initialize EnvironmentTools.
|
||||
|
||||
Args:
|
||||
allowed_paths: List of paths to restrict operations to.
|
||||
Defaults to current directory ["."] if None.
|
||||
Pass empty list [] to allow all paths (not recommended).
|
||||
include_grep: Whether to include GrepTool (requires grep installed).
|
||||
include_search: Whether to include FileSearchTool.
|
||||
"""
|
||||
self.allowed_paths = allowed_paths if allowed_paths is not None else ["."]
|
||||
self.include_grep = include_grep
|
||||
self.include_search = include_search
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Get all configured environment tools.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances with shared allowed_paths configuration.
|
||||
"""
|
||||
tool_list: list[BaseTool] = [
|
||||
FileReadTool(allowed_paths=self.allowed_paths),
|
||||
ListDirTool(allowed_paths=self.allowed_paths),
|
||||
]
|
||||
|
||||
if self.include_grep:
|
||||
tool_list.append(GrepTool(allowed_paths=self.allowed_paths))
|
||||
|
||||
if self.include_search:
|
||||
tool_list.append(FileSearchTool(allowed_paths=self.allowed_paths))
|
||||
|
||||
return tool_list
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Tool for reading file contents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.experimental.environment_tools.base_environment_tool import (
|
||||
BaseEnvironmentTool,
|
||||
)
|
||||
|
||||
|
||||
class FileReadInput(BaseModel):
|
||||
"""Input schema for reading files."""
|
||||
|
||||
path: str = Field(..., description="Path to the file to read")
|
||||
start_line: int | None = Field(
|
||||
default=None,
|
||||
description="Line to start reading from (1-indexed). If None, starts from beginning.",
|
||||
)
|
||||
line_count: int | None = Field(
|
||||
default=None,
|
||||
description="Number of lines to read. If None, reads to end of file.",
|
||||
)
|
||||
|
||||
|
||||
class FileReadTool(BaseEnvironmentTool):
|
||||
"""Read contents of text files with optional line ranges.
|
||||
|
||||
Use this tool to:
|
||||
- Read configuration files, source code, logs
|
||||
- Inspect file contents before making decisions
|
||||
- Load reference documentation or data files
|
||||
|
||||
Supports reading entire files or specific line ranges for efficiency.
|
||||
"""
|
||||
|
||||
name: str = "read_file"
|
||||
description: str = """Read the contents of a text file.
|
||||
|
||||
Use this to read configuration files, source code, logs, or any text file.
|
||||
You can optionally specify start_line and line_count to read specific portions.
|
||||
|
||||
Examples:
|
||||
- Read entire file: path="config.yaml"
|
||||
- Read lines 100-149: path="large.log", start_line=100, line_count=50
|
||||
"""
|
||||
args_schema: type[BaseModel] = FileReadInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
path: str,
|
||||
start_line: int | None = None,
|
||||
line_count: int | None = None,
|
||||
) -> str:
|
||||
"""Read file contents with optional line range.
|
||||
|
||||
Args:
|
||||
path: Path to the file to read.
|
||||
start_line: Line to start reading from (1-indexed).
|
||||
line_count: Number of lines to read.
|
||||
|
||||
Returns:
|
||||
File contents with metadata header, or error message.
|
||||
"""
|
||||
# Validate path against allowed_paths
|
||||
valid, result = self._validate_path(path)
|
||||
if not valid:
|
||||
return f"Error: {result}"
|
||||
|
||||
assert isinstance(result, Path) # noqa: S101
|
||||
file_path = result
|
||||
|
||||
# Check file exists and is a file
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
if start_line is None and line_count is None:
|
||||
# Read entire file
|
||||
content = f.read()
|
||||
else:
|
||||
# Read specific line range
|
||||
lines = f.readlines()
|
||||
start_idx = (start_line or 1) - 1 # Convert to 0-indexed
|
||||
start_idx = max(0, start_idx) # Ensure non-negative
|
||||
|
||||
if line_count is not None:
|
||||
end_idx = start_idx + line_count
|
||||
else:
|
||||
end_idx = len(lines)
|
||||
|
||||
content = "".join(lines[start_idx:end_idx])
|
||||
|
||||
# Get file metadata
|
||||
stat = file_path.stat()
|
||||
total_lines = content.count("\n") + (
|
||||
1 if content and not content.endswith("\n") else 0
|
||||
)
|
||||
|
||||
# Format output with metadata header
|
||||
header = f"File: {path}\n"
|
||||
header += f"Size: {self._format_size(stat.st_size)} | Lines: {total_lines}"
|
||||
|
||||
if start_line is not None or line_count is not None:
|
||||
header += (
|
||||
f" | Range: {start_line or 1}-{(start_line or 1) + total_lines - 1}"
|
||||
)
|
||||
|
||||
header += "\n" + "=" * 60 + "\n"
|
||||
|
||||
return header + content
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return f"Error: File is not a text file or has encoding issues: {path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {path}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
@@ -1,127 +0,0 @@
|
||||
"""Tool for finding files by name pattern."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.experimental.environment_tools.base_environment_tool import (
|
||||
BaseEnvironmentTool,
|
||||
)
|
||||
|
||||
|
||||
class FileSearchInput(BaseModel):
|
||||
"""Input schema for file search."""
|
||||
|
||||
pattern: str = Field(
|
||||
...,
|
||||
description="Filename pattern to search for (glob syntax, e.g., '*.py', 'test_*.py')",
|
||||
)
|
||||
path: str = Field(
|
||||
default=".",
|
||||
description="Directory to search in",
|
||||
)
|
||||
file_type: Literal["file", "dir", "all"] | None = Field(
|
||||
default="all",
|
||||
description="Filter by type: 'file' for files only, 'dir' for directories only, 'all' for both",
|
||||
)
|
||||
|
||||
|
||||
class FileSearchTool(BaseEnvironmentTool):
|
||||
"""Find files by name pattern.
|
||||
|
||||
Use this tool to:
|
||||
- Find specific files in a codebase
|
||||
- Locate configuration files
|
||||
- Search for files matching a pattern
|
||||
"""
|
||||
|
||||
name: str = "find_files"
|
||||
description: str = """Find files by name pattern using glob syntax.
|
||||
|
||||
Searches recursively through directories to find matching files.
|
||||
|
||||
Examples:
|
||||
- Find Python files: pattern="*.py", path="src/"
|
||||
- Find test files: pattern="test_*.py"
|
||||
- Find configs: pattern="*.yaml", path="."
|
||||
- Find directories only: pattern="*", file_type="dir"
|
||||
"""
|
||||
args_schema: type[BaseModel] = FileSearchInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str = ".",
|
||||
file_type: Literal["file", "dir", "all"] | None = "all",
|
||||
) -> str:
|
||||
"""Find files matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob pattern for filenames.
|
||||
path: Directory to search in.
|
||||
file_type: Filter by type ('file', 'dir', or 'all').
|
||||
|
||||
Returns:
|
||||
List of matching files or error message.
|
||||
"""
|
||||
# Validate path against allowed_paths
|
||||
valid, result = self._validate_path(path)
|
||||
if not valid:
|
||||
return f"Error: {result}"
|
||||
|
||||
search_path = result
|
||||
|
||||
# Check directory exists
|
||||
if not search_path.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
|
||||
if not search_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
try:
|
||||
# Find matching entries recursively
|
||||
matches = list(search_path.rglob(pattern))
|
||||
|
||||
# Filter by type
|
||||
if file_type == "file":
|
||||
matches = [m for m in matches if m.is_file()]
|
||||
elif file_type == "dir":
|
||||
matches = [m for m in matches if m.is_dir()]
|
||||
|
||||
# Filter out hidden files
|
||||
matches = [
|
||||
m for m in matches if not any(part.startswith(".") for part in m.parts)
|
||||
]
|
||||
|
||||
# Sort alphabetically
|
||||
matches.sort(key=lambda x: str(x).lower())
|
||||
|
||||
if not matches:
|
||||
return f"No {file_type if file_type != 'all' else 'files'} matching '{pattern}' found in {path}"
|
||||
|
||||
# Format output
|
||||
result_lines = [f"Found {len(matches)} matches for '{pattern}' in {path}:"]
|
||||
result_lines.append("=" * 60)
|
||||
|
||||
for match in matches:
|
||||
# Get relative path from search directory
|
||||
rel_path = match.relative_to(search_path)
|
||||
|
||||
if match.is_dir():
|
||||
result_lines.append(f"📁 {rel_path}/")
|
||||
else:
|
||||
try:
|
||||
size = match.stat().st_size
|
||||
except (OSError, PermissionError):
|
||||
continue # Skip files we can't stat
|
||||
size_str = self._format_size(size)
|
||||
result_lines.append(f"📄 {rel_path} ({size_str})")
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {path}"
|
||||
except Exception as e:
|
||||
return f"Error searching files: {e}"
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Tool for searching patterns in files using grep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.experimental.environment_tools.base_environment_tool import (
|
||||
BaseEnvironmentTool,
|
||||
)
|
||||
|
||||
|
||||
class GrepInput(BaseModel):
|
||||
"""Input schema for grep search."""
|
||||
|
||||
pattern: str = Field(..., description="Search pattern (supports regex)")
|
||||
path: str = Field(..., description="File or directory to search in")
|
||||
recursive: bool = Field(
|
||||
default=True,
|
||||
description="Search recursively in directories",
|
||||
)
|
||||
ignore_case: bool = Field(
|
||||
default=False,
|
||||
description="Case-insensitive search",
|
||||
)
|
||||
context_lines: int = Field(
|
||||
default=2,
|
||||
description="Number of context lines to show before/after matches",
|
||||
)
|
||||
|
||||
|
||||
class GrepTool(BaseEnvironmentTool):
|
||||
"""Search for text patterns in files using grep.
|
||||
|
||||
Use this tool to:
|
||||
- Find where a function or class is defined
|
||||
- Search for error messages in logs
|
||||
- Locate configuration values
|
||||
- Find TODO comments or specific patterns
|
||||
"""
|
||||
|
||||
name: str = "grep_search"
|
||||
description: str = """Search for text patterns in files using grep.
|
||||
|
||||
Supports regex patterns. Returns matching lines with context.
|
||||
|
||||
Examples:
|
||||
- Find function: pattern="def process_data", path="src/"
|
||||
- Search logs: pattern="ERROR", path="logs/app.log"
|
||||
- Case-insensitive: pattern="todo", path=".", ignore_case=True
|
||||
"""
|
||||
args_schema: type[BaseModel] = GrepInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
recursive: bool = True,
|
||||
ignore_case: bool = False,
|
||||
context_lines: int = 2,
|
||||
) -> str:
|
||||
"""Search for patterns in files.
|
||||
|
||||
Args:
|
||||
pattern: Search pattern (regex supported).
|
||||
path: File or directory to search in.
|
||||
recursive: Whether to search recursively.
|
||||
ignore_case: Whether to ignore case.
|
||||
context_lines: Lines of context around matches.
|
||||
|
||||
Returns:
|
||||
Search results or error message.
|
||||
"""
|
||||
# Validate path against allowed_paths
|
||||
valid, result = self._validate_path(path)
|
||||
if not valid:
|
||||
return f"Error: {result}"
|
||||
|
||||
search_path = result
|
||||
|
||||
# Check path exists
|
||||
if not search_path.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
|
||||
try:
|
||||
# Build grep command safely
|
||||
cmd = ["grep", "--color=never"]
|
||||
|
||||
# Add recursive flag if searching directory
|
||||
if recursive and search_path.is_dir():
|
||||
cmd.append("-r")
|
||||
|
||||
# Case insensitive
|
||||
if ignore_case:
|
||||
cmd.append("-i")
|
||||
|
||||
# Context lines
|
||||
if context_lines > 0:
|
||||
cmd.extend(["-C", str(context_lines)])
|
||||
|
||||
# Show line numbers
|
||||
cmd.append("-n")
|
||||
|
||||
# Use -- to prevent pattern from being interpreted as option
|
||||
cmd.append("--")
|
||||
cmd.append(pattern)
|
||||
cmd.append(str(search_path))
|
||||
|
||||
# Execute with timeout
|
||||
# Security: cmd is a list (no shell injection), path is validated above
|
||||
result = subprocess.run( # noqa: S603
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Found matches
|
||||
output = result.stdout
|
||||
# Count actual match lines (not context lines)
|
||||
match_lines = [
|
||||
line
|
||||
for line in output.split("\n")
|
||||
if line and not line.startswith("--")
|
||||
]
|
||||
match_count = len(match_lines)
|
||||
|
||||
header = f"Found {match_count} matches for '{pattern}' in {path}\n"
|
||||
header += "=" * 60 + "\n"
|
||||
return header + output
|
||||
|
||||
if result.returncode == 1:
|
||||
# No matches found (grep returns 1 for no matches)
|
||||
return f"No matches found for '{pattern}' in {path}"
|
||||
|
||||
# Error occurred
|
||||
error_msg = result.stderr.strip() if result.stderr else "Unknown error"
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return "Error: Search timed out (>30s). Try narrowing the search path."
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
"Error: grep command not found. Ensure grep is installed on the system."
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error during search: {e}"
|
||||
@@ -1,147 +0,0 @@
|
||||
"""Tool for listing directory contents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.experimental.environment_tools.base_environment_tool import (
|
||||
BaseEnvironmentTool,
|
||||
)
|
||||
|
||||
|
||||
class ListDirInput(BaseModel):
|
||||
"""Input schema for listing directories."""
|
||||
|
||||
path: str = Field(default=".", description="Directory path to list")
|
||||
pattern: str | None = Field(
|
||||
default=None,
|
||||
description="Glob pattern to filter entries (e.g., '*.py', '*.md')",
|
||||
)
|
||||
recursive: bool = Field(
|
||||
default=False,
|
||||
description="If True, list contents recursively including subdirectories",
|
||||
)
|
||||
|
||||
|
||||
class ListDirTool(BaseEnvironmentTool):
|
||||
"""List contents of a directory with optional filtering.
|
||||
|
||||
Use this tool to:
|
||||
- Explore project structure
|
||||
- Find specific file types
|
||||
- Check what files exist in a directory
|
||||
- Navigate the file system
|
||||
"""
|
||||
|
||||
name: str = "list_directory"
|
||||
description: str = """List contents of a directory.
|
||||
|
||||
Use this to explore directories and find files. You can filter by pattern
|
||||
and optionally list recursively.
|
||||
|
||||
Examples:
|
||||
- List current dir: path="."
|
||||
- List src folder: path="src/"
|
||||
- Find Python files: path=".", pattern="*.py"
|
||||
- Recursive listing: path="src/", recursive=True
|
||||
"""
|
||||
args_schema: type[BaseModel] = ListDirInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
path: str = ".",
|
||||
pattern: str | None = None,
|
||||
recursive: bool = False,
|
||||
) -> str:
|
||||
"""List directory contents.
|
||||
|
||||
Args:
|
||||
path: Directory path to list.
|
||||
pattern: Glob pattern to filter entries.
|
||||
recursive: Whether to list recursively.
|
||||
|
||||
Returns:
|
||||
Formatted directory listing or error message.
|
||||
"""
|
||||
# Validate path against allowed_paths
|
||||
valid, result = self._validate_path(path)
|
||||
if not valid:
|
||||
return f"Error: {result}"
|
||||
|
||||
assert isinstance(result, Path) # noqa: S101
|
||||
dir_path = result
|
||||
|
||||
# Check directory exists
|
||||
if not dir_path.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
try:
|
||||
# Get entries based on pattern and recursive flag
|
||||
if pattern:
|
||||
if recursive:
|
||||
entries = list(dir_path.rglob(pattern))
|
||||
else:
|
||||
entries = list(dir_path.glob(pattern))
|
||||
else:
|
||||
if recursive:
|
||||
entries = list(dir_path.rglob("*"))
|
||||
else:
|
||||
entries = list(dir_path.iterdir())
|
||||
|
||||
# Filter out hidden files (starting with .)
|
||||
entries = [e for e in entries if not e.name.startswith(".")]
|
||||
|
||||
# Sort: directories first, then files, alphabetically
|
||||
entries.sort(key=lambda x: (not x.is_dir(), x.name.lower()))
|
||||
|
||||
if not entries:
|
||||
if pattern:
|
||||
return f"No entries matching '{pattern}' in {path}"
|
||||
return f"Directory is empty: {path}"
|
||||
|
||||
# Format output
|
||||
result_lines = [f"Contents of {path}:"]
|
||||
result_lines.append("=" * 60)
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
|
||||
for entry in entries:
|
||||
# Get relative path for recursive listings
|
||||
if recursive:
|
||||
display_name = str(entry.relative_to(dir_path))
|
||||
else:
|
||||
display_name = entry.name
|
||||
|
||||
if entry.is_dir():
|
||||
dirs.append(f"📁 {display_name}/")
|
||||
else:
|
||||
try:
|
||||
size = entry.stat().st_size
|
||||
except (OSError, PermissionError):
|
||||
continue # Skip files we can't stat
|
||||
size_str = self._format_size(size)
|
||||
files.append(f"📄 {display_name} ({size_str})")
|
||||
|
||||
# Output directories first, then files
|
||||
if dirs:
|
||||
result_lines.extend(dirs)
|
||||
if files:
|
||||
if dirs:
|
||||
result_lines.append("") # Blank line between dirs and files
|
||||
result_lines.extend(files)
|
||||
|
||||
result_lines.append("")
|
||||
result_lines.append(f"Total: {len(dirs)} directories, {len(files)} files")
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {path}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {e}"
|
||||
@@ -1,9 +1,12 @@
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar, tool
|
||||
from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"EnvVar",
|
||||
"SearchStrategy",
|
||||
"ToolSearchTool",
|
||||
"tool",
|
||||
]
|
||||
|
||||
333
lib/crewai/src/crewai/tools/tool_search_tool.py
Normal file
333
lib/crewai/src/crewai/tools/tool_search_tool.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tool Search Tool for on-demand tool discovery.
|
||||
|
||||
This module implements a Tool Search Tool that allows agents to dynamically
|
||||
discover and load tools on-demand, reducing token consumption when working
|
||||
with large tool libraries.
|
||||
|
||||
Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
|
||||
|
||||
class SearchStrategy(str, Enum):
|
||||
"""Search strategy for tool discovery."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
REGEX = "regex"
|
||||
|
||||
|
||||
class ToolSearchResult(BaseModel):
|
||||
"""Result from a tool search operation."""
|
||||
|
||||
name: str = Field(description="The name of the tool")
|
||||
description: str = Field(description="The description of the tool")
|
||||
args_schema: dict[str, Any] = Field(
|
||||
description="The JSON schema for the tool's arguments"
|
||||
)
|
||||
|
||||
|
||||
class ToolSearchToolSchema(BaseModel):
|
||||
"""Schema for the Tool Search Tool arguments."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query to find relevant tools. Use keywords that describe the capability you need."
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of tools to return. Default is 5.",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
|
||||
|
||||
class ToolSearchTool(BaseTool):
|
||||
"""A tool that searches through a catalog of tools to find relevant ones.
|
||||
|
||||
This tool enables on-demand tool discovery, allowing agents to work with
|
||||
large tool libraries without loading all tool definitions upfront. Instead
|
||||
of consuming tokens with all tool definitions, the agent can search for
|
||||
relevant tools when needed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai.tools import BaseTool, ToolSearchTool
|
||||
|
||||
# Create your tools
|
||||
search_tool = MySearchTool()
|
||||
scrape_tool = MyScrapeWebsiteTool()
|
||||
database_tool = MyDatabaseTool()
|
||||
|
||||
# Create a tool search tool with your tool catalog
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=[search_tool, scrape_tool, database_tool],
|
||||
search_strategy=SearchStrategy.KEYWORD,
|
||||
)
|
||||
|
||||
# Use with an agent - only the tool_search is loaded initially
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
tools=[tool_search], # Other tools discovered on-demand
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
tool_catalog: List of tools available for search.
|
||||
search_strategy: Strategy to use for searching (keyword or regex).
|
||||
custom_search_fn: Optional custom search function for advanced matching.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="Tool Search",
|
||||
description="The name of the tool search tool.",
|
||||
)
|
||||
description: str = Field(
|
||||
default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.",
|
||||
description="Description of what the tool search tool does.",
|
||||
)
|
||||
args_schema: type[BaseModel] = Field(
|
||||
default=ToolSearchToolSchema,
|
||||
description="The schema for the tool search arguments.",
|
||||
)
|
||||
tool_catalog: list[BaseTool | CrewStructuredTool] = Field(
|
||||
default_factory=list,
|
||||
description="List of tools available for search.",
|
||||
)
|
||||
search_strategy: SearchStrategy = Field(
|
||||
default=SearchStrategy.KEYWORD,
|
||||
description="Strategy to use for searching tools.",
|
||||
)
|
||||
custom_search_fn: Callable[
|
||||
[str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool]
|
||||
] | None = Field(
|
||||
default=None,
|
||||
description="Optional custom search function for advanced matching.",
|
||||
)
|
||||
|
||||
def _run(self, query: str, max_results: int = 5) -> str:
|
||||
"""Search for tools matching the query.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant tools.
|
||||
max_results: Maximum number of tools to return.
|
||||
|
||||
Returns:
|
||||
JSON string containing the matching tool definitions.
|
||||
"""
|
||||
if not self.tool_catalog:
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "No tools available in the catalog.",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
|
||||
if self.custom_search_fn:
|
||||
matching_tools = self.custom_search_fn(query, self.tool_catalog)
|
||||
elif self.search_strategy == SearchStrategy.REGEX:
|
||||
matching_tools = self._regex_search(query)
|
||||
else:
|
||||
matching_tools = self._keyword_search(query)
|
||||
|
||||
matching_tools = matching_tools[:max_results]
|
||||
|
||||
if not matching_tools:
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "no_results",
|
||||
"message": f"No tools found matching query: '{query}'. Try different keywords.",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
|
||||
tool_results = []
|
||||
for tool in matching_tools:
|
||||
tool_info = self._get_tool_info(tool)
|
||||
tool_results.append(tool_info)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "success",
|
||||
"message": f"Found {len(tool_results)} tool(s) matching your query.",
|
||||
"tools": tool_results,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
def _keyword_search(
|
||||
self, query: str
|
||||
) -> list[BaseTool | CrewStructuredTool]:
|
||||
"""Search tools using keyword matching.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
List of matching tools sorted by relevance.
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = []
|
||||
|
||||
for tool in self.tool_catalog:
|
||||
score = self._calculate_keyword_score(tool, query_lower, query_words)
|
||||
if score > 0:
|
||||
scored_tools.append((score, tool))
|
||||
|
||||
scored_tools.sort(key=lambda x: x[0], reverse=True)
|
||||
return [tool for _, tool in scored_tools]
|
||||
|
||||
def _calculate_keyword_score(
|
||||
self,
|
||||
tool: BaseTool | CrewStructuredTool,
|
||||
query_lower: str,
|
||||
query_words: set[str],
|
||||
) -> float:
|
||||
"""Calculate relevance score for a tool based on keyword matching.
|
||||
|
||||
Args:
|
||||
tool: The tool to score.
|
||||
query_lower: Lowercase query string.
|
||||
query_words: Set of query words.
|
||||
|
||||
Returns:
|
||||
Relevance score (higher is better).
|
||||
"""
|
||||
score = 0.0
|
||||
tool_name_lower = tool.name.lower()
|
||||
tool_desc_lower = tool.description.lower()
|
||||
|
||||
if query_lower in tool_name_lower:
|
||||
score += 10.0
|
||||
if query_lower in tool_desc_lower:
|
||||
score += 5.0
|
||||
|
||||
for word in query_words:
|
||||
if len(word) < 2:
|
||||
continue
|
||||
if word in tool_name_lower:
|
||||
score += 3.0
|
||||
if word in tool_desc_lower:
|
||||
score += 1.0
|
||||
|
||||
return score
|
||||
|
||||
def _regex_search(
|
||||
self, query: str
|
||||
) -> list[BaseTool | CrewStructuredTool]:
|
||||
"""Search tools using regex pattern matching.
|
||||
|
||||
Args:
|
||||
query: The regex pattern to search for.
|
||||
|
||||
Returns:
|
||||
List of matching tools.
|
||||
"""
|
||||
try:
|
||||
pattern = re.compile(query, re.IGNORECASE)
|
||||
except re.error:
|
||||
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in self.tool_catalog
|
||||
if pattern.search(tool.name) or pattern.search(tool.description)
|
||||
]
|
||||
|
||||
def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]:
|
||||
"""Get tool information as a dictionary.
|
||||
|
||||
Args:
|
||||
tool: The tool to get information from.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool name, description, and args schema.
|
||||
"""
|
||||
if isinstance(tool, BaseTool):
|
||||
schema_dict = generate_model_description(tool.args_schema)
|
||||
args_schema = schema_dict.get("json_schema", {}).get("schema", {})
|
||||
else:
|
||||
args_schema = tool.args_schema.model_json_schema()
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": self._get_original_description(tool),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str:
|
||||
"""Get the original description of a tool without the generated schema.
|
||||
|
||||
Args:
|
||||
tool: The tool to get the description from.
|
||||
|
||||
Returns:
|
||||
The original tool description.
|
||||
"""
|
||||
description = tool.description
|
||||
if "Tool Description:" in description:
|
||||
parts = description.split("Tool Description:")
|
||||
if len(parts) > 1:
|
||||
return parts[1].strip()
|
||||
return description
|
||||
|
||||
def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None:
|
||||
"""Add a tool to the catalog.
|
||||
|
||||
Args:
|
||||
tool: The tool to add.
|
||||
"""
|
||||
self.tool_catalog.append(tool)
|
||||
|
||||
def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None:
|
||||
"""Add multiple tools to the catalog.
|
||||
|
||||
Args:
|
||||
tools: The tools to add.
|
||||
"""
|
||||
self.tool_catalog.extend(tools)
|
||||
|
||||
def remove_tool(self, tool_name: str) -> bool:
|
||||
"""Remove a tool from the catalog by name.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to remove.
|
||||
|
||||
Returns:
|
||||
True if the tool was removed, False if not found.
|
||||
"""
|
||||
for i, tool in enumerate(self.tool_catalog):
|
||||
if tool.name == tool_name:
|
||||
self.tool_catalog.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_catalog_size(self) -> int:
|
||||
"""Get the number of tools in the catalog.
|
||||
|
||||
Returns:
|
||||
The number of tools in the catalog.
|
||||
"""
|
||||
return len(self.tool_catalog)
|
||||
|
||||
def list_tool_names(self) -> list[str]:
|
||||
"""List all tool names in the catalog.
|
||||
|
||||
Returns:
|
||||
List of tool names.
|
||||
"""
|
||||
return [tool.name for tool in self.tool_catalog]
|
||||
@@ -1,408 +0,0 @@
|
||||
"""Tests for experimental environment tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.experimental.environment_tools import (
|
||||
BaseEnvironmentTool,
|
||||
EnvironmentTools,
|
||||
FileReadTool,
|
||||
FileSearchTool,
|
||||
GrepTool,
|
||||
ListDirTool,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Generator[str, None, None]:
|
||||
"""Create a temporary directory with test files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create test files
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n")
|
||||
|
||||
python_file = Path(tmpdir) / "example.py"
|
||||
python_file.write_text("def hello():\n print('Hello World')\n")
|
||||
|
||||
# Create subdirectory with files
|
||||
subdir = Path(tmpdir) / "subdir"
|
||||
subdir.mkdir()
|
||||
(subdir / "nested.txt").write_text("Nested content\n")
|
||||
(subdir / "another.py").write_text("# Another Python file\n")
|
||||
|
||||
yield tmpdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restricted_temp_dir() -> Generator[tuple[str, str], None, None]:
|
||||
"""Create two directories - one allowed, one not."""
|
||||
with tempfile.TemporaryDirectory() as allowed_dir:
|
||||
with tempfile.TemporaryDirectory() as forbidden_dir:
|
||||
# Create files in both
|
||||
(Path(allowed_dir) / "allowed.txt").write_text("Allowed content\n")
|
||||
(Path(forbidden_dir) / "forbidden.txt").write_text("Forbidden content\n")
|
||||
|
||||
yield allowed_dir, forbidden_dir
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BaseEnvironmentTool Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestBaseEnvironmentTool:
|
||||
"""Tests for BaseEnvironmentTool path validation."""
|
||||
|
||||
def test_default_allowed_paths_is_current_directory(self) -> None:
|
||||
"""Default allowed_paths should be current directory for security."""
|
||||
tool = FileReadTool()
|
||||
|
||||
assert tool.allowed_paths == ["."]
|
||||
|
||||
def test_validate_path_explicit_no_restrictions(self, temp_dir: str) -> None:
|
||||
"""With explicit empty allowed_paths, all paths should be allowed."""
|
||||
tool = FileReadTool(allowed_paths=[])
|
||||
valid, result = tool._validate_path(temp_dir)
|
||||
|
||||
assert valid is True
|
||||
assert isinstance(result, Path)
|
||||
|
||||
def test_validate_path_within_allowed(self, temp_dir: str) -> None:
|
||||
"""Paths within allowed_paths should be valid."""
|
||||
tool = FileReadTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
valid, result = tool._validate_path(test_file)
|
||||
|
||||
assert valid is True
|
||||
assert isinstance(result, Path)
|
||||
|
||||
def test_validate_path_outside_allowed(self, restricted_temp_dir: tuple[str, str]) -> None:
|
||||
"""Paths outside allowed_paths should be rejected."""
|
||||
allowed_dir, forbidden_dir = restricted_temp_dir
|
||||
tool = FileReadTool(allowed_paths=[allowed_dir])
|
||||
|
||||
forbidden_file = os.path.join(forbidden_dir, "forbidden.txt")
|
||||
valid, result = tool._validate_path(forbidden_file)
|
||||
|
||||
assert valid is False
|
||||
assert isinstance(result, str)
|
||||
assert "outside allowed paths" in result
|
||||
|
||||
def test_format_size(self) -> None:
|
||||
"""Test human-readable size formatting."""
|
||||
tool = FileReadTool()
|
||||
|
||||
assert tool._format_size(500) == "500B"
|
||||
assert tool._format_size(1024) == "1.0KB"
|
||||
assert tool._format_size(1536) == "1.5KB"
|
||||
assert tool._format_size(1024 * 1024) == "1.0MB"
|
||||
assert tool._format_size(1024 * 1024 * 1024) == "1.0GB"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FileReadTool Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestFileReadTool:
|
||||
"""Tests for FileReadTool."""
|
||||
|
||||
def test_read_entire_file(self, temp_dir: str) -> None:
|
||||
"""Should read entire file contents."""
|
||||
tool = FileReadTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
result = tool._run(path=test_file)
|
||||
|
||||
assert "Line 1" in result
|
||||
assert "Line 2" in result
|
||||
assert "Line 5" in result
|
||||
assert "File:" in result # Metadata header
|
||||
|
||||
def test_read_with_line_range(self, temp_dir: str) -> None:
|
||||
"""Should read specific line range."""
|
||||
tool = FileReadTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
result = tool._run(path=test_file, start_line=2, line_count=2)
|
||||
|
||||
assert "Line 2" in result
|
||||
assert "Line 3" in result
|
||||
# Should not include lines outside range
|
||||
assert "Line 1" not in result.split("=" * 60)[-1] # Check content after header
|
||||
|
||||
def test_read_file_not_found(self, temp_dir: str) -> None:
|
||||
"""Should return error for missing file."""
|
||||
tool = FileReadTool(allowed_paths=[temp_dir])
|
||||
missing_file = os.path.join(temp_dir, "nonexistent.txt")
|
||||
|
||||
result = tool._run(path=missing_file)
|
||||
|
||||
assert "Error: File not found" in result
|
||||
|
||||
def test_read_file_path_restricted(self, restricted_temp_dir: tuple[str, str]) -> None:
|
||||
"""Should reject paths outside allowed_paths."""
|
||||
allowed_dir, forbidden_dir = restricted_temp_dir
|
||||
tool = FileReadTool(allowed_paths=[allowed_dir])
|
||||
|
||||
forbidden_file = os.path.join(forbidden_dir, "forbidden.txt")
|
||||
result = tool._run(path=forbidden_file)
|
||||
|
||||
assert "Error:" in result
|
||||
assert "outside allowed paths" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ListDirTool Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestListDirTool:
|
||||
"""Tests for ListDirTool."""
|
||||
|
||||
def test_list_directory(self, temp_dir: str) -> None:
|
||||
"""Should list directory contents."""
|
||||
tool = ListDirTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(path=temp_dir)
|
||||
|
||||
assert "test.txt" in result
|
||||
assert "example.py" in result
|
||||
assert "subdir" in result
|
||||
assert "Total:" in result
|
||||
|
||||
def test_list_with_pattern(self, temp_dir: str) -> None:
|
||||
"""Should filter by pattern."""
|
||||
tool = ListDirTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(path=temp_dir, pattern="*.py")
|
||||
|
||||
assert "example.py" in result
|
||||
assert "test.txt" not in result
|
||||
|
||||
def test_list_recursive(self, temp_dir: str) -> None:
|
||||
"""Should list recursively when enabled."""
|
||||
tool = ListDirTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(path=temp_dir, recursive=True)
|
||||
|
||||
assert "nested.txt" in result
|
||||
assert "another.py" in result
|
||||
|
||||
def test_list_nonexistent_directory(self, temp_dir: str) -> None:
|
||||
"""Should return error for missing directory."""
|
||||
tool = ListDirTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(path=os.path.join(temp_dir, "nonexistent"))
|
||||
|
||||
assert "Error: Directory not found" in result
|
||||
|
||||
def test_list_path_restricted(self, restricted_temp_dir: tuple[str, str]) -> None:
|
||||
"""Should reject paths outside allowed_paths."""
|
||||
allowed_dir, forbidden_dir = restricted_temp_dir
|
||||
tool = ListDirTool(allowed_paths=[allowed_dir])
|
||||
|
||||
result = tool._run(path=forbidden_dir)
|
||||
|
||||
assert "Error:" in result
|
||||
assert "outside allowed paths" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GrepTool Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestGrepTool:
|
||||
"""Tests for GrepTool."""
|
||||
|
||||
def test_grep_finds_pattern(self, temp_dir: str) -> None:
|
||||
"""Should find matching patterns."""
|
||||
tool = GrepTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
result = tool._run(pattern="Line 2", path=test_file)
|
||||
|
||||
assert "Line 2" in result
|
||||
assert "matches" in result.lower() or "found" in result.lower()
|
||||
|
||||
def test_grep_no_matches(self, temp_dir: str) -> None:
|
||||
"""Should report when no matches found."""
|
||||
tool = GrepTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
result = tool._run(pattern="nonexistent pattern xyz", path=test_file)
|
||||
|
||||
assert "No matches found" in result
|
||||
|
||||
def test_grep_recursive(self, temp_dir: str) -> None:
|
||||
"""Should search recursively in directories."""
|
||||
tool = GrepTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(pattern="Nested", path=temp_dir, recursive=True)
|
||||
|
||||
assert "Nested" in result
|
||||
|
||||
def test_grep_case_insensitive(self, temp_dir: str) -> None:
|
||||
"""Should support case-insensitive search."""
|
||||
tool = GrepTool(allowed_paths=[temp_dir])
|
||||
test_file = os.path.join(temp_dir, "test.txt")
|
||||
|
||||
result = tool._run(pattern="LINE", path=test_file, ignore_case=True)
|
||||
|
||||
assert "Line" in result or "matches" in result.lower()
|
||||
|
||||
def test_grep_path_restricted(self, restricted_temp_dir: tuple[str, str]) -> None:
|
||||
"""Should reject paths outside allowed_paths."""
|
||||
allowed_dir, forbidden_dir = restricted_temp_dir
|
||||
tool = GrepTool(allowed_paths=[allowed_dir])
|
||||
|
||||
result = tool._run(pattern="test", path=forbidden_dir)
|
||||
|
||||
assert "Error:" in result
|
||||
assert "outside allowed paths" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FileSearchTool Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestFileSearchTool:
|
||||
"""Tests for FileSearchTool."""
|
||||
|
||||
def test_find_files_by_pattern(self, temp_dir: str) -> None:
|
||||
"""Should find files matching pattern."""
|
||||
tool = FileSearchTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(pattern="*.py", path=temp_dir)
|
||||
|
||||
assert "example.py" in result
|
||||
assert "another.py" in result
|
||||
|
||||
def test_find_no_matches(self, temp_dir: str) -> None:
|
||||
"""Should report when no files match."""
|
||||
tool = FileSearchTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(pattern="*.xyz", path=temp_dir)
|
||||
|
||||
assert "No" in result and "found" in result
|
||||
|
||||
def test_find_files_only(self, temp_dir: str) -> None:
|
||||
"""Should filter to files only."""
|
||||
tool = FileSearchTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(pattern="*", path=temp_dir, file_type="file")
|
||||
|
||||
# Should include files
|
||||
assert "test.txt" in result or "example.py" in result
|
||||
# Directories should have trailing slash in output
|
||||
# Check that subdir is not listed as a file
|
||||
|
||||
def test_find_dirs_only(self, temp_dir: str) -> None:
|
||||
"""Should filter to directories only."""
|
||||
tool = FileSearchTool(allowed_paths=[temp_dir])
|
||||
|
||||
result = tool._run(pattern="*", path=temp_dir, file_type="dir")
|
||||
|
||||
assert "subdir" in result
|
||||
|
||||
def test_find_path_restricted(self, restricted_temp_dir: tuple[str, str]) -> None:
|
||||
"""Should reject paths outside allowed_paths."""
|
||||
allowed_dir, forbidden_dir = restricted_temp_dir
|
||||
tool = FileSearchTool(allowed_paths=[allowed_dir])
|
||||
|
||||
result = tool._run(pattern="*", path=forbidden_dir)
|
||||
|
||||
assert "Error:" in result
|
||||
assert "outside allowed paths" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EnvironmentTools Manager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestEnvironmentTools:
|
||||
"""Tests for EnvironmentTools manager class."""
|
||||
|
||||
def test_default_allowed_paths_is_current_directory(self) -> None:
|
||||
"""Default should restrict to current directory for security."""
|
||||
env_tools = EnvironmentTools()
|
||||
tools = env_tools.tools()
|
||||
|
||||
# All tools should default to current directory
|
||||
for tool in tools:
|
||||
assert isinstance(tool, BaseEnvironmentTool)
|
||||
assert tool.allowed_paths == ["."]
|
||||
|
||||
def test_explicit_empty_allowed_paths_allows_all(self) -> None:
|
||||
"""Passing empty list should allow all paths."""
|
||||
env_tools = EnvironmentTools(allowed_paths=[])
|
||||
tools = env_tools.tools()
|
||||
|
||||
for tool in tools:
|
||||
assert isinstance(tool, BaseEnvironmentTool)
|
||||
assert tool.allowed_paths == []
|
||||
|
||||
def test_returns_all_tools_by_default(self) -> None:
|
||||
"""Should return all four tools by default."""
|
||||
env_tools = EnvironmentTools()
|
||||
tools = env_tools.tools()
|
||||
|
||||
assert len(tools) == 4
|
||||
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "read_file" in tool_names
|
||||
assert "list_directory" in tool_names
|
||||
assert "grep_search" in tool_names
|
||||
assert "find_files" in tool_names
|
||||
|
||||
def test_exclude_grep(self) -> None:
|
||||
"""Should exclude grep tool when disabled."""
|
||||
env_tools = EnvironmentTools(include_grep=False)
|
||||
tools = env_tools.tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "grep_search" not in tool_names
|
||||
|
||||
def test_exclude_search(self) -> None:
|
||||
"""Should exclude search tool when disabled."""
|
||||
env_tools = EnvironmentTools(include_search=False)
|
||||
tools = env_tools.tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "find_files" not in tool_names
|
||||
|
||||
def test_allowed_paths_propagated(self, temp_dir: str) -> None:
|
||||
"""Should propagate allowed_paths to all tools."""
|
||||
env_tools = EnvironmentTools(allowed_paths=[temp_dir])
|
||||
tools = env_tools.tools()
|
||||
|
||||
for tool in tools:
|
||||
assert isinstance(tool, BaseEnvironmentTool)
|
||||
assert tool.allowed_paths == [temp_dir]
|
||||
|
||||
def test_tools_are_base_tool_instances(self) -> None:
|
||||
"""All returned tools should be BaseTool instances."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
env_tools = EnvironmentTools()
|
||||
tools = env_tools.tools()
|
||||
|
||||
for tool in tools:
|
||||
assert isinstance(tool, BaseTool)
|
||||
393
lib/crewai/tests/tools/test_tool_search_tool.py
Normal file
393
lib/crewai/tests/tools/test_tool_search_tool.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Tests for the ToolSearchTool functionality."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
|
||||
|
||||
|
||||
class MockSearchTool(BaseTool):
|
||||
"""A mock search tool for testing."""
|
||||
|
||||
name: str = "Web Search"
|
||||
description: str = "Search the web for information on any topic."
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Search results for: {query}"
|
||||
|
||||
|
||||
class MockDatabaseTool(BaseTool):
|
||||
"""A mock database tool for testing."""
|
||||
|
||||
name: str = "Database Query"
|
||||
description: str = "Query a SQL database to retrieve data."
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Database results for: {query}"
|
||||
|
||||
|
||||
class MockScrapeTool(BaseTool):
|
||||
"""A mock web scraping tool for testing."""
|
||||
|
||||
name: str = "Web Scraper"
|
||||
description: str = "Scrape content from websites and extract text."
|
||||
|
||||
def _run(self, url: str) -> str:
|
||||
return f"Scraped content from: {url}"
|
||||
|
||||
|
||||
class MockEmailTool(BaseTool):
|
||||
"""A mock email tool for testing."""
|
||||
|
||||
name: str = "Send Email"
|
||||
description: str = "Send an email to a specified recipient."
|
||||
|
||||
def _run(self, to: str, subject: str, body: str) -> str:
|
||||
return f"Email sent to {to}"
|
||||
|
||||
|
||||
class MockCalculatorTool(BaseTool):
|
||||
"""A mock calculator tool for testing."""
|
||||
|
||||
name: str = "Calculator"
|
||||
description: str = "Perform mathematical calculations and arithmetic operations."
|
||||
|
||||
def _run(self, expression: str) -> str:
|
||||
return f"Result: {eval(expression)}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools() -> list[BaseTool]:
|
||||
"""Create a list of sample tools for testing."""
|
||||
return [
|
||||
MockSearchTool(),
|
||||
MockDatabaseTool(),
|
||||
MockScrapeTool(),
|
||||
MockEmailTool(),
|
||||
MockCalculatorTool(),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
|
||||
"""Create a ToolSearchTool with sample tools."""
|
||||
return ToolSearchTool(tool_catalog=sample_tools)
|
||||
|
||||
|
||||
class TestToolSearchToolCreation:
|
||||
"""Tests for ToolSearchTool creation and initialization."""
|
||||
|
||||
def test_create_tool_search_with_empty_catalog(self) -> None:
|
||||
"""Test creating a ToolSearchTool with an empty catalog."""
|
||||
tool_search = ToolSearchTool()
|
||||
assert tool_search.name == "Tool Search"
|
||||
assert tool_search.tool_catalog == []
|
||||
assert tool_search.search_strategy == SearchStrategy.KEYWORD
|
||||
|
||||
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test creating a ToolSearchTool with a list of tools."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
assert len(tool_search.tool_catalog) == 5
|
||||
assert tool_search.get_catalog_size() == 5
|
||||
|
||||
def test_create_tool_search_with_regex_strategy(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test creating a ToolSearchTool with regex search strategy."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
assert tool_search.search_strategy == SearchStrategy.REGEX
|
||||
|
||||
def test_create_tool_search_with_custom_name(self) -> None:
|
||||
"""Test creating a ToolSearchTool with a custom name."""
|
||||
tool_search = ToolSearchTool(name="My Tool Finder")
|
||||
assert tool_search.name == "My Tool Finder"
|
||||
|
||||
|
||||
class TestToolSearchKeywordSearch:
|
||||
"""Tests for keyword-based tool search."""
|
||||
|
||||
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by its exact name."""
|
||||
result = tool_search._run("Web Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
assert result_data["tools"][0]["name"] == "Web Search"
|
||||
|
||||
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by partial name."""
|
||||
result = tool_search._run("Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Search" in tool_names
|
||||
|
||||
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching for a tool by keyword in description."""
|
||||
result = tool_search._run("database")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Database Query" in tool_names
|
||||
|
||||
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching with multiple keywords."""
|
||||
result = tool_search._run("web scrape content")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) >= 1
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Scraper" in tool_names
|
||||
|
||||
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test searching with a query that returns no results."""
|
||||
result = tool_search._run("xyznonexistent123abc")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "no_results"
|
||||
assert len(result_data["tools"]) == 0
|
||||
|
||||
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that max_results limits the number of returned tools."""
|
||||
result = tool_search._run("tool", max_results=2)
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) <= 2
|
||||
|
||||
def test_search_empty_catalog(self) -> None:
|
||||
"""Test searching with an empty tool catalog."""
|
||||
tool_search = ToolSearchTool()
|
||||
result = tool_search._run("search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "error"
|
||||
assert "No tools available" in result_data["message"]
|
||||
|
||||
|
||||
class TestToolSearchRegexSearch:
|
||||
"""Tests for regex-based tool search."""
|
||||
|
||||
def test_regex_search_simple_pattern(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test regex search with a simple pattern."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("Web.*")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Web Search" in tool_names or "Web Scraper" in tool_names
|
||||
|
||||
def test_regex_search_case_insensitive(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that regex search is case insensitive."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("email")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_names = [t["name"] for t in result_data["tools"]]
|
||||
assert "Send Email" in tool_names
|
||||
|
||||
def test_regex_search_invalid_pattern_fallback(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that invalid regex patterns are escaped and still work."""
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
||||
)
|
||||
result = tool_search._run("[invalid(regex")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] in ["success", "no_results"]
|
||||
|
||||
|
||||
class TestToolSearchCustomSearch:
|
||||
"""Tests for custom search function."""
|
||||
|
||||
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test using a custom search function."""
|
||||
|
||||
def custom_search(
|
||||
query: str, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
return [t for t in tools if "email" in t.name.lower()]
|
||||
|
||||
tool_search = ToolSearchTool(
|
||||
tool_catalog=sample_tools, custom_search_fn=custom_search
|
||||
)
|
||||
result = tool_search._run("anything")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert len(result_data["tools"]) == 1
|
||||
assert result_data["tools"][0]["name"] == "Send Email"
|
||||
|
||||
|
||||
class TestToolSearchCatalogManagement:
|
||||
"""Tests for tool catalog management."""
|
||||
|
||||
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test adding a tool to the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
|
||||
class NewTool(BaseTool):
|
||||
name: str = "New Tool"
|
||||
description: str = "A new tool for testing."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "New tool result"
|
||||
|
||||
tool_search.add_tool(NewTool())
|
||||
assert tool_search.get_catalog_size() == initial_size + 1
|
||||
|
||||
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test adding multiple tools to the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
|
||||
class NewTool1(BaseTool):
|
||||
name: str = "New Tool 1"
|
||||
description: str = "First new tool."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "Result 1"
|
||||
|
||||
class NewTool2(BaseTool):
|
||||
name: str = "New Tool 2"
|
||||
description: str = "Second new tool."
|
||||
|
||||
def _run(self) -> str:
|
||||
return "Result 2"
|
||||
|
||||
tool_search.add_tools([NewTool1(), NewTool2()])
|
||||
assert tool_search.get_catalog_size() == initial_size + 2
|
||||
|
||||
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test removing a tool from the catalog."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
result = tool_search.remove_tool("Web Search")
|
||||
|
||||
assert result is True
|
||||
assert tool_search.get_catalog_size() == initial_size - 1
|
||||
|
||||
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test removing a tool that doesn't exist."""
|
||||
initial_size = tool_search.get_catalog_size()
|
||||
result = tool_search.remove_tool("Nonexistent Tool")
|
||||
|
||||
assert result is False
|
||||
assert tool_search.get_catalog_size() == initial_size
|
||||
|
||||
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test listing all tool names in the catalog."""
|
||||
names = tool_search.list_tool_names()
|
||||
|
||||
assert len(names) == 5
|
||||
assert "Web Search" in names
|
||||
assert "Database Query" in names
|
||||
assert "Web Scraper" in names
|
||||
assert "Send Email" in names
|
||||
assert "Calculator" in names
|
||||
|
||||
|
||||
class TestToolSearchResultFormat:
|
||||
"""Tests for the format of search results."""
|
||||
|
||||
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that search results contain complete tool information."""
|
||||
result = tool_search._run("Calculator")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_info = result_data["tools"][0]
|
||||
|
||||
assert "name" in tool_info
|
||||
assert "description" in tool_info
|
||||
assert "args_schema" in tool_info
|
||||
assert tool_info["name"] == "Calculator"
|
||||
|
||||
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test that args_schema is properly formatted."""
|
||||
result = tool_search._run("Email")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
tool_info = result_data["tools"][0]
|
||||
|
||||
assert "args_schema" in tool_info
|
||||
args_schema = tool_info["args_schema"]
|
||||
assert isinstance(args_schema, dict)
|
||||
|
||||
|
||||
class TestToolSearchIntegration:
|
||||
"""Integration tests for ToolSearchTool."""
|
||||
|
||||
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
|
||||
"""Test that ToolSearchTool works as a BaseTool."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
|
||||
assert isinstance(tool_search, BaseTool)
|
||||
assert tool_search.name == "Tool Search"
|
||||
assert "search" in tool_search.description.lower()
|
||||
|
||||
def test_tool_search_to_structured_tool(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test converting ToolSearchTool to structured tool."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
structured = tool_search.to_structured_tool()
|
||||
|
||||
assert structured.name == "Tool Search"
|
||||
assert structured.args_schema is not None
|
||||
|
||||
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
|
||||
"""Test the run method of ToolSearchTool."""
|
||||
result = tool_search.run(query="search", max_results=3)
|
||||
|
||||
assert isinstance(result, str)
|
||||
result_data = json.loads(result)
|
||||
assert "status" in result_data
|
||||
assert "tools" in result_data
|
||||
|
||||
|
||||
class TestToolSearchScoring:
|
||||
"""Tests for the keyword scoring algorithm."""
|
||||
|
||||
def test_exact_name_match_scores_highest(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that exact name matches score higher than partial matches."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
result = tool_search._run("Web Search")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert result_data["tools"][0]["name"] == "Web Search"
|
||||
|
||||
def test_name_match_scores_higher_than_description(
|
||||
self, sample_tools: list[BaseTool]
|
||||
) -> None:
|
||||
"""Test that name matches score higher than description matches."""
|
||||
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
||||
result = tool_search._run("Calculator")
|
||||
result_data = json.loads(result)
|
||||
|
||||
assert result_data["status"] == "success"
|
||||
assert result_data["tools"][0]["name"] == "Calculator"
|
||||
Reference in New Issue
Block a user