fix: make BaseTool usage count thread-safe for parallel step execution

Add _usage_lock and _claim_usage() to BaseTool for atomic
check-and-increment of current_usage_count. This prevents race
conditions when parallel plan steps invoke the same tool concurrently
via execute_todos_parallel. Remove the racy pre-check from
execute_single_native_tool_call since the limit is now enforced
atomically inside tool.run().
This commit is contained in:
Greyson LaLonde
2026-03-15 16:12:55 -04:00
parent 02f5d514f8
commit 8e5e17cdeb
2 changed files with 43 additions and 24 deletions

View File

@@ -5,6 +5,7 @@ import asyncio
from collections.abc import Awaitable, Callable
from inspect import Parameter, signature
import json
import threading
from typing import (
Any,
Generic,
@@ -18,6 +19,7 @@ from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
PrivateAttr,
create_model,
field_validator,
)
@@ -94,6 +96,7 @@ class BaseTool(BaseModel, ABC):
default=0,
description="Current number of times this tool has been used.",
)
_usage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
@field_validator("args_schema", mode="before")
@classmethod
@@ -173,6 +176,25 @@ class BaseTool(BaseModel, ABC):
) from e
return kwargs
def _claim_usage(self) -> str | None:
"""Atomically check max usage and increment the counter.
Returns:
None if usage was claimed successfully, or an error message
string if the tool has reached its usage limit.
"""
with self._usage_lock:
if (
self.max_usage_count is not None
and self.current_usage_count >= self.max_usage_count
):
return (
f"Tool '{self.name}' has reached its usage limit of "
f"{self.max_usage_count} times and cannot be used anymore."
)
self.current_usage_count += 1
return None
def run(
self,
*args: Any,
@@ -181,13 +203,15 @@ class BaseTool(BaseModel, ABC):
if not args:
kwargs = self._validate_kwargs(kwargs)
limit_error = self._claim_usage()
if limit_error:
return limit_error
result = self._run(*args, **kwargs)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result
async def arun(
@@ -206,9 +230,12 @@ class BaseTool(BaseModel, ABC):
"""
if not args:
kwargs = self._validate_kwargs(kwargs)
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
limit_error = self._claim_usage()
if limit_error:
return limit_error
return await self._arun(*args, **kwargs)
async def _arun(
self,
@@ -361,12 +388,15 @@ class Tool(BaseTool, Generic[P, R]):
if not args:
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
limit_error = self._claim_usage()
if limit_error:
return limit_error # type: ignore[return-value]
result = self.func(*args, **kwargs)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result # type: ignore[return-value]
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
@@ -393,9 +423,12 @@ class Tool(BaseTool, Generic[P, R]):
"""
if not args:
kwargs = self._validate_kwargs(kwargs) # type: ignore[assignment]
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
limit_error = self._claim_usage()
if limit_error:
return limit_error # type: ignore[return-value]
return await self._arun(*args, **kwargs)
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function asynchronously.

View File

@@ -1429,15 +1429,6 @@ def execute_single_native_tool_call(
original_tool = tool
break
# Check max usage count
max_usage_reached = False
if (
original_tool
and original_tool.max_usage_count is not None
and original_tool.current_usage_count >= original_tool.max_usage_count
):
max_usage_reached = True
# Check cache
from_cache = False
input_str = json.dumps(args_dict) if args_dict else ""
@@ -1496,7 +1487,7 @@ def execute_single_native_tool_call(
error_event_emitted = False
if hook_blocked:
result = f"Tool execution blocked by hook. Tool: {func_name}"
elif not from_cache and not max_usage_reached:
elif not from_cache:
if func_name in available_functions:
try:
tool_func = available_functions[func_name]
@@ -1533,11 +1524,6 @@ def execute_single_native_tool_call(
),
)
error_event_emitted = True
elif max_usage_reached and original_tool:
result = (
f"Tool '{func_name}' has reached its usage limit of "
f"{original_tool.max_usage_count} times and cannot be used anymore."
)
# After hooks
after_hook_context = ToolCallHookContext(