From fae812ffb718a2b21fb192271565333558793f43 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 09:19:16 +0000 Subject: [PATCH] feat: add ToolSearchTool for on-demand tool discovery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Anthropic's Tool Search Tool pattern for on-demand tool loading, reducing token consumption when working with large tool libraries. Features: - ToolSearchTool class that searches through a catalog of tools - Keyword-based search with relevance scoring (default) - Regex-based search as alternative strategy - Support for custom search functions - Tool catalog management (add, remove, list tools) - Returns JSON with tool definitions including name, description, and args_schema Closes #4224 Co-Authored-By: João --- lib/crewai/src/crewai/tools/__init__.py | 3 + .../src/crewai/tools/tool_search_tool.py | 333 +++++++++++++++ .../tests/tools/test_tool_search_tool.py | 393 ++++++++++++++++++ 3 files changed, 729 insertions(+) create mode 100644 lib/crewai/src/crewai/tools/tool_search_tool.py create mode 100644 lib/crewai/tests/tools/test_tool_search_tool.py diff --git a/lib/crewai/src/crewai/tools/__init__.py b/lib/crewai/src/crewai/tools/__init__.py index ef698c90a..5b66695cf 100644 --- a/lib/crewai/src/crewai/tools/__init__.py +++ b/lib/crewai/src/crewai/tools/__init__.py @@ -1,9 +1,12 @@ from crewai.tools.base_tool import BaseTool, EnvVar, tool +from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool __all__ = [ "BaseTool", "EnvVar", + "SearchStrategy", + "ToolSearchTool", "tool", ] diff --git a/lib/crewai/src/crewai/tools/tool_search_tool.py b/lib/crewai/src/crewai/tools/tool_search_tool.py new file mode 100644 index 000000000..70680c0b3 --- /dev/null +++ b/lib/crewai/src/crewai/tools/tool_search_tool.py @@ -0,0 +1,333 @@ +"""Tool Search Tool for on-demand tool discovery. + +This module implements a Tool Search Tool that allows agents to dynamically +discover and load tools on-demand, reducing token consumption when working +with large tool libraries. + +Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading. +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from enum import Enum +import json +import re +from typing import Any + +from pydantic import BaseModel, Field + +from crewai.tools.base_tool import BaseTool +from crewai.tools.structured_tool import CrewStructuredTool +from crewai.utilities.pydantic_schema_utils import generate_model_description + + +class SearchStrategy(str, Enum): + """Search strategy for tool discovery.""" + + KEYWORD = "keyword" + REGEX = "regex" + + +class ToolSearchResult(BaseModel): + """Result from a tool search operation.""" + + name: str = Field(description="The name of the tool") + description: str = Field(description="The description of the tool") + args_schema: dict[str, Any] = Field( + description="The JSON schema for the tool's arguments" + ) + + +class ToolSearchToolSchema(BaseModel): + """Schema for the Tool Search Tool arguments.""" + + query: str = Field( + description="The search query to find relevant tools. Use keywords that describe the capability you need." + ) + max_results: int = Field( + default=5, + description="Maximum number of tools to return. Default is 5.", + ge=1, + le=20, + ) + + +class ToolSearchTool(BaseTool): + """A tool that searches through a catalog of tools to find relevant ones. + + This tool enables on-demand tool discovery, allowing agents to work with + large tool libraries without loading all tool definitions upfront. Instead + of consuming tokens with all tool definitions, the agent can search for + relevant tools when needed. + + Example: + ```python + from crewai.tools import BaseTool, ToolSearchTool + + # Create your tools + search_tool = MySearchTool() + scrape_tool = MyScrapeWebsiteTool() + database_tool = MyDatabaseTool() + + # Create a tool search tool with your tool catalog + tool_search = ToolSearchTool( + tool_catalog=[search_tool, scrape_tool, database_tool], + search_strategy=SearchStrategy.KEYWORD, + ) + + # Use with an agent - only the tool_search is loaded initially + agent = Agent( + role="Researcher", + tools=[tool_search], # Other tools discovered on-demand + ) + ``` + + Attributes: + tool_catalog: List of tools available for search. + search_strategy: Strategy to use for searching (keyword or regex). + custom_search_fn: Optional custom search function for advanced matching. + """ + + name: str = Field( + default="Tool Search", + description="The name of the tool search tool.", + ) + description: str = Field( + default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.", + description="Description of what the tool search tool does.", + ) + args_schema: type[BaseModel] = Field( + default=ToolSearchToolSchema, + description="The schema for the tool search arguments.", + ) + tool_catalog: list[BaseTool | CrewStructuredTool] = Field( + default_factory=list, + description="List of tools available for search.", + ) + search_strategy: SearchStrategy = Field( + default=SearchStrategy.KEYWORD, + description="Strategy to use for searching tools.", + ) + custom_search_fn: Callable[ + [str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool] + ] | None = Field( + default=None, + description="Optional custom search function for advanced matching.", + ) + + def _run(self, query: str, max_results: int = 5) -> str: + """Search for tools matching the query. + + Args: + query: The search query to find relevant tools. + max_results: Maximum number of tools to return. + + Returns: + JSON string containing the matching tool definitions. + """ + if not self.tool_catalog: + return json.dumps( + { + "status": "error", + "message": "No tools available in the catalog.", + "tools": [], + } + ) + + if self.custom_search_fn: + matching_tools = self.custom_search_fn(query, self.tool_catalog) + elif self.search_strategy == SearchStrategy.REGEX: + matching_tools = self._regex_search(query) + else: + matching_tools = self._keyword_search(query) + + matching_tools = matching_tools[:max_results] + + if not matching_tools: + return json.dumps( + { + "status": "no_results", + "message": f"No tools found matching query: '{query}'. Try different keywords.", + "tools": [], + } + ) + + tool_results = [] + for tool in matching_tools: + tool_info = self._get_tool_info(tool) + tool_results.append(tool_info) + + return json.dumps( + { + "status": "success", + "message": f"Found {len(tool_results)} tool(s) matching your query.", + "tools": tool_results, + }, + indent=2, + ) + + def _keyword_search( + self, query: str + ) -> list[BaseTool | CrewStructuredTool]: + """Search tools using keyword matching. + + Args: + query: The search query. + + Returns: + List of matching tools sorted by relevance. + """ + query_lower = query.lower() + query_words = set(query_lower.split()) + + scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = [] + + for tool in self.tool_catalog: + score = self._calculate_keyword_score(tool, query_lower, query_words) + if score > 0: + scored_tools.append((score, tool)) + + scored_tools.sort(key=lambda x: x[0], reverse=True) + return [tool for _, tool in scored_tools] + + def _calculate_keyword_score( + self, + tool: BaseTool | CrewStructuredTool, + query_lower: str, + query_words: set[str], + ) -> float: + """Calculate relevance score for a tool based on keyword matching. + + Args: + tool: The tool to score. + query_lower: Lowercase query string. + query_words: Set of query words. + + Returns: + Relevance score (higher is better). + """ + score = 0.0 + tool_name_lower = tool.name.lower() + tool_desc_lower = tool.description.lower() + + if query_lower in tool_name_lower: + score += 10.0 + if query_lower in tool_desc_lower: + score += 5.0 + + for word in query_words: + if len(word) < 2: + continue + if word in tool_name_lower: + score += 3.0 + if word in tool_desc_lower: + score += 1.0 + + return score + + def _regex_search( + self, query: str + ) -> list[BaseTool | CrewStructuredTool]: + """Search tools using regex pattern matching. + + Args: + query: The regex pattern to search for. + + Returns: + List of matching tools. + """ + try: + pattern = re.compile(query, re.IGNORECASE) + except re.error: + pattern = re.compile(re.escape(query), re.IGNORECASE) + + return [ + tool + for tool in self.tool_catalog + if pattern.search(tool.name) or pattern.search(tool.description) + ] + + def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]: + """Get tool information as a dictionary. + + Args: + tool: The tool to get information from. + + Returns: + Dictionary containing tool name, description, and args schema. + """ + if isinstance(tool, BaseTool): + schema_dict = generate_model_description(tool.args_schema) + args_schema = schema_dict.get("json_schema", {}).get("schema", {}) + else: + args_schema = tool.args_schema.model_json_schema() + + return { + "name": tool.name, + "description": self._get_original_description(tool), + "args_schema": args_schema, + } + + def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str: + """Get the original description of a tool without the generated schema. + + Args: + tool: The tool to get the description from. + + Returns: + The original tool description. + """ + description = tool.description + if "Tool Description:" in description: + parts = description.split("Tool Description:") + if len(parts) > 1: + return parts[1].strip() + return description + + def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None: + """Add a tool to the catalog. + + Args: + tool: The tool to add. + """ + self.tool_catalog.append(tool) + + def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None: + """Add multiple tools to the catalog. + + Args: + tools: The tools to add. + """ + self.tool_catalog.extend(tools) + + def remove_tool(self, tool_name: str) -> bool: + """Remove a tool from the catalog by name. + + Args: + tool_name: The name of the tool to remove. + + Returns: + True if the tool was removed, False if not found. + """ + for i, tool in enumerate(self.tool_catalog): + if tool.name == tool_name: + self.tool_catalog.pop(i) + return True + return False + + def get_catalog_size(self) -> int: + """Get the number of tools in the catalog. + + Returns: + The number of tools in the catalog. + """ + return len(self.tool_catalog) + + def list_tool_names(self) -> list[str]: + """List all tool names in the catalog. + + Returns: + List of tool names. + """ + return [tool.name for tool in self.tool_catalog] diff --git a/lib/crewai/tests/tools/test_tool_search_tool.py b/lib/crewai/tests/tools/test_tool_search_tool.py new file mode 100644 index 000000000..7ec54e3c7 --- /dev/null +++ b/lib/crewai/tests/tools/test_tool_search_tool.py @@ -0,0 +1,393 @@ +"""Tests for the ToolSearchTool functionality.""" + +import json + +import pytest +from pydantic import BaseModel + +from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool + + +class MockSearchTool(BaseTool): + """A mock search tool for testing.""" + + name: str = "Web Search" + description: str = "Search the web for information on any topic." + + def _run(self, query: str) -> str: + return f"Search results for: {query}" + + +class MockDatabaseTool(BaseTool): + """A mock database tool for testing.""" + + name: str = "Database Query" + description: str = "Query a SQL database to retrieve data." + + def _run(self, query: str) -> str: + return f"Database results for: {query}" + + +class MockScrapeTool(BaseTool): + """A mock web scraping tool for testing.""" + + name: str = "Web Scraper" + description: str = "Scrape content from websites and extract text." + + def _run(self, url: str) -> str: + return f"Scraped content from: {url}" + + +class MockEmailTool(BaseTool): + """A mock email tool for testing.""" + + name: str = "Send Email" + description: str = "Send an email to a specified recipient." + + def _run(self, to: str, subject: str, body: str) -> str: + return f"Email sent to {to}" + + +class MockCalculatorTool(BaseTool): + """A mock calculator tool for testing.""" + + name: str = "Calculator" + description: str = "Perform mathematical calculations and arithmetic operations." + + def _run(self, expression: str) -> str: + return f"Result: {eval(expression)}" + + +@pytest.fixture +def sample_tools() -> list[BaseTool]: + """Create a list of sample tools for testing.""" + return [ + MockSearchTool(), + MockDatabaseTool(), + MockScrapeTool(), + MockEmailTool(), + MockCalculatorTool(), + ] + + +@pytest.fixture +def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool: + """Create a ToolSearchTool with sample tools.""" + return ToolSearchTool(tool_catalog=sample_tools) + + +class TestToolSearchToolCreation: + """Tests for ToolSearchTool creation and initialization.""" + + def test_create_tool_search_with_empty_catalog(self) -> None: + """Test creating a ToolSearchTool with an empty catalog.""" + tool_search = ToolSearchTool() + assert tool_search.name == "Tool Search" + assert tool_search.tool_catalog == [] + assert tool_search.search_strategy == SearchStrategy.KEYWORD + + def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None: + """Test creating a ToolSearchTool with a list of tools.""" + tool_search = ToolSearchTool(tool_catalog=sample_tools) + assert len(tool_search.tool_catalog) == 5 + assert tool_search.get_catalog_size() == 5 + + def test_create_tool_search_with_regex_strategy( + self, sample_tools: list[BaseTool] + ) -> None: + """Test creating a ToolSearchTool with regex search strategy.""" + tool_search = ToolSearchTool( + tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX + ) + assert tool_search.search_strategy == SearchStrategy.REGEX + + def test_create_tool_search_with_custom_name(self) -> None: + """Test creating a ToolSearchTool with a custom name.""" + tool_search = ToolSearchTool(name="My Tool Finder") + assert tool_search.name == "My Tool Finder" + + +class TestToolSearchKeywordSearch: + """Tests for keyword-based tool search.""" + + def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None: + """Test searching for a tool by its exact name.""" + result = tool_search._run("Web Search") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) >= 1 + assert result_data["tools"][0]["name"] == "Web Search" + + def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None: + """Test searching for a tool by partial name.""" + result = tool_search._run("Search") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) >= 1 + tool_names = [t["name"] for t in result_data["tools"]] + assert "Web Search" in tool_names + + def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None: + """Test searching for a tool by keyword in description.""" + result = tool_search._run("database") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) >= 1 + tool_names = [t["name"] for t in result_data["tools"]] + assert "Database Query" in tool_names + + def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None: + """Test searching with multiple keywords.""" + result = tool_search._run("web scrape content") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) >= 1 + tool_names = [t["name"] for t in result_data["tools"]] + assert "Web Scraper" in tool_names + + def test_search_no_results(self, tool_search: ToolSearchTool) -> None: + """Test searching with a query that returns no results.""" + result = tool_search._run("xyznonexistent123abc") + result_data = json.loads(result) + + assert result_data["status"] == "no_results" + assert len(result_data["tools"]) == 0 + + def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None: + """Test that max_results limits the number of returned tools.""" + result = tool_search._run("tool", max_results=2) + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) <= 2 + + def test_search_empty_catalog(self) -> None: + """Test searching with an empty tool catalog.""" + tool_search = ToolSearchTool() + result = tool_search._run("search") + result_data = json.loads(result) + + assert result_data["status"] == "error" + assert "No tools available" in result_data["message"] + + +class TestToolSearchRegexSearch: + """Tests for regex-based tool search.""" + + def test_regex_search_simple_pattern( + self, sample_tools: list[BaseTool] + ) -> None: + """Test regex search with a simple pattern.""" + tool_search = ToolSearchTool( + tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX + ) + result = tool_search._run("Web.*") + result_data = json.loads(result) + + assert result_data["status"] == "success" + tool_names = [t["name"] for t in result_data["tools"]] + assert "Web Search" in tool_names or "Web Scraper" in tool_names + + def test_regex_search_case_insensitive( + self, sample_tools: list[BaseTool] + ) -> None: + """Test that regex search is case insensitive.""" + tool_search = ToolSearchTool( + tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX + ) + result = tool_search._run("email") + result_data = json.loads(result) + + assert result_data["status"] == "success" + tool_names = [t["name"] for t in result_data["tools"]] + assert "Send Email" in tool_names + + def test_regex_search_invalid_pattern_fallback( + self, sample_tools: list[BaseTool] + ) -> None: + """Test that invalid regex patterns are escaped and still work.""" + tool_search = ToolSearchTool( + tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX + ) + result = tool_search._run("[invalid(regex") + result_data = json.loads(result) + + assert result_data["status"] in ["success", "no_results"] + + +class TestToolSearchCustomSearch: + """Tests for custom search function.""" + + def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None: + """Test using a custom search function.""" + + def custom_search( + query: str, tools: list[BaseTool] + ) -> list[BaseTool]: + return [t for t in tools if "email" in t.name.lower()] + + tool_search = ToolSearchTool( + tool_catalog=sample_tools, custom_search_fn=custom_search + ) + result = tool_search._run("anything") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert len(result_data["tools"]) == 1 + assert result_data["tools"][0]["name"] == "Send Email" + + +class TestToolSearchCatalogManagement: + """Tests for tool catalog management.""" + + def test_add_tool(self, tool_search: ToolSearchTool) -> None: + """Test adding a tool to the catalog.""" + initial_size = tool_search.get_catalog_size() + + class NewTool(BaseTool): + name: str = "New Tool" + description: str = "A new tool for testing." + + def _run(self) -> str: + return "New tool result" + + tool_search.add_tool(NewTool()) + assert tool_search.get_catalog_size() == initial_size + 1 + + def test_add_tools(self, tool_search: ToolSearchTool) -> None: + """Test adding multiple tools to the catalog.""" + initial_size = tool_search.get_catalog_size() + + class NewTool1(BaseTool): + name: str = "New Tool 1" + description: str = "First new tool." + + def _run(self) -> str: + return "Result 1" + + class NewTool2(BaseTool): + name: str = "New Tool 2" + description: str = "Second new tool." + + def _run(self) -> str: + return "Result 2" + + tool_search.add_tools([NewTool1(), NewTool2()]) + assert tool_search.get_catalog_size() == initial_size + 2 + + def test_remove_tool(self, tool_search: ToolSearchTool) -> None: + """Test removing a tool from the catalog.""" + initial_size = tool_search.get_catalog_size() + result = tool_search.remove_tool("Web Search") + + assert result is True + assert tool_search.get_catalog_size() == initial_size - 1 + + def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None: + """Test removing a tool that doesn't exist.""" + initial_size = tool_search.get_catalog_size() + result = tool_search.remove_tool("Nonexistent Tool") + + assert result is False + assert tool_search.get_catalog_size() == initial_size + + def test_list_tool_names(self, tool_search: ToolSearchTool) -> None: + """Test listing all tool names in the catalog.""" + names = tool_search.list_tool_names() + + assert len(names) == 5 + assert "Web Search" in names + assert "Database Query" in names + assert "Web Scraper" in names + assert "Send Email" in names + assert "Calculator" in names + + +class TestToolSearchResultFormat: + """Tests for the format of search results.""" + + def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None: + """Test that search results contain complete tool information.""" + result = tool_search._run("Calculator") + result_data = json.loads(result) + + assert result_data["status"] == "success" + tool_info = result_data["tools"][0] + + assert "name" in tool_info + assert "description" in tool_info + assert "args_schema" in tool_info + assert tool_info["name"] == "Calculator" + + def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None: + """Test that args_schema is properly formatted.""" + result = tool_search._run("Email") + result_data = json.loads(result) + + assert result_data["status"] == "success" + tool_info = result_data["tools"][0] + + assert "args_schema" in tool_info + args_schema = tool_info["args_schema"] + assert isinstance(args_schema, dict) + + +class TestToolSearchIntegration: + """Integration tests for ToolSearchTool.""" + + def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None: + """Test that ToolSearchTool works as a BaseTool.""" + tool_search = ToolSearchTool(tool_catalog=sample_tools) + + assert isinstance(tool_search, BaseTool) + assert tool_search.name == "Tool Search" + assert "search" in tool_search.description.lower() + + def test_tool_search_to_structured_tool( + self, sample_tools: list[BaseTool] + ) -> None: + """Test converting ToolSearchTool to structured tool.""" + tool_search = ToolSearchTool(tool_catalog=sample_tools) + structured = tool_search.to_structured_tool() + + assert structured.name == "Tool Search" + assert structured.args_schema is not None + + def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None: + """Test the run method of ToolSearchTool.""" + result = tool_search.run(query="search", max_results=3) + + assert isinstance(result, str) + result_data = json.loads(result) + assert "status" in result_data + assert "tools" in result_data + + +class TestToolSearchScoring: + """Tests for the keyword scoring algorithm.""" + + def test_exact_name_match_scores_highest( + self, sample_tools: list[BaseTool] + ) -> None: + """Test that exact name matches score higher than partial matches.""" + tool_search = ToolSearchTool(tool_catalog=sample_tools) + result = tool_search._run("Web Search") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert result_data["tools"][0]["name"] == "Web Search" + + def test_name_match_scores_higher_than_description( + self, sample_tools: list[BaseTool] + ) -> None: + """Test that name matches score higher than description matches.""" + tool_search = ToolSearchTool(tool_catalog=sample_tools) + result = tool_search._run("Calculator") + result_data = json.loads(result) + + assert result_data["status"] == "success" + assert result_data["tools"][0]["name"] == "Calculator"