Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
2ada7aba97 Fix race condition in LLM callback system
This commit fixes a race condition in the LLM callback system where
multiple LLM instances calling set_callbacks concurrently could cause
callbacks to be removed before they fire.

Changes:
- Add class-level RLock (_callback_lock) to LLM class to synchronize
  access to global litellm callbacks
- Wrap callback registration and LLM call execution in the lock for
  both call() and acall() methods
- Use RLock (reentrant lock) to handle recursive calls without deadlock
  (e.g., when retrying with unsupported 'stop' parameter)
- Remove sleep(5) workaround from test_llm_callback_replacement test
- Add new test_llm_callback_lock_prevents_race_condition test to verify
  concurrent callback access is properly synchronized

Fixes #4214

Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:11:01 +00:00
15 changed files with 220 additions and 907 deletions

View File

@@ -574,10 +574,6 @@ When you run this Flow, the output will change based on the random boolean value
### Human in the Loop (human feedback)
<Note>
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**.
</Note>
The `@human_feedback` decorator enables human-in-the-loop workflows by pausing flow execution to collect feedback from a human. This is useful for approval gates, quality review, and decision points that require human judgment.
```python Code

View File

@@ -7,10 +7,6 @@ mode: "wide"
## Overview
<Note>
The `@human_feedback` decorator requires **CrewAI version 1.8.0 or higher**. Make sure to update your installation before using this feature.
</Note>
The `@human_feedback` decorator enables human-in-the-loop (HITL) workflows directly within CrewAI Flows. It allows you to pause flow execution, present output to a human for review, collect their feedback, and optionally route to different listeners based on the feedback outcome.
This is particularly valuable for:

View File

@@ -11,10 +11,10 @@ Human-in-the-Loop (HITL) is a powerful approach that combines artificial intelli
CrewAI offers two main approaches for implementing human-in-the-loop workflows:
| Approach | Best For | Integration | Version |
|----------|----------|-------------|---------|
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) | **1.8.0+** |
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide | - |
| Approach | Best For | Integration |
|----------|----------|-------------|
| **Flow-based** (`@human_feedback` decorator) | Local development, console-based review, synchronous workflows | [Human Feedback in Flows](/en/learn/human-feedback-in-flows) |
| **Webhook-based** (Enterprise) | Production deployments, async workflows, external integrations (Slack, Teams, etc.) | This guide |
<Tip>
If you're building flows and want to add human review steps with routing based on feedback, check out the [Human Feedback in Flows](/en/learn/human-feedback-in-flows) guide for the `@human_feedback` decorator.

View File

@@ -567,10 +567,6 @@ Fourth method running
### Human in the Loop (인간 피드백)
<Note>
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다.
</Note>
`@human_feedback` 데코레이터는 인간의 피드백을 수집하기 위해 플로우 실행을 일시 중지하는 human-in-the-loop 워크플로우를 가능하게 합니다. 이는 승인 게이트, 품질 검토, 인간의 판단이 필요한 결정 지점에 유용합니다.
```python Code

View File

@@ -7,10 +7,6 @@ mode: "wide"
## 개요
<Note>
`@human_feedback` 데코레이터는 **CrewAI 버전 1.8.0 이상**이 필요합니다. 이 기능을 사용하기 전에 설치를 업데이트하세요.
</Note>
`@human_feedback` 데코레이터는 CrewAI Flow 내에서 직접 human-in-the-loop(HITL) 워크플로우를 가능하게 합니다. Flow 실행을 일시 중지하고, 인간에게 검토를 위해 출력을 제시하고, 피드백을 수집하고, 선택적으로 피드백 결과에 따라 다른 리스너로 라우팅할 수 있습니다.
이는 특히 다음과 같은 경우에 유용합니다:

View File

@@ -5,22 +5,9 @@ icon: "user-check"
mode: "wide"
---
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. CrewAI는 필요에 따라 HITL을 구현하는 여러 가지 방법을 제공합니다.
휴먼 인 더 루프(HITL, Human-in-the-Loop)는 인공지능과 인간의 전문 지식을 결합하여 의사결정을 강화하고 작업 결과를 향상시키는 강력한 접근 방식입니다. 이 가이드에서는 CrewAI 내에서 HITL을 구현하는 방법을 안내합니다.
## HITL 접근 방식 선택
CrewAI는 human-in-the-loop 워크플로우를 구현하기 위한 두 가지 주요 접근 방식을 제공합니다:
| 접근 방식 | 적합한 용도 | 통합 | 버전 |
|----------|----------|-------------|---------|
| **Flow 기반** (`@human_feedback` 데코레이터) | 로컬 개발, 콘솔 기반 검토, 동기식 워크플로우 | [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) | **1.8.0+** |
| **Webhook 기반** (Enterprise) | 프로덕션 배포, 비동기 워크플로우, 외부 통합 (Slack, Teams 등) | 이 가이드 | - |
<Tip>
Flow를 구축하면서 피드백을 기반으로 라우팅하는 인간 검토 단계를 추가하려면 `@human_feedback` 데코레이터에 대한 [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) 가이드를 참조하세요.
</Tip>
## Webhook 기반 HITL 워크플로우 설정
## HITL 워크플로우 설정
<Steps>
<Step title="작업 구성">

View File

@@ -309,10 +309,6 @@ Ao executar esse Flow, a saída será diferente dependendo do valor booleano ale
### Human in the Loop (feedback humano)
<Note>
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**.
</Note>
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop, pausando a execução do flow para coletar feedback de um humano. Isso é útil para portões de aprovação, revisão de qualidade e pontos de decisão que requerem julgamento humano.
```python Code

View File

@@ -7,10 +7,6 @@ mode: "wide"
## Visão Geral
<Note>
O decorador `@human_feedback` requer **CrewAI versão 1.8.0 ou superior**. Certifique-se de atualizar sua instalação antes de usar este recurso.
</Note>
O decorador `@human_feedback` permite fluxos de trabalho human-in-the-loop (HITL) diretamente nos CrewAI Flows. Ele permite pausar a execução do flow, apresentar a saída para um humano revisar, coletar seu feedback e, opcionalmente, rotear para diferentes listeners com base no resultado do feedback.
Isso é particularmente valioso para:

View File

@@ -5,22 +5,9 @@ icon: "user-check"
mode: "wide"
---
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. CrewAI oferece várias maneiras de implementar HITL dependendo das suas necessidades.
Human-in-the-Loop (HITL) é uma abordagem poderosa que combina a inteligência artificial com a experiência humana para aprimorar a tomada de decisões e melhorar os resultados das tarefas. Este guia mostra como implementar HITL dentro da CrewAI.
## Escolhendo Sua Abordagem HITL
CrewAI oferece duas abordagens principais para implementar workflows human-in-the-loop:
| Abordagem | Melhor Para | Integração | Versão |
|----------|----------|-------------|---------|
| **Baseada em Flow** (decorador `@human_feedback`) | Desenvolvimento local, revisão via console, workflows síncronos | [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) | **1.8.0+** |
| **Baseada em Webhook** (Enterprise) | Deployments em produção, workflows assíncronos, integrações externas (Slack, Teams, etc.) | Este guia | - |
<Tip>
Se você está construindo flows e deseja adicionar etapas de revisão humana com roteamento baseado em feedback, confira o guia [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows) para o decorador `@human_feedback`.
</Tip>
## Configurando Workflows HITL Baseados em Webhook
## Configurando Workflows HITL
<Steps>
<Step title="Configure sua Tarefa">

View File

@@ -209,9 +209,10 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
# Handle telemetry
span = self.execution_spans.pop(source, None)
span = self.execution_spans.get(source)
if span:
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = get_task_name(source)
@@ -221,10 +222,11 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(TaskFailedEvent)
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
span = self.execution_spans.pop(source, None)
span = self.execution_spans.get(source)
if span:
if source.agent and source.agent.crew:
self._telemetry.task_ended(span, source, source.agent.crew)
self.execution_spans[source] = None
# Pass task name if it exists
task_name = get_task_name(source)

View File

@@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
_callback_lock: threading.RLock = threading.RLock()
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.
@@ -1144,7 +1145,7 @@ class LLM(BaseLLM):
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1363,7 +1364,7 @@ class LLM(BaseLLM):
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
@@ -1657,78 +1658,92 @@ class LLM(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
logging.info("Retrying LLM call without the unsupported 'stop'")
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
# Recursive call happens inside the lock since we're using
# a reentrant-safe pattern (the lock is released when we
# exit the with block, and the recursive call will acquire
# it again)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
async def acall(
self,
@@ -1790,14 +1805,27 @@ class LLM(BaseLLM):
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
return await self._ahandle_streaming_response(
if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1805,52 +1833,49 @@ class LLM(BaseLLM):
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
logging.info("Retrying LLM call without the unsupported 'stop'")
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
def _handle_emit_call_events(
self,

View File

@@ -1,12 +1,9 @@
from crewai.tools.base_tool import BaseTool, EnvVar, tool
from crewai.tools.tool_search_tool import SearchStrategy, ToolSearchTool
__all__ = [
"BaseTool",
"EnvVar",
"SearchStrategy",
"ToolSearchTool",
"tool",
]

View File

@@ -1,333 +0,0 @@
"""Tool Search Tool for on-demand tool discovery.
This module implements a Tool Search Tool that allows agents to dynamically
discover and load tools on-demand, reducing token consumption when working
with large tool libraries.
Inspired by Anthropic's Tool Search Tool approach for on-demand tool loading.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from enum import Enum
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.pydantic_schema_utils import generate_model_description
class SearchStrategy(str, Enum):
"""Search strategy for tool discovery."""
KEYWORD = "keyword"
REGEX = "regex"
class ToolSearchResult(BaseModel):
"""Result from a tool search operation."""
name: str = Field(description="The name of the tool")
description: str = Field(description="The description of the tool")
args_schema: dict[str, Any] = Field(
description="The JSON schema for the tool's arguments"
)
class ToolSearchToolSchema(BaseModel):
"""Schema for the Tool Search Tool arguments."""
query: str = Field(
description="The search query to find relevant tools. Use keywords that describe the capability you need."
)
max_results: int = Field(
default=5,
description="Maximum number of tools to return. Default is 5.",
ge=1,
le=20,
)
class ToolSearchTool(BaseTool):
"""A tool that searches through a catalog of tools to find relevant ones.
This tool enables on-demand tool discovery, allowing agents to work with
large tool libraries without loading all tool definitions upfront. Instead
of consuming tokens with all tool definitions, the agent can search for
relevant tools when needed.
Example:
```python
from crewai.tools import BaseTool, ToolSearchTool
# Create your tools
search_tool = MySearchTool()
scrape_tool = MyScrapeWebsiteTool()
database_tool = MyDatabaseTool()
# Create a tool search tool with your tool catalog
tool_search = ToolSearchTool(
tool_catalog=[search_tool, scrape_tool, database_tool],
search_strategy=SearchStrategy.KEYWORD,
)
# Use with an agent - only the tool_search is loaded initially
agent = Agent(
role="Researcher",
tools=[tool_search], # Other tools discovered on-demand
)
```
Attributes:
tool_catalog: List of tools available for search.
search_strategy: Strategy to use for searching (keyword or regex).
custom_search_fn: Optional custom search function for advanced matching.
"""
name: str = Field(
default="Tool Search",
description="The name of the tool search tool.",
)
description: str = Field(
default="Search for available tools by describing the capability you need. Returns tool definitions that match your query.",
description="Description of what the tool search tool does.",
)
args_schema: type[BaseModel] = Field(
default=ToolSearchToolSchema,
description="The schema for the tool search arguments.",
)
tool_catalog: list[BaseTool | CrewStructuredTool] = Field(
default_factory=list,
description="List of tools available for search.",
)
search_strategy: SearchStrategy = Field(
default=SearchStrategy.KEYWORD,
description="Strategy to use for searching tools.",
)
custom_search_fn: Callable[
[str, Sequence[BaseTool | CrewStructuredTool]], list[BaseTool | CrewStructuredTool]
] | None = Field(
default=None,
description="Optional custom search function for advanced matching.",
)
def _run(self, query: str, max_results: int = 5) -> str:
"""Search for tools matching the query.
Args:
query: The search query to find relevant tools.
max_results: Maximum number of tools to return.
Returns:
JSON string containing the matching tool definitions.
"""
if not self.tool_catalog:
return json.dumps(
{
"status": "error",
"message": "No tools available in the catalog.",
"tools": [],
}
)
if self.custom_search_fn:
matching_tools = self.custom_search_fn(query, self.tool_catalog)
elif self.search_strategy == SearchStrategy.REGEX:
matching_tools = self._regex_search(query)
else:
matching_tools = self._keyword_search(query)
matching_tools = matching_tools[:max_results]
if not matching_tools:
return json.dumps(
{
"status": "no_results",
"message": f"No tools found matching query: '{query}'. Try different keywords.",
"tools": [],
}
)
tool_results = []
for tool in matching_tools:
tool_info = self._get_tool_info(tool)
tool_results.append(tool_info)
return json.dumps(
{
"status": "success",
"message": f"Found {len(tool_results)} tool(s) matching your query.",
"tools": tool_results,
},
indent=2,
)
def _keyword_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using keyword matching.
Args:
query: The search query.
Returns:
List of matching tools sorted by relevance.
"""
query_lower = query.lower()
query_words = set(query_lower.split())
scored_tools: list[tuple[float, BaseTool | CrewStructuredTool]] = []
for tool in self.tool_catalog:
score = self._calculate_keyword_score(tool, query_lower, query_words)
if score > 0:
scored_tools.append((score, tool))
scored_tools.sort(key=lambda x: x[0], reverse=True)
return [tool for _, tool in scored_tools]
def _calculate_keyword_score(
self,
tool: BaseTool | CrewStructuredTool,
query_lower: str,
query_words: set[str],
) -> float:
"""Calculate relevance score for a tool based on keyword matching.
Args:
tool: The tool to score.
query_lower: Lowercase query string.
query_words: Set of query words.
Returns:
Relevance score (higher is better).
"""
score = 0.0
tool_name_lower = tool.name.lower()
tool_desc_lower = tool.description.lower()
if query_lower in tool_name_lower:
score += 10.0
if query_lower in tool_desc_lower:
score += 5.0
for word in query_words:
if len(word) < 2:
continue
if word in tool_name_lower:
score += 3.0
if word in tool_desc_lower:
score += 1.0
return score
def _regex_search(
self, query: str
) -> list[BaseTool | CrewStructuredTool]:
"""Search tools using regex pattern matching.
Args:
query: The regex pattern to search for.
Returns:
List of matching tools.
"""
try:
pattern = re.compile(query, re.IGNORECASE)
except re.error:
pattern = re.compile(re.escape(query), re.IGNORECASE)
return [
tool
for tool in self.tool_catalog
if pattern.search(tool.name) or pattern.search(tool.description)
]
def _get_tool_info(self, tool: BaseTool | CrewStructuredTool) -> dict[str, Any]:
"""Get tool information as a dictionary.
Args:
tool: The tool to get information from.
Returns:
Dictionary containing tool name, description, and args schema.
"""
if isinstance(tool, BaseTool):
schema_dict = generate_model_description(tool.args_schema)
args_schema = schema_dict.get("json_schema", {}).get("schema", {})
else:
args_schema = tool.args_schema.model_json_schema()
return {
"name": tool.name,
"description": self._get_original_description(tool),
"args_schema": args_schema,
}
def _get_original_description(self, tool: BaseTool | CrewStructuredTool) -> str:
"""Get the original description of a tool without the generated schema.
Args:
tool: The tool to get the description from.
Returns:
The original tool description.
"""
description = tool.description
if "Tool Description:" in description:
parts = description.split("Tool Description:")
if len(parts) > 1:
return parts[1].strip()
return description
def add_tool(self, tool: BaseTool | CrewStructuredTool) -> None:
"""Add a tool to the catalog.
Args:
tool: The tool to add.
"""
self.tool_catalog.append(tool)
def add_tools(self, tools: Sequence[BaseTool | CrewStructuredTool]) -> None:
"""Add multiple tools to the catalog.
Args:
tools: The tools to add.
"""
self.tool_catalog.extend(tools)
def remove_tool(self, tool_name: str) -> bool:
"""Remove a tool from the catalog by name.
Args:
tool_name: The name of the tool to remove.
Returns:
True if the tool was removed, False if not found.
"""
for i, tool in enumerate(self.tool_catalog):
if tool.name == tool_name:
self.tool_catalog.pop(i)
return True
return False
def get_catalog_size(self) -> int:
"""Get the number of tools in the catalog.
Returns:
The number of tools in the catalog.
"""
return len(self.tool_catalog)
def list_tool_names(self) -> list[str]:
"""List all tool names in the catalog.
Returns:
List of tool names.
"""
return [tool.name for tool in self.tool_catalog]

View File

@@ -1,6 +1,6 @@
import logging
import os
from time import sleep
import threading
from unittest.mock import MagicMock, patch
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,9 +18,15 @@ from pydantic import BaseModel
import pytest
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr()
def test_llm_callback_replacement():
"""Test that callbacks are properly isolated between LLM instances.
This test verifies that the race condition fix (using _callback_lock) works
correctly. Previously, this test required a sleep(5) workaround because
callbacks were being modified globally without synchronization, causing
one LLM instance's callbacks to interfere with another's.
"""
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
# The first handler should not have been updated
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
def test_llm_callback_lock_prevents_race_condition():
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
This test verifies that multiple threads can safely call LLM.call() with
different callbacks without interfering with each other. The lock ensures
that callbacks are properly isolated between concurrent calls.
"""
num_threads = 5
results: list[int] = []
errors: list[Exception] = []
lock = threading.Lock()
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
try:
llm = LLM(model="gpt-4o-mini", is_litellm=True)
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
mock_message = MagicMock()
mock_message.content = f"Response from thread {thread_id}"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
}
mock_completion.return_value = mock_response
llm.call(
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
callbacks=[calc_handler],
)
usage = calc_handler.token_cost_process.get_summary()
with lock:
results.append(usage.successful_requests)
except Exception as e:
with lock:
errors.append(e)
with patch("litellm.completion") as mock_completion:
threads = [
threading.Thread(target=make_llm_call, args=(i, mock_completion))
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_threads
assert all(
r == 1 for r in results
), f"Expected all callbacks to have 1 successful request, got {results}"
@pytest.mark.vcr()
def test_llm_call_with_string_input():
llm = LLM(model="gpt-4o-mini")
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
assert llm._token_usage["completion_tokens"] > 0
assert llm._token_usage["total_tokens"] > 0
assert len(result) > 0
assert len(result) > 0

View File

@@ -1,393 +0,0 @@
"""Tests for the ToolSearchTool functionality."""
import json
import pytest
from pydantic import BaseModel
from crewai.tools import BaseTool, SearchStrategy, ToolSearchTool
class MockSearchTool(BaseTool):
"""A mock search tool for testing."""
name: str = "Web Search"
description: str = "Search the web for information on any topic."
def _run(self, query: str) -> str:
return f"Search results for: {query}"
class MockDatabaseTool(BaseTool):
"""A mock database tool for testing."""
name: str = "Database Query"
description: str = "Query a SQL database to retrieve data."
def _run(self, query: str) -> str:
return f"Database results for: {query}"
class MockScrapeTool(BaseTool):
"""A mock web scraping tool for testing."""
name: str = "Web Scraper"
description: str = "Scrape content from websites and extract text."
def _run(self, url: str) -> str:
return f"Scraped content from: {url}"
class MockEmailTool(BaseTool):
"""A mock email tool for testing."""
name: str = "Send Email"
description: str = "Send an email to a specified recipient."
def _run(self, to: str, subject: str, body: str) -> str:
return f"Email sent to {to}"
class MockCalculatorTool(BaseTool):
"""A mock calculator tool for testing."""
name: str = "Calculator"
description: str = "Perform mathematical calculations and arithmetic operations."
def _run(self, expression: str) -> str:
return f"Result: {eval(expression)}"
@pytest.fixture
def sample_tools() -> list[BaseTool]:
"""Create a list of sample tools for testing."""
return [
MockSearchTool(),
MockDatabaseTool(),
MockScrapeTool(),
MockEmailTool(),
MockCalculatorTool(),
]
@pytest.fixture
def tool_search(sample_tools: list[BaseTool]) -> ToolSearchTool:
"""Create a ToolSearchTool with sample tools."""
return ToolSearchTool(tool_catalog=sample_tools)
class TestToolSearchToolCreation:
"""Tests for ToolSearchTool creation and initialization."""
def test_create_tool_search_with_empty_catalog(self) -> None:
"""Test creating a ToolSearchTool with an empty catalog."""
tool_search = ToolSearchTool()
assert tool_search.name == "Tool Search"
assert tool_search.tool_catalog == []
assert tool_search.search_strategy == SearchStrategy.KEYWORD
def test_create_tool_search_with_tools(self, sample_tools: list[BaseTool]) -> None:
"""Test creating a ToolSearchTool with a list of tools."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert len(tool_search.tool_catalog) == 5
assert tool_search.get_catalog_size() == 5
def test_create_tool_search_with_regex_strategy(
self, sample_tools: list[BaseTool]
) -> None:
"""Test creating a ToolSearchTool with regex search strategy."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
assert tool_search.search_strategy == SearchStrategy.REGEX
def test_create_tool_search_with_custom_name(self) -> None:
"""Test creating a ToolSearchTool with a custom name."""
tool_search = ToolSearchTool(name="My Tool Finder")
assert tool_search.name == "My Tool Finder"
class TestToolSearchKeywordSearch:
"""Tests for keyword-based tool search."""
def test_search_by_exact_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by its exact name."""
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
assert result_data["tools"][0]["name"] == "Web Search"
def test_search_by_partial_name(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by partial name."""
result = tool_search._run("Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names
def test_search_by_description_keyword(self, tool_search: ToolSearchTool) -> None:
"""Test searching for a tool by keyword in description."""
result = tool_search._run("database")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Database Query" in tool_names
def test_search_with_multiple_keywords(self, tool_search: ToolSearchTool) -> None:
"""Test searching with multiple keywords."""
result = tool_search._run("web scrape content")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) >= 1
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Scraper" in tool_names
def test_search_no_results(self, tool_search: ToolSearchTool) -> None:
"""Test searching with a query that returns no results."""
result = tool_search._run("xyznonexistent123abc")
result_data = json.loads(result)
assert result_data["status"] == "no_results"
assert len(result_data["tools"]) == 0
def test_search_max_results_limit(self, tool_search: ToolSearchTool) -> None:
"""Test that max_results limits the number of returned tools."""
result = tool_search._run("tool", max_results=2)
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) <= 2
def test_search_empty_catalog(self) -> None:
"""Test searching with an empty tool catalog."""
tool_search = ToolSearchTool()
result = tool_search._run("search")
result_data = json.loads(result)
assert result_data["status"] == "error"
assert "No tools available" in result_data["message"]
class TestToolSearchRegexSearch:
"""Tests for regex-based tool search."""
def test_regex_search_simple_pattern(
self, sample_tools: list[BaseTool]
) -> None:
"""Test regex search with a simple pattern."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("Web.*")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Web Search" in tool_names or "Web Scraper" in tool_names
def test_regex_search_case_insensitive(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that regex search is case insensitive."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_names = [t["name"] for t in result_data["tools"]]
assert "Send Email" in tool_names
def test_regex_search_invalid_pattern_fallback(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that invalid regex patterns are escaped and still work."""
tool_search = ToolSearchTool(
tool_catalog=sample_tools, search_strategy=SearchStrategy.REGEX
)
result = tool_search._run("[invalid(regex")
result_data = json.loads(result)
assert result_data["status"] in ["success", "no_results"]
class TestToolSearchCustomSearch:
"""Tests for custom search function."""
def test_custom_search_function(self, sample_tools: list[BaseTool]) -> None:
"""Test using a custom search function."""
def custom_search(
query: str, tools: list[BaseTool]
) -> list[BaseTool]:
return [t for t in tools if "email" in t.name.lower()]
tool_search = ToolSearchTool(
tool_catalog=sample_tools, custom_search_fn=custom_search
)
result = tool_search._run("anything")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert len(result_data["tools"]) == 1
assert result_data["tools"][0]["name"] == "Send Email"
class TestToolSearchCatalogManagement:
"""Tests for tool catalog management."""
def test_add_tool(self, tool_search: ToolSearchTool) -> None:
"""Test adding a tool to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool(BaseTool):
name: str = "New Tool"
description: str = "A new tool for testing."
def _run(self) -> str:
return "New tool result"
tool_search.add_tool(NewTool())
assert tool_search.get_catalog_size() == initial_size + 1
def test_add_tools(self, tool_search: ToolSearchTool) -> None:
"""Test adding multiple tools to the catalog."""
initial_size = tool_search.get_catalog_size()
class NewTool1(BaseTool):
name: str = "New Tool 1"
description: str = "First new tool."
def _run(self) -> str:
return "Result 1"
class NewTool2(BaseTool):
name: str = "New Tool 2"
description: str = "Second new tool."
def _run(self) -> str:
return "Result 2"
tool_search.add_tools([NewTool1(), NewTool2()])
assert tool_search.get_catalog_size() == initial_size + 2
def test_remove_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool from the catalog."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Web Search")
assert result is True
assert tool_search.get_catalog_size() == initial_size - 1
def test_remove_nonexistent_tool(self, tool_search: ToolSearchTool) -> None:
"""Test removing a tool that doesn't exist."""
initial_size = tool_search.get_catalog_size()
result = tool_search.remove_tool("Nonexistent Tool")
assert result is False
assert tool_search.get_catalog_size() == initial_size
def test_list_tool_names(self, tool_search: ToolSearchTool) -> None:
"""Test listing all tool names in the catalog."""
names = tool_search.list_tool_names()
assert len(names) == 5
assert "Web Search" in names
assert "Database Query" in names
assert "Web Scraper" in names
assert "Send Email" in names
assert "Calculator" in names
class TestToolSearchResultFormat:
"""Tests for the format of search results."""
def test_result_contains_tool_info(self, tool_search: ToolSearchTool) -> None:
"""Test that search results contain complete tool information."""
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "name" in tool_info
assert "description" in tool_info
assert "args_schema" in tool_info
assert tool_info["name"] == "Calculator"
def test_result_args_schema_format(self, tool_search: ToolSearchTool) -> None:
"""Test that args_schema is properly formatted."""
result = tool_search._run("Email")
result_data = json.loads(result)
assert result_data["status"] == "success"
tool_info = result_data["tools"][0]
assert "args_schema" in tool_info
args_schema = tool_info["args_schema"]
assert isinstance(args_schema, dict)
class TestToolSearchIntegration:
"""Integration tests for ToolSearchTool."""
def test_tool_search_as_base_tool(self, sample_tools: list[BaseTool]) -> None:
"""Test that ToolSearchTool works as a BaseTool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
assert isinstance(tool_search, BaseTool)
assert tool_search.name == "Tool Search"
assert "search" in tool_search.description.lower()
def test_tool_search_to_structured_tool(
self, sample_tools: list[BaseTool]
) -> None:
"""Test converting ToolSearchTool to structured tool."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
structured = tool_search.to_structured_tool()
assert structured.name == "Tool Search"
assert structured.args_schema is not None
def test_tool_search_run_method(self, tool_search: ToolSearchTool) -> None:
"""Test the run method of ToolSearchTool."""
result = tool_search.run(query="search", max_results=3)
assert isinstance(result, str)
result_data = json.loads(result)
assert "status" in result_data
assert "tools" in result_data
class TestToolSearchScoring:
"""Tests for the keyword scoring algorithm."""
def test_exact_name_match_scores_highest(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that exact name matches score higher than partial matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Web Search")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Web Search"
def test_name_match_scores_higher_than_description(
self, sample_tools: list[BaseTool]
) -> None:
"""Test that name matches score higher than description matches."""
tool_search = ToolSearchTool(tool_catalog=sample_tools)
result = tool_search._run("Calculator")
result_data = json.loads(result)
assert result_data["status"] == "success"
assert result_data["tools"][0]["name"] == "Calculator"