mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 19:18:30 +00:00
Implements Anthropic's Tool Search Tool pattern for on-demand tool loading, reducing token consumption when working with large tool libraries. Features: - ToolSearchTool class that searches through a catalog of tools - Keyword-based search with relevance scoring (default) - Regex-based search as alternative strategy - Support for custom search functions - Tool catalog management (add, remove, list tools) - Returns JSON with tool definitions including name, description, and args_schema Closes #4224 Co-Authored-By: João <joao@crewai.com>
394 lines
14 KiB
Python
394 lines
14 KiB
Python
"""Tests for the ToolSearchTool functionality."""
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
|
|
|
|
|
|
class MockSearchTool(BaseTool):
|
|
"""A mock search tool for testing."""
|
|
|
|
name: str = "Web Search"
|
|
description: str = "Search the web for information on any topic."
|
|
|
|
def _run(self, query: str) -> str:
|
|
return f"Search results for: {query}"
|
|
|
|
|
|
class MockDatabaseTool(BaseTool):
|
|
"""A mock database tool for testing."""
|
|
|
|
name: str = "Database Query"
|
|
description: str = "Query a SQL database to retrieve data."
|
|
|
|
def _run(self, query: str) -> str:
|
|
return f"Database results for: {query}"
|
|
|
|
|
|
class MockScrapeTool(BaseTool):
|
|
"""A mock web scraping tool for testing."""
|
|
|
|
name: str = "Web Scraper"
|
|
description: str = "Scrape content from websites and extract text."
|
|
|
|
def _run(self, url: str) -> str:
|
|
return f"Scraped content from: {url}"
|
|
|
|
|
|
class MockEmailTool(BaseTool):
|
|
"""A mock email tool for testing."""
|
|
|
|
name: str = "Send Email"
|
|
description: str = "Send an email to a specified recipient."
|
|
|
|
def _run(self, to: str, subject: str, body: str) -> str:
|
|
return f"Email sent to {to}"
|
|
|
|
|
|
class MockCalculatorTool(BaseTool):
|
|
"""A mock calculator tool for testing."""
|
|
|
|
name: str = "Calculator"
|
|
description: str = "Perform mathematical calculations and arithmetic operations."
|
|
|
|
def _run(self, expression: str) -> str:
|
|
return f"Result: {eval(expression)}"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tools() -> list[BaseTool]:
|
|
"""Create a list of sample tools for testing."""
|
|
return [
|
|
MockSearchTool(),
|
|
MockDatabaseTool(),
|
|
MockScrapeTool(),
|
|
MockEmailTool(),
|
|
MockCalculatorTool(),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
|
|
"""Create a ToolSearchTool with sample tools."""
|
|
return ToolSearchTool(tool_catalog=sample_tools)
|
|
|
|
|
|
class TestToolSearchToolCreation:
|
|
"""Tests for ToolSearchTool creation and initialization."""
|
|
|
|
def test_create_tool_search_with_empty_catalog(self) -> None:
|
|
"""Test creating a ToolSearchTool with an empty catalog."""
|
|
tool_search = ToolSearchTool()
|
|
assert tool_search.name == "Tool Search"
|
|
assert tool_search.tool_catalog == []
|
|
assert tool_search.search_strategy == SearchStrategy.KEYWORD
|
|
|
|
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
|
|
"""Test creating a ToolSearchTool with a list of tools."""
|
|
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
|
assert len(tool_search.tool_catalog) == 5
|
|
assert tool_search.get_catalog_size() == 5
|
|
|
|
def test_create_tool_search_with_regex_strategy(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test creating a ToolSearchTool with regex search strategy."""
|
|
tool_search = ToolSearchTool(
|
|
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
|
)
|
|
assert tool_search.search_strategy == SearchStrategy.REGEX
|
|
|
|
def test_create_tool_search_with_custom_name(self) -> None:
|
|
"""Test creating a ToolSearchTool with a custom name."""
|
|
tool_search = ToolSearchTool(name="My Tool Finder")
|
|
assert tool_search.name == "My Tool Finder"
|
|
|
|
|
|
class TestToolSearchKeywordSearch:
|
|
"""Tests for keyword-based tool search."""
|
|
|
|
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test searching for a tool by its exact name."""
|
|
result = tool_search._run("Web Search")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) >= 1
|
|
assert result_data["tools"][0]["name"] == "Web Search"
|
|
|
|
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test searching for a tool by partial name."""
|
|
result = tool_search._run("Search")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) >= 1
|
|
tool_names = [t["name"] for t in result_data["tools"]]
|
|
assert "Web Search" in tool_names
|
|
|
|
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test searching for a tool by keyword in description."""
|
|
result = tool_search._run("database")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) >= 1
|
|
tool_names = [t["name"] for t in result_data["tools"]]
|
|
assert "Database Query" in tool_names
|
|
|
|
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test searching with multiple keywords."""
|
|
result = tool_search._run("web scrape content")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) >= 1
|
|
tool_names = [t["name"] for t in result_data["tools"]]
|
|
assert "Web Scraper" in tool_names
|
|
|
|
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test searching with a query that returns no results."""
|
|
result = tool_search._run("xyznonexistent123abc")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "no_results"
|
|
assert len(result_data["tools"]) == 0
|
|
|
|
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test that max_results limits the number of returned tools."""
|
|
result = tool_search._run("tool", max_results=2)
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) <= 2
|
|
|
|
def test_search_empty_catalog(self) -> None:
|
|
"""Test searching with an empty tool catalog."""
|
|
tool_search = ToolSearchTool()
|
|
result = tool_search._run("search")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "error"
|
|
assert "No tools available" in result_data["message"]
|
|
|
|
|
|
class TestToolSearchRegexSearch:
|
|
"""Tests for regex-based tool search."""
|
|
|
|
def test_regex_search_simple_pattern(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test regex search with a simple pattern."""
|
|
tool_search = ToolSearchTool(
|
|
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
|
)
|
|
result = tool_search._run("Web.*")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
tool_names = [t["name"] for t in result_data["tools"]]
|
|
assert "Web Search" in tool_names or "Web Scraper" in tool_names
|
|
|
|
def test_regex_search_case_insensitive(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test that regex search is case insensitive."""
|
|
tool_search = ToolSearchTool(
|
|
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
|
)
|
|
result = tool_search._run("email")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
tool_names = [t["name"] for t in result_data["tools"]]
|
|
assert "Send Email" in tool_names
|
|
|
|
def test_regex_search_invalid_pattern_fallback(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test that invalid regex patterns are escaped and still work."""
|
|
tool_search = ToolSearchTool(
|
|
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
|
|
)
|
|
result = tool_search._run("[invalid(regex")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] in ["success", "no_results"]
|
|
|
|
|
|
class TestToolSearchCustomSearch:
|
|
"""Tests for custom search function."""
|
|
|
|
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
|
|
"""Test using a custom search function."""
|
|
|
|
def custom_search(
|
|
query: str, tools: list[BaseTool]
|
|
) -> list[BaseTool]:
|
|
return [t for t in tools if "email" in t.name.lower()]
|
|
|
|
tool_search = ToolSearchTool(
|
|
tool_catalog=sample_tools, custom_search_fn=custom_search
|
|
)
|
|
result = tool_search._run("anything")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert len(result_data["tools"]) == 1
|
|
assert result_data["tools"][0]["name"] == "Send Email"
|
|
|
|
|
|
class TestToolSearchCatalogManagement:
|
|
"""Tests for tool catalog management."""
|
|
|
|
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test adding a tool to the catalog."""
|
|
initial_size = tool_search.get_catalog_size()
|
|
|
|
class NewTool(BaseTool):
|
|
name: str = "New Tool"
|
|
description: str = "A new tool for testing."
|
|
|
|
def _run(self) -> str:
|
|
return "New tool result"
|
|
|
|
tool_search.add_tool(NewTool())
|
|
assert tool_search.get_catalog_size() == initial_size + 1
|
|
|
|
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test adding multiple tools to the catalog."""
|
|
initial_size = tool_search.get_catalog_size()
|
|
|
|
class NewTool1(BaseTool):
|
|
name: str = "New Tool 1"
|
|
description: str = "First new tool."
|
|
|
|
def _run(self) -> str:
|
|
return "Result 1"
|
|
|
|
class NewTool2(BaseTool):
|
|
name: str = "New Tool 2"
|
|
description: str = "Second new tool."
|
|
|
|
def _run(self) -> str:
|
|
return "Result 2"
|
|
|
|
tool_search.add_tools([NewTool1(), NewTool2()])
|
|
assert tool_search.get_catalog_size() == initial_size + 2
|
|
|
|
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test removing a tool from the catalog."""
|
|
initial_size = tool_search.get_catalog_size()
|
|
result = tool_search.remove_tool("Web Search")
|
|
|
|
assert result is True
|
|
assert tool_search.get_catalog_size() == initial_size - 1
|
|
|
|
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test removing a tool that doesn't exist."""
|
|
initial_size = tool_search.get_catalog_size()
|
|
result = tool_search.remove_tool("Nonexistent Tool")
|
|
|
|
assert result is False
|
|
assert tool_search.get_catalog_size() == initial_size
|
|
|
|
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test listing all tool names in the catalog."""
|
|
names = tool_search.list_tool_names()
|
|
|
|
assert len(names) == 5
|
|
assert "Web Search" in names
|
|
assert "Database Query" in names
|
|
assert "Web Scraper" in names
|
|
assert "Send Email" in names
|
|
assert "Calculator" in names
|
|
|
|
|
|
class TestToolSearchResultFormat:
|
|
"""Tests for the format of search results."""
|
|
|
|
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test that search results contain complete tool information."""
|
|
result = tool_search._run("Calculator")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
tool_info = result_data["tools"][0]
|
|
|
|
assert "name" in tool_info
|
|
assert "description" in tool_info
|
|
assert "args_schema" in tool_info
|
|
assert tool_info["name"] == "Calculator"
|
|
|
|
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test that args_schema is properly formatted."""
|
|
result = tool_search._run("Email")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
tool_info = result_data["tools"][0]
|
|
|
|
assert "args_schema" in tool_info
|
|
args_schema = tool_info["args_schema"]
|
|
assert isinstance(args_schema, dict)
|
|
|
|
|
|
class TestToolSearchIntegration:
|
|
"""Integration tests for ToolSearchTool."""
|
|
|
|
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
|
|
"""Test that ToolSearchTool works as a BaseTool."""
|
|
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
|
|
|
assert isinstance(tool_search, BaseTool)
|
|
assert tool_search.name == "Tool Search"
|
|
assert "search" in tool_search.description.lower()
|
|
|
|
def test_tool_search_to_structured_tool(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test converting ToolSearchTool to structured tool."""
|
|
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
|
structured = tool_search.to_structured_tool()
|
|
|
|
assert structured.name == "Tool Search"
|
|
assert structured.args_schema is not None
|
|
|
|
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
|
|
"""Test the run method of ToolSearchTool."""
|
|
result = tool_search.run(query="search", max_results=3)
|
|
|
|
assert isinstance(result, str)
|
|
result_data = json.loads(result)
|
|
assert "status" in result_data
|
|
assert "tools" in result_data
|
|
|
|
|
|
class TestToolSearchScoring:
|
|
"""Tests for the keyword scoring algorithm."""
|
|
|
|
def test_exact_name_match_scores_highest(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test that exact name matches score higher than partial matches."""
|
|
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
|
result = tool_search._run("Web Search")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert result_data["tools"][0]["name"] == "Web Search"
|
|
|
|
def test_name_match_scores_higher_than_description(
|
|
self, sample_tools: list[BaseTool]
|
|
) -> None:
|
|
"""Test that name matches score higher than description matches."""
|
|
tool_search = ToolSearchTool(tool_catalog=sample_tools)
|
|
result = tool_search._run("Calculator")
|
|
result_data = json.loads(result)
|
|
|
|
assert result_data["status"] == "success"
|
|
assert result_data["tools"][0]["name"] == "Calculator"
|