Files
crewAI/lib/crewai/tests/tools/test_tool_search_tool.py
Devin AI fae812ffb7 feat: add ToolSearchTool for on-demand tool discovery
Implements Anthropic's Tool Search Tool pattern for on-demand tool loading,
reducing token consumption when working with large tool libraries.

Features:
- ToolSearchTool class that searches through a catalog of tools
- Keyword-based search with relevance scoring (default)
- Regex-based search as alternative strategy
- Support for custom search functions
- Tool catalog management (add, remove, list tools)
- Returns JSON with tool definitions including name, description, and args_schema

Closes #4224

Co-Authored-By: João <joao@crewai.com>
2026-01-12 09:19:16 +00:00

394 lines
14 KiB
Python

"""Tests for the ToolSearchTool functionality."""
import json
import pytest
from pydantic import BaseModel
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
class MockSearchTool(BaseTool):
"""A mock search tool for testing."""
name: str = "Web Search"
description: str = "Search the web for information on any topic."
def _run(self, query: str) -> str:
return f"Search results for: {query}"
class MockDatabaseTool(BaseTool):
"""A mock database tool for testing."""
name: str = "Database Query"
description: str = "Query a SQL database to retrieve data."
def _run(self, query: str) -> str:
return f"Database results for: {query}"
class MockScrapeTool(BaseTool):
"""A mock web scraping tool for testing."""
name: str = "Web Scraper"
description: str = "Scrape content from websites and extract text."
def _run(self, url: str) -> str:
return f"Scraped content from: {url}"
class MockEmailTool(BaseTool):
"""A mock email tool for testing."""
name: str = "Send Email"
description: str = "Send an email to a specified recipient."
def _run(self, to: str, subject: str, body: str) -> str:
return f"Email sent to {to}"
class MockCalculatorTool(BaseTool):
"""A mock calculator tool for testing."""
name: str = "Calculator"
description: str = "Perform mathematical calculations and arithmetic operations."
def _run(self, expression: str) -> str:
return f"Result: {eval(expression)}"
@pytest.fixture
def sample_tools() -> list[BaseTool]:
"""Create a list of sample tools for testing."""
return [
MockSearchTool(),
MockDatabaseTool(),
MockScrapeTool(),
MockEmailTool(),
MockCalculatorTool(),
]
@pytest.fixture
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
"""Create a ToolSearchTool with sample tools."""
return ToolSearchTool(tool_catalog=sample_tools)
class TestToolSearchToolCreation:
"""Tests for ToolSearchTool creation and initialization."""
def test_create_tool_search_with_empty_catalog(self) -> None:
"""Test creating a ToolSearchTool with an empty catalog."""
tool_search = ToolSearchTool()
assert tool_search.name == "Tool Search"
assert tool_search.tool_catalog == []
assert tool_search.search_strategy == SearchStrategy.KEYWORD
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
"""Test creating a ToolSearchTool with a list of tools."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert len(tool_search.tool_catalog) == 5
assert tool_search.get_catalog_size() == 5
def test_create_tool_search_with_regex_strategy(
self, sample_tools: list[BaseTool]
) -> None:
"""Test creating a ToolSearchTool with regex search strategy."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
assert tool_search.search_strategy == SearchStrategy.REGEX
def test_create_tool_search_with_custom_name(self) -> None:
"""Test creating a ToolSearchTool with a custom name."""
tool_search = ToolSearchTool(name="My Tool Finder")
assert tool_search.name == "My Tool Finder"
class TestToolSearchKeywordSearch:
"""Tests for keyword-based tool search."""
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by its exact name."""
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
assert result_data["tools"][0]["name"] == "Web Search"
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by partial name."""
result = tool_search._run("Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by keyword in description."""
result = tool_search._run("database")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Database Query" in tool_names
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
"""Test searching with multiple keywords."""
result = tool_search._run("web scrape content")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Scraper" in tool_names
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
"""Test searching with a query that returns no results."""
result = tool_search._run("xyznonexistent123abc")
result_data = json.loads(result)
assert result_data["status"] == "no_results"
assert len(result_data["tools"]) == 0
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
"""Test that max_results limits the number of returned tools."""
result = tool_search._run("tool", max_results=2)
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) <= 2
def test_search_empty_catalog(self) -> None:
"""Test searching with an empty tool catalog."""
tool_search = ToolSearchTool()
result = tool_search._run("search")
result_data = json.loads(result)
assert result_data["status"] == "error"
assert "No tools available" in result_data["message"]
class TestToolSearchRegexSearch:
"""Tests for regex-based tool search."""
def test_regex_search_simple_pattern(
self, sample_tools: list[BaseTool]
) -> None:
"""Test regex search with a simple pattern."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("Web.*")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names or "Web Scraper" in tool_names
def test_regex_search_case_insensitive(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that regex search is case insensitive."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Send Email" in tool_names
def test_regex_search_invalid_pattern_fallback(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that invalid regex patterns are escaped and still work."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("[invalid(regex")
result_data = json.loads(result)
assert result_data["status"] in ["success", "no_results"]
class TestToolSearchCustomSearch:
"""Tests for custom search function."""
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
"""Test using a custom search function."""
def custom_search(
query: str, tools: list[BaseTool]
) -> list[BaseTool]:
return [t for t in tools if "email" in t.name.lower()]
tool_search = ToolSearchTool(
tool_catalog=sample_tools, custom_search_fn=custom_search
)
result = tool_search._run("anything")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) == 1
assert result_data["tools"][0]["name"] == "Send Email"
class TestToolSearchCatalogManagement:
"""Tests for tool catalog management."""
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
"""Test adding a tool to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool(BaseTool):
name: str = "New Tool"
description: str = "A new tool for testing."
def _run(self) -> str:
return "New tool result"
tool_search.add_tool(NewTool())
assert tool_search.get_catalog_size() == initial_size + 1
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
"""Test adding multiple tools to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool1(BaseTool):
name: str = "New Tool 1"
description: str = "First new tool."
def _run(self) -> str:
return "Result 1"
class NewTool2(BaseTool):
name: str = "New Tool 2"
description: str = "Second new tool."
def _run(self) -> str:
return "Result 2"
tool_search.add_tools([NewTool1(), NewTool2()])
assert tool_search.get_catalog_size() == initial_size + 2
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool from the catalog."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Web Search")
assert result is True
assert tool_search.get_catalog_size() == initial_size - 1
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool that doesn't exist."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Nonexistent Tool")
assert result is False
assert tool_search.get_catalog_size() == initial_size
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
"""Test listing all tool names in the catalog."""
names = tool_search.list_tool_names()
assert len(names) == 5
assert "Web Search" in names
assert "Database Query" in names
assert "Web Scraper" in names
assert "Send Email" in names
assert "Calculator" in names
class TestToolSearchResultFormat:
"""Tests for the format of search results."""
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
"""Test that search results contain complete tool information."""
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "name" in tool_info
assert "description" in tool_info
assert "args_schema" in tool_info
assert tool_info["name"] == "Calculator"
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
"""Test that args_schema is properly formatted."""
result = tool_search._run("Email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "args_schema" in tool_info
args_schema = tool_info["args_schema"]
assert isinstance(args_schema, dict)
class TestToolSearchIntegration:
"""Integration tests for ToolSearchTool."""
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
"""Test that ToolSearchTool works as a BaseTool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert isinstance(tool_search, BaseTool)
assert tool_search.name == "Tool Search"
assert "search" in tool_search.description.lower()
def test_tool_search_to_structured_tool(
self, sample_tools: list[BaseTool]
) -> None:
"""Test converting ToolSearchTool to structured tool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
structured = tool_search.to_structured_tool()
assert structured.name == "Tool Search"
assert structured.args_schema is not None
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
"""Test the run method of ToolSearchTool."""
result = tool_search.run(query="search", max_results=3)
assert isinstance(result, str)
result_data = json.loads(result)
assert "status" in result_data
assert "tools" in result_data
class TestToolSearchScoring:
"""Tests for the keyword scoring algorithm."""
def test_exact_name_match_scores_highest(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that exact name matches score higher than partial matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Web Search"
def test_name_match_scores_higher_than_description(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that name matches score higher than description matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Calculator"