Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
fae812ffb7 feat: add ToolSearchTool for on-demand tool discovery
Implements Anthropic's Tool Search Tool pattern for on-demand tool loading,
reducing token consumption when working with large tool libraries.

Features:
- ToolSearchTool class that searches through a catalog of tools
- Keyword-based search with relevance scoring (default)
- Regex-based search as alternative strategy
- Support for custom search functions
- Tool catalog management (add, remove, list tools)
- Returns JSON with tool definitions including name, description, and args_schema

Closes #4224

Co-Authored-By: João <joao@crewai.com>
2026-01-12 09:19:16 +00:00
GininDenis
17e3fcbe1f fix: unlink task in execution spans
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
2026-01-12 02:58:42 -05:00
Joao Moura
b858d705a8 updating docs
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
2026-01-11 16:02:55 -08:00
15 changed files with 807 additions and 298 deletions

View File

@@ -574,6 +574,10 @@ When you run this Flow, the output will change based on the random boolean value
### Human in the Loop (human feedback)
<Note>
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**.
</Note>
The `@human_feedback` decorator enables human-in-the-loop workflows by pausing flow execution to collect feedback from a human. This is useful for approval gates, quality review, and decision points that require human judgment.
```python Code

View File

@@ -7,6 +7,10 @@ mode: "wide"
## Overview
<Note>
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**. Make sure to update your installation before using this feature.
</Note>
The `@human_feedback` decorator enables human-in-the-loop (HITL) workflows directly within CrewAI Flows. It allows you to pause flow execution, present output to a human for review, collect their feedback, and optionally route to different listeners based on the feedback outcome.
This is particularly valuable for:

View File

@@ -11,10 +11,10 @@ Human-in-the-Loop (HITL) is a powerful approach that combines artificial intelli
CrewAI offers two main approaches for implementing human-in-the-loop workflows:
| Approach | Best For | Integration |
|----------|----------|-------------|
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) |
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide |
| Approach | Best For | Integration | Version |
|----------|----------|-------------|---------|
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) | **1.8.0+** |
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide | - |
<Tip>
If you're building flows and want to add human review steps with routing based on feedback, check out the [Human Feedback in Flows](/en/learn/human-feedback-in-flows) guide for the `@human_feedback` decorator.

View File

@@ -567,6 +567,10 @@ Fourth method running
### Human in the Loop (인간 피드백)
<Note>
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다.
</Note>
`@human_feedback` 데코레이터는 인간의 피드백을 수집하기 위해 플로우 실행을 일시 중지하는 human-in-the-loop 워크플로우를 가능하게 합니다. 이는 승인 게이트, 품질 검토, 인간의 판단이 필요한 결정 지점에 유용합니다.
```python Code

View File

@@ -7,6 +7,10 @@ mode: "wide"
## 개요
<Note>
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다. 이 기능을 사용하기 전에 설치를 업데이트하세요.
</Note>
`@human_feedback` 데코레이터는 CrewAI Flow 내에서 직접 human-in-the-loop(HITL) 워크플로우를 가능하게 합니다. Flow 실행을 일시 중지하고, 인간에게 검토를 위해 출력을 제시하고, 피드백을 수집하고, 선택적으로 피드백 결과에 따라 다른 리스너로 라우팅할 수 있습니다.
이는 특히 다음과 같은 경우에 유용합니다:

View File

@@ -5,9 +5,22 @@ icon: "user-check"
mode: "wide"
---
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. 이 가이드에서는 CrewAI 내에서 HITL을 구현하는 방법을 안내합니다.
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. CrewAI는 필요에 따라 HITL을 구현하는 여러 가지 방법을 제공합니다.
## HITL 워크플로우 설정
## HITL 접근 방식 선택
CrewAI는 human-in-the-loop 워크플로우를 구현하기 위한 두 가지 주요 접근 방식을 제공합니다:
| 접근 방식 | 적합한 용도 | 통합 | 버전 |
|----------|----------|-------------|---------|
| **Flow 기반** (`@human_feedback` 데코레이터) | 로컬 개발, 콘솔 기반 검토, 동기식 워크플로우 | [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) | **1.8.0+** |
| **Webhook 기반** (Enterprise) | 프로덕션 배포, 비동기 워크플로우, 외부 통합 (Slack, Teams 등) | 이 가이드 | - |
<Tip>
Flow를 구축하면서 피드백을 기반으로 라우팅하는 인간 검토 단계를 추가하려면 `@human_feedback` 데코레이터에 대한 [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) 가이드를 참조하세요.
</Tip>
## Webhook 기반 HITL 워크플로우 설정
<Steps>
<Step title="작업 구성">

View File

@@ -309,6 +309,10 @@ Ao executar esse Flow, a saída será diferente dependendo do valor booleano ale
### Human in the Loop (feedback humano)
<Note>
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**.
</Note>
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop, pausando a execução do flow para coletar feedback de um humano. Isso é útil para portões de aprovação, revisão de qualidade e pontos de decisão que requerem julgamento humano.
```python Code

View File

@@ -7,6 +7,10 @@ mode: "wide"
## Visão Geral
<Note>
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**. Certifique-se de atualizar sua instalação antes de usar este recurso.
</Note>
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop (HITL) diretamente nos CrewAI Flows. Ele permite pausar a execução do flow, apresentar a saída para um humano revisar, coletar seu feedback e, opcionalmente, rotear para diferentes listeners com base no resultado do feedback.
Isso é particularmente valioso para:

View File

@@ -5,9 +5,22 @@ icon: "user-check"
mode: "wide"
---
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. Este guia mostra como implementar HITL dentro da CrewAI.
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. CrewAI oferece várias maneiras de implementar HITL dependendo das suas necessidades.
## Configurando Workflows HITL
## Escolhendo Sua Abordagem HITL
CrewAI oferece duas abordagens principais para implementar workflows human-in-the-loop:
| Abordagem | Melhor Para | Integração | Versão |
|----------|----------|-------------|---------|
| **Baseada em Flow** (decorador `@human_feedback`) | Desenvolvimento local, revisão via console, workflows síncronos | [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) | **1.8.0+** |
| **Baseada em Webhook** (Enterprise) | Deployments em produção, workflows assíncronos, integrações externas (Slack, Teams, etc.) | Este guia | - |
<Tip>
Se você está construindo flows e deseja adicionar etapas de revisão humana com roteamento baseado em feedback, confira o guia [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) para o decorador `@human_feedback`.
</Tip>
## Configurando Workflows HITL Baseados em Webhook
<Steps>
<Step title="Configure sua Tarefa">

View File

@@ -209,10 +209,9 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
# Handle telemetry
span = self.execution_spans.get(source)
span = self.execution_spans.pop(source, None)
if span:
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = get_task_name(source)
@@ -222,11 +221,10 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(TaskFailedEvent)
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
span = self.execution_spans.get(source)
span = self.execution_spans.pop(source, None)
if span:
if source.agent and source.agent.crew:
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = get_task_name(source)

View File

@@ -1,9 +1,12 @@
from crewai.tools.base_tool import BaseTool, EnvVar, tool
from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool
__all__ = [
"BaseTool",
"EnvVar",
"SearchStrategy",
"ToolSearchTool",
"tool",
]

View File

@@ -0,0 +1,333 @@
"""Tool Search Tool for on-demand tool discovery.
This module implements a Tool Search Tool that allows agents to dynamically
discover and load tools on-demand, reducing token consumption when working
with large tool libraries.
Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from enum import Enum
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.pydantic_schema_utils import generate_model_description
class SearchStrategy(str, Enum):
"""Search strategy for tool discovery."""
KEYWORD = "keyword"
REGEX = "regex"
class ToolSearchResult(BaseModel):
"""Result from a tool search operation."""
name: str = Field(description="The name of the tool")
description: str = Field(description="The description of the tool")
args_schema: dict[str, Any] = Field(
description="The JSON schema for the tool's arguments"
)
class ToolSearchToolSchema(BaseModel):
"""Schema for the Tool Search Tool arguments."""
query: str = Field(
description="The search query to find relevant tools. Use keywords that describe the capability you need."
)
max_results: int = Field(
default=5,
description="Maximum number of tools to return. Default is 5.",
ge=1,
le=20,
)
class ToolSearchTool(BaseTool):
"""A tool that searches through a catalog of tools to find relevant ones.
This tool enables on-demand tool discovery, allowing agents to work with
large tool libraries without loading all tool definitions upfront. Instead
of consuming tokens with all tool definitions, the agent can search for
relevant tools when needed.
Example:
```python
from crewai.tools import BaseTool, ToolSearchTool
# Create your tools
search_tool = MySearchTool()
scrape_tool = MyScrapeWebsiteTool()
database_tool = MyDatabaseTool()
# Create a tool search tool with your tool catalog
tool_search = ToolSearchTool(
tool_catalog=[search_tool, scrape_tool, database_tool],
search_strategy=SearchStrategy.KEYWORD,
)
# Use with an agent - only the tool_search is loaded initially
agent = Agent(
role="Researcher",
tools=[tool_search], # Other tools discovered on-demand
)
```
Attributes:
tool_catalog: List of tools available for search.
search_strategy: Strategy to use for searching (keyword or regex).
custom_search_fn: Optional custom search function for advanced matching.
"""
name: str = Field(
default="Tool Search",
description="The name of the tool search tool.",
)
description: str = Field(
default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.",
description="Description of what the tool search tool does.",
)
args_schema: type[BaseModel] = Field(
default=ToolSearchToolSchema,
description="The schema for the tool search arguments.",
)
tool_catalog: list[BaseTool | CrewStructuredTool] = Field(
default_factory=list,
description="List of tools available for search.",
)
search_strategy: SearchStrategy = Field(
default=SearchStrategy.KEYWORD,
description="Strategy to use for searching tools.",
)
custom_search_fn: Callable[
[str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool]
] | None = Field(
default=None,
description="Optional custom search function for advanced matching.",
)
def _run(self, query: str, max_results: int = 5) -> str:
"""Search for tools matching the query.
Args:
query: The search query to find relevant tools.
max_results: Maximum number of tools to return.
Returns:
JSON string containing the matching tool definitions.
"""
if not self.tool_catalog:
return json.dumps(
{
"status": "error",
"message": "No tools available in the catalog.",
"tools": [],
}
)
if self.custom_search_fn:
matching_tools = self.custom_search_fn(query, self.tool_catalog)
elif self.search_strategy == SearchStrategy.REGEX:
matching_tools = self._regex_search(query)
else:
matching_tools = self._keyword_search(query)
matching_tools = matching_tools[:max_results]
if not matching_tools:
return json.dumps(
{
"status": "no_results",
"message": f"No tools found matching query: '{query}'. Try different keywords.",
"tools": [],
}
)
tool_results = []
for tool in matching_tools:
tool_info = self._get_tool_info(tool)
tool_results.append(tool_info)
return json.dumps(
{
"status": "success",
"message": f"Found {len(tool_results)} tool(s) matching your query.",
"tools": tool_results,
},
indent=2,
)
def _keyword_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using keyword matching.
Args:
query: The search query.
Returns:
List of matching tools sorted by relevance.
"""
query_lower = query.lower()
query_words = set(query_lower.split())
scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = []
for tool in self.tool_catalog:
score = self._calculate_keyword_score(tool, query_lower, query_words)
if score > 0:
scored_tools.append((score, tool))
scored_tools.sort(key=lambda x: x[0], reverse=True)
return [tool for _, tool in scored_tools]
def _calculate_keyword_score(
self,
tool: BaseTool | CrewStructuredTool,
query_lower: str,
query_words: set[str],
) -> float:
"""Calculate relevance score for a tool based on keyword matching.
Args:
tool: The tool to score.
query_lower: Lowercase query string.
query_words: Set of query words.
Returns:
Relevance score (higher is better).
"""
score = 0.0
tool_name_lower = tool.name.lower()
tool_desc_lower = tool.description.lower()
if query_lower in tool_name_lower:
score += 10.0
if query_lower in tool_desc_lower:
score += 5.0
for word in query_words:
if len(word) < 2:
continue
if word in tool_name_lower:
score += 3.0
if word in tool_desc_lower:
score += 1.0
return score
def _regex_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using regex pattern matching.
Args:
query: The regex pattern to search for.
Returns:
List of matching tools.
"""
try:
pattern = re.compile(query, re.IGNORECASE)
except re.error:
pattern = re.compile(re.escape(query), re.IGNORECASE)
return [
tool
for tool in self.tool_catalog
if pattern.search(tool.name) or pattern.search(tool.description)
]
def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]:
"""Get tool information as a dictionary.
Args:
tool: The tool to get information from.
Returns:
Dictionary containing tool name, description, and args schema.
"""
if isinstance(tool, BaseTool):
schema_dict = generate_model_description(tool.args_schema)
args_schema = schema_dict.get("json_schema", {}).get("schema", {})
else:
args_schema = tool.args_schema.model_json_schema()
return {
"name": tool.name,
"description": self._get_original_description(tool),
"args_schema": args_schema,
}
def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str:
"""Get the original description of a tool without the generated schema.
Args:
tool: The tool to get the description from.
Returns:
The original tool description.
"""
description = tool.description
if "Tool Description:" in description:
parts = description.split("Tool Description:")
if len(parts) > 1:
return parts[1].strip()
return description
def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None:
"""Add a tool to the catalog.
Args:
tool: The tool to add.
"""
self.tool_catalog.append(tool)
def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None:
"""Add multiple tools to the catalog.
Args:
tools: The tools to add.
"""
self.tool_catalog.extend(tools)
def remove_tool(self, tool_name: str) -> bool:
"""Remove a tool from the catalog by name.
Args:
tool_name: The name of the tool to remove.
Returns:
True if the tool was removed, False if not found.
"""
for i, tool in enumerate(self.tool_catalog):
if tool.name == tool_name:
self.tool_catalog.pop(i)
return True
return False
def get_catalog_size(self) -> int:
"""Get the number of tools in the catalog.
Returns:
The number of tools in the catalog.
"""
return len(self.tool_catalog)
def list_tool_names(self) -> list[str]:
"""List all tool names in the catalog.
Returns:
List of tool names.
"""
return [tool.name for tool in self.tool_catalog]

View File

@@ -2,11 +2,8 @@ from datetime import datetime
import json
import os
import pickle
import tempfile
import threading
from typing import Any, TypedDict
import portalocker
from typing_extensions import Unpack
@@ -126,15 +123,10 @@ class FileHandler:
class PickleHandler:
"""Thread-safe handler for saving and loading data using pickle.
This class provides thread-safe file operations using portalocker for
cross-process file locking and atomic write operations to prevent
data corruption during concurrent access.
"""Handler for saving and loading data using pickle.
Attributes:
file_path: The path to the pickle file.
_lock: Threading lock for thread-safe operations within the same process.
"""
def __init__(self, file_name: str) -> None:
@@ -149,62 +141,34 @@ class PickleHandler:
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
self._lock = threading.Lock()
def initialize_file(self) -> None:
"""Initialize the file with an empty dictionary and overwrite any existing data."""
self.save({})
def save(self, data: Any) -> None:
"""Save the data to the specified file using pickle with thread-safe atomic writes.
This method uses a two-phase approach for thread safety:
1. Threading lock for same-process thread safety
2. Atomic write (write to temp file, then rename) for cross-process safety
and data integrity
"""
Save the data to the specified file using pickle.
Args:
data: The data to be saved to the file.
data: The data to be saved to the file.
"""
with self._lock:
dir_name = os.path.dirname(self.file_path) or os.getcwd()
fd, temp_path = tempfile.mkstemp(
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
)
try:
with os.fdopen(fd, "wb") as f:
pickle.dump(obj=data, file=f)
f.flush()
os.fsync(f.fileno())
os.replace(temp_path, self.file_path)
except Exception:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
def load(self) -> Any:
"""Load the data from the specified file using pickle with thread-safe locking.
This method uses portalocker for cross-process read locking to ensure
data consistency when multiple processes may be accessing the file.
"""Load the data from the specified file using pickle.
Returns:
The data loaded from the file, or an empty dictionary if the file
does not exist or is empty.
The data loaded from the file.
"""
with self._lock:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with portalocker.Lock(
self.file_path, "rb", flags=portalocker.LOCK_SH
) as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {}
except Exception:
raise
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading

View File

@@ -0,0 +1,393 @@
"""Tests for the ToolSearchTool functionality."""
import json
import pytest
from pydantic import BaseModel
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
class MockSearchTool(BaseTool):
"""A mock search tool for testing."""
name: str = "Web Search"
description: str = "Search the web for information on any topic."
def _run(self, query: str) -> str:
return f"Search results for: {query}"
class MockDatabaseTool(BaseTool):
"""A mock database tool for testing."""
name: str = "Database Query"
description: str = "Query a SQL database to retrieve data."
def _run(self, query: str) -> str:
return f"Database results for: {query}"
class MockScrapeTool(BaseTool):
"""A mock web scraping tool for testing."""
name: str = "Web Scraper"
description: str = "Scrape content from websites and extract text."
def _run(self, url: str) -> str:
return f"Scraped content from: {url}"
class MockEmailTool(BaseTool):
"""A mock email tool for testing."""
name: str = "Send Email"
description: str = "Send an email to a specified recipient."
def _run(self, to: str, subject: str, body: str) -> str:
return f"Email sent to {to}"
class MockCalculatorTool(BaseTool):
"""A mock calculator tool for testing."""
name: str = "Calculator"
description: str = "Perform mathematical calculations and arithmetic operations."
def _run(self, expression: str) -> str:
return f"Result: {eval(expression)}"
@pytest.fixture
def sample_tools() -> list[BaseTool]:
"""Create a list of sample tools for testing."""
return [
MockSearchTool(),
MockDatabaseTool(),
MockScrapeTool(),
MockEmailTool(),
MockCalculatorTool(),
]
@pytest.fixture
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
"""Create a ToolSearchTool with sample tools."""
return ToolSearchTool(tool_catalog=sample_tools)
class TestToolSearchToolCreation:
"""Tests for ToolSearchTool creation and initialization."""
def test_create_tool_search_with_empty_catalog(self) -> None:
"""Test creating a ToolSearchTool with an empty catalog."""
tool_search = ToolSearchTool()
assert tool_search.name == "Tool Search"
assert tool_search.tool_catalog == []
assert tool_search.search_strategy == SearchStrategy.KEYWORD
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
"""Test creating a ToolSearchTool with a list of tools."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert len(tool_search.tool_catalog) == 5
assert tool_search.get_catalog_size() == 5
def test_create_tool_search_with_regex_strategy(
self, sample_tools: list[BaseTool]
) -> None:
"""Test creating a ToolSearchTool with regex search strategy."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
assert tool_search.search_strategy == SearchStrategy.REGEX
def test_create_tool_search_with_custom_name(self) -> None:
"""Test creating a ToolSearchTool with a custom name."""
tool_search = ToolSearchTool(name="My Tool Finder")
assert tool_search.name == "My Tool Finder"
class TestToolSearchKeywordSearch:
"""Tests for keyword-based tool search."""
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by its exact name."""
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
assert result_data["tools"][0]["name"] == "Web Search"
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by partial name."""
result = tool_search._run("Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by keyword in description."""
result = tool_search._run("database")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Database Query" in tool_names
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
"""Test searching with multiple keywords."""
result = tool_search._run("web scrape content")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Scraper" in tool_names
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
"""Test searching with a query that returns no results."""
result = tool_search._run("xyznonexistent123abc")
result_data = json.loads(result)
assert result_data["status"] == "no_results"
assert len(result_data["tools"]) == 0
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
"""Test that max_results limits the number of returned tools."""
result = tool_search._run("tool", max_results=2)
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) <= 2
def test_search_empty_catalog(self) -> None:
"""Test searching with an empty tool catalog."""
tool_search = ToolSearchTool()
result = tool_search._run("search")
result_data = json.loads(result)
assert result_data["status"] == "error"
assert "No tools available" in result_data["message"]
class TestToolSearchRegexSearch:
"""Tests for regex-based tool search."""
def test_regex_search_simple_pattern(
self, sample_tools: list[BaseTool]
) -> None:
"""Test regex search with a simple pattern."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("Web.*")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names or "Web Scraper" in tool_names
def test_regex_search_case_insensitive(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that regex search is case insensitive."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Send Email" in tool_names
def test_regex_search_invalid_pattern_fallback(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that invalid regex patterns are escaped and still work."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("[invalid(regex")
result_data = json.loads(result)
assert result_data["status"] in ["success", "no_results"]
class TestToolSearchCustomSearch:
"""Tests for custom search function."""
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
"""Test using a custom search function."""
def custom_search(
query: str, tools: list[BaseTool]
) -> list[BaseTool]:
return [t for t in tools if "email" in t.name.lower()]
tool_search = ToolSearchTool(
tool_catalog=sample_tools, custom_search_fn=custom_search
)
result = tool_search._run("anything")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) == 1
assert result_data["tools"][0]["name"] == "Send Email"
class TestToolSearchCatalogManagement:
"""Tests for tool catalog management."""
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
"""Test adding a tool to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool(BaseTool):
name: str = "New Tool"
description: str = "A new tool for testing."
def _run(self) -> str:
return "New tool result"
tool_search.add_tool(NewTool())
assert tool_search.get_catalog_size() == initial_size + 1
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
"""Test adding multiple tools to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool1(BaseTool):
name: str = "New Tool 1"
description: str = "First new tool."
def _run(self) -> str:
return "Result 1"
class NewTool2(BaseTool):
name: str = "New Tool 2"
description: str = "Second new tool."
def _run(self) -> str:
return "Result 2"
tool_search.add_tools([NewTool1(), NewTool2()])
assert tool_search.get_catalog_size() == initial_size + 2
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool from the catalog."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Web Search")
assert result is True
assert tool_search.get_catalog_size() == initial_size - 1
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool that doesn't exist."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Nonexistent Tool")
assert result is False
assert tool_search.get_catalog_size() == initial_size
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
"""Test listing all tool names in the catalog."""
names = tool_search.list_tool_names()
assert len(names) == 5
assert "Web Search" in names
assert "Database Query" in names
assert "Web Scraper" in names
assert "Send Email" in names
assert "Calculator" in names
class TestToolSearchResultFormat:
"""Tests for the format of search results."""
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
"""Test that search results contain complete tool information."""
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "name" in tool_info
assert "description" in tool_info
assert "args_schema" in tool_info
assert tool_info["name"] == "Calculator"
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
"""Test that args_schema is properly formatted."""
result = tool_search._run("Email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "args_schema" in tool_info
args_schema = tool_info["args_schema"]
assert isinstance(args_schema, dict)
class TestToolSearchIntegration:
"""Integration tests for ToolSearchTool."""
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
"""Test that ToolSearchTool works as a BaseTool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert isinstance(tool_search, BaseTool)
assert tool_search.name == "Tool Search"
assert "search" in tool_search.description.lower()
def test_tool_search_to_structured_tool(
self, sample_tools: list[BaseTool]
) -> None:
"""Test converting ToolSearchTool to structured tool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
structured = tool_search.to_structured_tool()
assert structured.name == "Tool Search"
assert structured.args_schema is not None
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
"""Test the run method of ToolSearchTool."""
result = tool_search.run(query="search", max_results=3)
assert isinstance(result, str)
result_data = json.loads(result)
assert "status" in result_data
assert "tools" in result_data
class TestToolSearchScoring:
"""Tests for the keyword scoring algorithm."""
def test_exact_name_match_scores_highest(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that exact name matches score higher than partial matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Web Search"
def test_name_match_scores_higher_than_description(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that name matches score higher than description matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Calculator"

View File

@@ -1,8 +1,6 @@
import os
import threading
import unittest
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from crewai.utilities.file_handler import PickleHandler
@@ -10,6 +8,7 @@ from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase):
def setUp(self):
# Use a unique file name for each test to avoid race conditions in parallel test execution
unique_id = str(uuid.uuid4())
self.file_name = f"test_data_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
@@ -48,234 +47,3 @@ class TestPickleHandler(unittest.TestCase):
assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
class TestPickleHandlerThreadSafety(unittest.TestCase):
"""Tests for thread-safety of PickleHandler operations."""
def setUp(self):
unique_id = str(uuid.uuid4())
self.file_name = f"test_thread_safe_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_concurrent_writes_same_handler(self):
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
num_threads = 10
num_writes_per_thread = 20
errors: list[Exception] = []
write_count = 0
count_lock = threading.Lock()
def write_data(thread_id: int) -> None:
nonlocal write_count
for i in range(num_writes_per_thread):
try:
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
self.handler.save(data)
with count_lock:
write_count += 1
except Exception as e:
errors.append(e)
threads = []
for i in range(num_threads):
t = threading.Thread(target=write_data, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
assert write_count == num_threads * num_writes_per_thread
loaded_data = self.handler.load()
assert isinstance(loaded_data, dict)
assert "thread" in loaded_data
assert "iteration" in loaded_data
def test_concurrent_reads_same_handler(self):
"""Test that concurrent reads from multiple threads don't cause issues."""
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
self.handler.save(test_data)
num_threads = 20
results: list[dict] = []
errors: list[Exception] = []
results_lock = threading.Lock()
def read_data() -> None:
try:
data = self.handler.load()
with results_lock:
results.append(data)
except Exception as e:
errors.append(e)
threads = []
for _ in range(num_threads):
t = threading.Thread(target=read_data)
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
assert len(results) == num_threads
for result in results:
assert result == test_data
def test_concurrent_read_write_same_handler(self):
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
initial_data = {"counter": 0}
self.handler.save(initial_data)
num_writers = 5
num_readers = 10
writes_per_thread = 10
reads_per_thread = 20
write_errors: list[Exception] = []
read_errors: list[Exception] = []
read_results: list[dict] = []
results_lock = threading.Lock()
def writer(thread_id: int) -> None:
for i in range(writes_per_thread):
try:
data = {"writer": thread_id, "write_num": i}
self.handler.save(data)
except Exception as e:
write_errors.append(e)
def reader() -> None:
for _ in range(reads_per_thread):
try:
data = self.handler.load()
with results_lock:
read_results.append(data)
except Exception as e:
read_errors.append(e)
threads = []
for i in range(num_writers):
t = threading.Thread(target=writer, args=(i,))
threads.append(t)
for _ in range(num_readers):
t = threading.Thread(target=reader)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
assert len(write_errors) == 0, f"Write errors: {write_errors}"
assert len(read_errors) == 0, f"Read errors: {read_errors}"
for result in read_results:
assert isinstance(result, dict)
def test_atomic_write_no_partial_data(self):
"""Test that atomic writes prevent partial/corrupted data from being read."""
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
num_iterations = 50
errors: list[Exception] = []
corruption_detected = False
corruption_lock = threading.Lock()
def writer() -> None:
for _ in range(num_iterations):
try:
self.handler.save(large_data)
except Exception as e:
errors.append(e)
def reader() -> None:
nonlocal corruption_detected
for _ in range(num_iterations * 2):
try:
data = self.handler.load()
if data and data != {} and data != large_data:
with corruption_lock:
corruption_detected = True
except Exception as e:
errors.append(e)
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert not corruption_detected, "Partial/corrupted data was read"
def test_thread_pool_concurrent_operations(self):
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
num_operations = 100
errors: list[Exception] = []
def operation(op_id: int) -> str:
try:
if op_id % 3 == 0:
self.handler.save({"op_id": op_id, "type": "write"})
return f"write_{op_id}"
else:
data = self.handler.load()
return f"read_{op_id}_{type(data).__name__}"
except Exception as e:
errors.append(e)
return f"error_{op_id}"
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(operation, i) for i in range(num_operations)]
results = [f.result() for f in as_completed(futures)]
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_operations
def test_multiple_handlers_same_file(self):
"""Test that multiple PickleHandler instances for the same file work correctly."""
handler1 = PickleHandler(self.file_name)
handler2 = PickleHandler(self.file_name)
num_operations = 50
errors: list[Exception] = []
def use_handler1() -> None:
for i in range(num_operations):
try:
handler1.save({"handler": 1, "iteration": i})
except Exception as e:
errors.append(e)
def use_handler2() -> None:
for i in range(num_operations):
try:
handler2.save({"handler": 2, "iteration": i})
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=use_handler1)
t2 = threading.Thread(target=use_handler2)
t1.start()
t2.start()
t1.join()
t2.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
final_data = self.handler.load()
assert isinstance(final_data, dict)
assert "handler" in final_data
assert "iteration" in final_data