mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
fix: make BaseTool usage count thread-safe for parallel step execution
Add _usage_lock and _claim_usage() to BaseTool for atomic check-and-increment of current_usage_count. This prevents race conditions when parallel plan steps invoke the same tool concurrently via execute_todos_parallel. Remove the racy pre-check from execute_single_native_tool_call since the limit is now enforced atomically inside tool.run().
This commit is contained in:
@@ -5,6 +5,7 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import Parameter, signature
|
||||
import json
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
@@ -18,6 +19,7 @@ from pydantic import (
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
@@ -94,6 +96,7 @@ class BaseTool(BaseModel, ABC):
|
||||
default=0,
|
||||
description="Current number of times this tool has been used.",
|
||||
)
|
||||
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
@field_validator("args_schema", mode="before")
|
||||
@classmethod
|
||||
@@ -173,6 +176,25 @@ class BaseTool(BaseModel, ABC):
|
||||
) from e
|
||||
return kwargs
|
||||
|
||||
def _claim_usage(self) -> str | None:
|
||||
"""Atomically check max usage and increment the counter.
|
||||
|
||||
Returns:
|
||||
None if usage was claimed successfully, or an error message
|
||||
string if the tool has reached its usage limit.
|
||||
"""
|
||||
with self._usage_lock:
|
||||
if (
|
||||
self.max_usage_count is not None
|
||||
and self.current_usage_count >= self.max_usage_count
|
||||
):
|
||||
return (
|
||||
f"Tool '{self.name}' has reached its usage limit of "
|
||||
f"{self.max_usage_count} times and cannot be used anymore."
|
||||
)
|
||||
self.current_usage_count += 1
|
||||
return None
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -181,13 +203,15 @@ class BaseTool(BaseModel, ABC):
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error
|
||||
|
||||
result = self._run(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
|
||||
return result
|
||||
|
||||
async def arun(
|
||||
@@ -206,9 +230,12 @@ class BaseTool(BaseModel, ABC):
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs)
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error
|
||||
|
||||
return await self._arun(*args, **kwargs)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@@ -361,12 +388,15 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error # type: ignore[return-value]
|
||||
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@@ -393,9 +423,12 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
"""
|
||||
if not args:
|
||||
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
limit_error = self._claim_usage()
|
||||
if limit_error:
|
||||
return limit_error # type: ignore[return-value]
|
||||
|
||||
return await self._arun(*args, **kwargs)
|
||||
|
||||
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function asynchronously.
|
||||
|
||||
@@ -1429,15 +1429,6 @@ def execute_single_native_tool_call(
|
||||
original_tool = tool
|
||||
break
|
||||
|
||||
# Check max usage count
|
||||
max_usage_reached = False
|
||||
if (
|
||||
original_tool
|
||||
and original_tool.max_usage_count is not None
|
||||
and original_tool.current_usage_count >= original_tool.max_usage_count
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
# Check cache
|
||||
from_cache = False
|
||||
input_str = json.dumps(args_dict) if args_dict else ""
|
||||
@@ -1496,7 +1487,7 @@ def execute_single_native_tool_call(
|
||||
error_event_emitted = False
|
||||
if hook_blocked:
|
||||
result = f"Tool execution blocked by hook. Tool: {func_name}"
|
||||
elif not from_cache and not max_usage_reached:
|
||||
elif not from_cache:
|
||||
if func_name in available_functions:
|
||||
try:
|
||||
tool_func = available_functions[func_name]
|
||||
@@ -1533,11 +1524,6 @@ def execute_single_native_tool_call(
|
||||
),
|
||||
)
|
||||
error_event_emitted = True
|
||||
elif max_usage_reached and original_tool:
|
||||
result = (
|
||||
f"Tool '{func_name}' has reached its usage limit of "
|
||||
f"{original_tool.max_usage_count} times and cannot be used anymore."
|
||||
)
|
||||
|
||||
# After hooks
|
||||
after_hook_context = ToolCallHookContext(
|
||||
|
||||
Reference in New Issue
Block a user