ensure we respect max_usage_count

This commit is contained in:
lorenzejay
2026-02-18 16:22:53 -08:00
parent aad1ec1d8d
commit 6efb427b89
2 changed files with 57 additions and 13 deletions

View File

@@ -740,6 +740,31 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if not parsed_calls:
return None
original_tools_by_name: dict[str, Any] = {}
for tool in self.original_tools or []:
original_tools_by_name[sanitize_tool_name(tool.name)] = tool
# Reserve max-usage slots deterministically in call order.
# This prevents race conditions when multiple parallel calls target the same tool.
reserved_usage_by_tool: dict[str, int] = {}
execution_plan: list[tuple[str, str, str | dict[str, Any], Any | None, bool]] = []
for call_id, func_name, func_args in parsed_calls:
original_tool = original_tools_by_name.get(func_name)
should_execute = True
if (
original_tool
and getattr(original_tool, "max_usage_count", None) is not None
):
current_usage = getattr(original_tool, "current_usage_count", 0)
reserved = reserved_usage_by_tool.get(func_name, 0)
if current_usage + reserved >= original_tool.max_usage_count:
should_execute = False
else:
reserved_usage_by_tool[func_name] = reserved + 1
execution_plan.append(
(call_id, func_name, func_args, original_tool, should_execute)
)
assistant_message: LLMMessage = {
"role": "assistant",
"content": None,
@@ -754,13 +779,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
else json.dumps(func_args),
},
}
for call_id, func_name, func_args in parsed_calls
for call_id, func_name, func_args, _, _ in execution_plan
],
}
self.messages.append(assistant_message)
def _execute_one(
idx: int, call_id: str, func_name: str, func_args: str | dict[str, Any]
idx: int,
call_id: str,
func_name: str,
func_args: str | dict[str, Any],
original_tool: Any | None,
should_execute: bool,
) -> tuple[int, str, str, str, Any | None]:
if isinstance(func_args, str):
try:
@@ -787,15 +817,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
track_delegation_if_needed(func_name, args_dict, self.task)
original_tool = None
for tool in self.original_tools or []:
if sanitize_tool_name(tool.name) == func_name:
original_tool = tool
break
error_event_emitted = False
result: str = "Tool not found"
if func_name in available_functions:
if not should_execute and original_tool:
result = (
f"Tool '{func_name}' has reached its usage limit of "
f"{original_tool.max_usage_count} times and cannot be used anymore."
)
elif func_name in available_functions:
try:
raw_result = available_functions[func_name](**args_dict)
result = (
@@ -843,8 +872,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
] * len(parsed_calls)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(_execute_one, idx, call_id, func_name, func_args): idx
for idx, (call_id, func_name, func_args) in enumerate(parsed_calls)
pool.submit(
_execute_one,
idx,
call_id,
func_name,
func_args,
original_tool,
should_execute,
): idx
for idx, (
call_id,
func_name,
func_args,
original_tool,
should_execute,
) in enumerate(execution_plan)
}
for future in as_completed(futures):
idx = futures[future]

View File

@@ -1017,5 +1017,6 @@ class TestMaxUsageCountWithNativeToolCalling:
result = crew.kickoff()
assert result is not None
# Verify usage count was incremented for each successful call
assert tool.current_usage_count == 2
# Verify the requested calls occurred while keeping usage bounded.
assert tool.current_usage_count >= 2
assert tool.current_usage_count <= tool.max_usage_count