supporting parallel tool use (#4513)

* supporting parallel tool use

* ensure we respect max_usage_count

* ensure result_as_answer, hooks, and cache parodity

* improve crew agent executor

* address test comments
This commit is contained in:
Lorenze Jay
2026-02-19 14:07:28 -08:00
committed by GitHub
parent 49aa29bb41
commit d09656664d
19 changed files with 3981 additions and 881 deletions

View File

@@ -6,13 +6,20 @@ when the LLM supports it, across multiple providers.
from __future__ import annotations
from collections.abc import Generator
import os
import threading
import time
from collections import Counter
from unittest.mock import patch
import pytest
from pydantic import BaseModel, Field
from crewai import Agent, Crew, Task
from crewai.events import crewai_event_bus
from crewai.hooks import register_after_tool_call_hook, register_before_tool_call_hook
from crewai.hooks.tool_hooks import ToolCallHookContext
from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
@@ -64,6 +71,73 @@ class FailingTool(BaseTool):
def _run(self) -> str:
raise Exception("This tool always fails")
class LocalSearchInput(BaseModel):
query: str = Field(description="Search query")
class ParallelProbe:
"""Thread-safe in-memory recorder for tool execution windows."""
_lock = threading.Lock()
_windows: list[tuple[str, float, float]] = []
@classmethod
def reset(cls) -> None:
with cls._lock:
cls._windows = []
@classmethod
def record(cls, tool_name: str, start: float, end: float) -> None:
with cls._lock:
cls._windows.append((tool_name, start, end))
@classmethod
def windows(cls) -> list[tuple[str, float, float]]:
with cls._lock:
return list(cls._windows)
def _parallel_prompt() -> str:
return (
"This is a tool-calling compliance test. "
"In your next assistant turn, emit exactly 3 tool calls in the same response (parallel tool calls), in this order: "
"1) parallel_local_search_one(query='latest OpenAI model release notes'), "
"2) parallel_local_search_two(query='latest Anthropic model release notes'), "
"3) parallel_local_search_three(query='latest Gemini model release notes'). "
"Do not call any other tools and do not answer before those 3 tool calls are emitted. "
"After the tool results return, provide a one paragraph summary."
)
def _max_concurrency(windows: list[tuple[str, float, float]]) -> int:
points: list[tuple[float, int]] = []
for _, start, end in windows:
points.append((start, 1))
points.append((end, -1))
points.sort(key=lambda p: (p[0], p[1]))
current = 0
maximum = 0
for _, delta in points:
current += delta
if current > maximum:
maximum = current
return maximum
def _assert_tools_overlapped() -> None:
windows = ParallelProbe.windows()
local_windows = [
w
for w in windows
if w[0].startswith("parallel_local_search_")
]
assert len(local_windows) >= 3, f"Expected at least 3 local tool calls, got {len(local_windows)}"
assert _max_concurrency(local_windows) >= 2, "Expected overlapping local tool executions"
@pytest.fixture
def calculator_tool() -> CalculatorTool:
"""Create a calculator tool for testing."""
@@ -82,6 +156,65 @@ def failing_tool() -> BaseTool:
)
@pytest.fixture
def parallel_tools() -> list[BaseTool]:
"""Create local tools used to verify native parallel execution deterministically."""
class ParallelLocalSearchOne(BaseTool):
name: str = "parallel_local_search_one"
description: str = "Local search tool #1 for concurrency testing."
args_schema: type[BaseModel] = LocalSearchInput
def _run(self, query: str) -> str:
start = time.perf_counter()
time.sleep(1.0)
end = time.perf_counter()
ParallelProbe.record(self.name, start, end)
return f"[one] {query}"
class ParallelLocalSearchTwo(BaseTool):
name: str = "parallel_local_search_two"
description: str = "Local search tool #2 for concurrency testing."
args_schema: type[BaseModel] = LocalSearchInput
def _run(self, query: str) -> str:
start = time.perf_counter()
time.sleep(1.0)
end = time.perf_counter()
ParallelProbe.record(self.name, start, end)
return f"[two] {query}"
class ParallelLocalSearchThree(BaseTool):
name: str = "parallel_local_search_three"
description: str = "Local search tool #3 for concurrency testing."
args_schema: type[BaseModel] = LocalSearchInput
def _run(self, query: str) -> str:
start = time.perf_counter()
time.sleep(1.0)
end = time.perf_counter()
ParallelProbe.record(self.name, start, end)
return f"[three] {query}"
return [
ParallelLocalSearchOne(),
ParallelLocalSearchTwo(),
ParallelLocalSearchThree(),
]
def _attach_parallel_probe_handler() -> None:
@crewai_event_bus.on(ToolUsageFinishedEvent)
def _capture_tool_window(_source, event: ToolUsageFinishedEvent):
if not event.tool_name.startswith("parallel_local_search_"):
return
ParallelProbe.record(
event.tool_name,
event.started_at.timestamp(),
event.finished_at.timestamp(),
)
# =============================================================================
# OpenAI Provider Tests
# =============================================================================
@@ -122,7 +255,7 @@ class TestOpenAINativeToolCalling:
self, calculator_tool: CalculatorTool
) -> None:
"""Test OpenAI agent kickoff with mocked LLM call."""
llm = LLM(model="gpt-4o-mini")
llm = LLM(model="gpt-5-nano")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
@@ -146,6 +279,174 @@ class TestOpenAINativeToolCalling:
assert mock_call.called
assert result is not None
@pytest.mark.vcr()
@pytest.mark.timeout(180)
def test_openai_parallel_native_tool_calling_test_crew(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gpt-5-nano", temperature=1),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
@pytest.mark.timeout(180)
def test_openai_parallel_native_tool_calling_test_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gpt-4o-mini"),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
@pytest.mark.timeout(180)
def test_openai_parallel_native_tool_calling_tool_hook_parity_crew(
self, parallel_tools: list[BaseTool]
) -> None:
hook_calls: dict[str, list[dict[str, str]]] = {"before": [], "after": []}
def before_hook(context: ToolCallHookContext) -> bool | None:
if context.tool_name.startswith("parallel_local_search_"):
hook_calls["before"].append(
{
"tool_name": context.tool_name,
"query": str(context.tool_input.get("query", "")),
}
)
return None
def after_hook(context: ToolCallHookContext) -> str | None:
if context.tool_name.startswith("parallel_local_search_"):
hook_calls["after"].append(
{
"tool_name": context.tool_name,
"query": str(context.tool_input.get("query", "")),
}
)
return None
register_before_tool_call_hook(before_hook)
register_after_tool_call_hook(after_hook)
try:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gpt-5-nano", temperature=1),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
before_names = [call["tool_name"] for call in hook_calls["before"]]
after_names = [call["tool_name"] for call in hook_calls["after"]]
assert len(before_names) >= 3, "Expected before hooks for all parallel calls"
assert Counter(before_names) == Counter(after_names)
assert all(call["query"] for call in hook_calls["before"])
assert all(call["query"] for call in hook_calls["after"])
finally:
from crewai.hooks import (
unregister_after_tool_call_hook,
unregister_before_tool_call_hook,
)
unregister_before_tool_call_hook(before_hook)
unregister_after_tool_call_hook(after_hook)
@pytest.mark.vcr()
@pytest.mark.timeout(180)
def test_openai_parallel_native_tool_calling_tool_hook_parity_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
hook_calls: dict[str, list[dict[str, str]]] = {"before": [], "after": []}
def before_hook(context: ToolCallHookContext) -> bool | None:
if context.tool_name.startswith("parallel_local_search_"):
hook_calls["before"].append(
{
"tool_name": context.tool_name,
"query": str(context.tool_input.get("query", "")),
}
)
return None
def after_hook(context: ToolCallHookContext) -> str | None:
if context.tool_name.startswith("parallel_local_search_"):
hook_calls["after"].append(
{
"tool_name": context.tool_name,
"query": str(context.tool_input.get("query", "")),
}
)
return None
register_before_tool_call_hook(before_hook)
register_after_tool_call_hook(after_hook)
try:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gpt-5-nano", temperature=1),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
before_names = [call["tool_name"] for call in hook_calls["before"]]
after_names = [call["tool_name"] for call in hook_calls["after"]]
assert len(before_names) >= 3, "Expected before hooks for all parallel calls"
assert Counter(before_names) == Counter(after_names)
assert all(call["query"] for call in hook_calls["before"])
assert all(call["query"] for call in hook_calls["after"])
finally:
from crewai.hooks import (
unregister_after_tool_call_hook,
unregister_before_tool_call_hook,
)
unregister_before_tool_call_hook(before_hook)
unregister_after_tool_call_hook(after_hook)
# =============================================================================
# Anthropic Provider Tests
@@ -217,6 +518,46 @@ class TestAnthropicNativeToolCalling:
assert mock_call.called
assert result is not None
@pytest.mark.vcr()
def test_anthropic_parallel_native_tool_calling_test_crew(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="anthropic/claude-sonnet-4-6"),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
def test_anthropic_parallel_native_tool_calling_test_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="anthropic/claude-sonnet-4-6"),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
# =============================================================================
# Google/Gemini Provider Tests
@@ -247,7 +588,7 @@ class TestGeminiNativeToolCalling:
goal="Help users with mathematical calculations",
backstory="You are a helpful math assistant.",
tools=[calculator_tool],
llm=LLM(model="gemini/gemini-2.0-flash-exp"),
llm=LLM(model="gemini/gemini-2.5-flash"),
)
task = Task(
@@ -266,7 +607,7 @@ class TestGeminiNativeToolCalling:
self, calculator_tool: CalculatorTool
) -> None:
"""Test Gemini agent kickoff with mocked LLM call."""
llm = LLM(model="gemini/gemini-2.0-flash-001")
llm = LLM(model="gemini/gemini-2.5-flash")
with patch.object(llm, "call", return_value="The answer is 120.") as mock_call:
agent = Agent(
@@ -290,6 +631,46 @@ class TestGeminiNativeToolCalling:
assert mock_call.called
assert result is not None
@pytest.mark.vcr()
def test_gemini_parallel_native_tool_calling_test_crew(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gemini/gemini-2.5-flash"),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
def test_gemini_parallel_native_tool_calling_test_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="gemini/gemini-2.5-flash"),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
# =============================================================================
# Azure Provider Tests
@@ -324,7 +705,7 @@ class TestAzureNativeToolCalling:
goal="Help users with mathematical calculations",
backstory="You are a helpful math assistant.",
tools=[calculator_tool],
llm=LLM(model="azure/gpt-4o-mini"),
llm=LLM(model="azure/gpt-5-nano"),
verbose=False,
max_iter=3,
)
@@ -347,7 +728,7 @@ class TestAzureNativeToolCalling:
) -> None:
"""Test Azure agent kickoff with mocked LLM call."""
llm = LLM(
model="azure/gpt-4o-mini",
model="azure/gpt-5-nano",
api_key="test-key",
base_url="https://test.openai.azure.com",
)
@@ -374,6 +755,46 @@ class TestAzureNativeToolCalling:
assert mock_call.called
assert result is not None
@pytest.mark.vcr()
def test_azure_parallel_native_tool_calling_test_crew(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="azure/gpt-5-nano"),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
def test_azure_parallel_native_tool_calling_test_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="azure/gpt-5-nano"),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
# =============================================================================
# Bedrock Provider Tests
@@ -384,18 +805,30 @@ class TestBedrockNativeToolCalling:
"""Tests for native tool calling with AWS Bedrock models."""
@pytest.fixture(autouse=True)
def mock_aws_env(self):
"""Mock AWS environment variables for tests."""
env_vars = {
"AWS_ACCESS_KEY_ID": "test-key",
"AWS_SECRET_ACCESS_KEY": "test-secret",
"AWS_REGION": "us-east-1",
}
if "AWS_ACCESS_KEY_ID" not in os.environ:
with patch.dict(os.environ, env_vars):
yield
else:
yield
def validate_bedrock_credentials_for_live_recording(self):
"""Run Bedrock tests only when explicitly enabled."""
run_live_bedrock = os.getenv("RUN_BEDROCK_LIVE_TESTS", "false").lower() == "true"
if not run_live_bedrock:
pytest.skip(
"Skipping Bedrock tests by default. "
"Set RUN_BEDROCK_LIVE_TESTS=true with valid AWS credentials to enable."
)
access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
if (
not access_key
or not secret_key
or access_key.startswith(("fake-", "test-"))
or secret_key.startswith(("fake-", "test-"))
):
pytest.skip(
"Skipping Bedrock tests: valid AWS credentials are required when "
"RUN_BEDROCK_LIVE_TESTS=true."
)
yield
@pytest.mark.vcr()
def test_bedrock_agent_kickoff_with_tools_mocked(
@@ -427,6 +860,46 @@ class TestBedrockNativeToolCalling:
assert result.raw is not None
assert "120" in str(result.raw)
@pytest.mark.vcr()
def test_bedrock_parallel_native_tool_calling_test_crew(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="bedrock/anthropic.claude-3-haiku-20240307-v1:0"),
verbose=False,
max_iter=3,
)
task = Task(
description=_parallel_prompt(),
expected_output="A one sentence summary of both tool outputs",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
_assert_tools_overlapped()
@pytest.mark.vcr()
def test_bedrock_parallel_native_tool_calling_test_agent_kickoff(
self, parallel_tools: list[BaseTool]
) -> None:
agent = Agent(
role="Parallel Tool Agent",
goal="Use both tools exactly as instructed",
backstory="You follow tool instructions precisely.",
tools=parallel_tools,
llm=LLM(model="bedrock/anthropic.claude-3-haiku-20240307-v1:0"),
verbose=False,
max_iter=3,
)
result = agent.kickoff(_parallel_prompt())
assert result is not None
_assert_tools_overlapped()
# =============================================================================
# Cross-Provider Native Tool Calling Behavior Tests
@@ -439,7 +912,7 @@ class TestNativeToolCallingBehavior:
def test_supports_function_calling_check(self) -> None:
"""Test that supports_function_calling() is properly checked."""
# OpenAI should support function calling
openai_llm = LLM(model="gpt-4o-mini")
openai_llm = LLM(model="gpt-5-nano")
assert hasattr(openai_llm, "supports_function_calling")
assert openai_llm.supports_function_calling() is True
@@ -475,7 +948,7 @@ class TestNativeToolCallingTokenUsage:
goal="Perform calculations efficiently",
backstory="You calculate things.",
tools=[calculator_tool],
llm=LLM(model="gpt-4o-mini"),
llm=LLM(model="gpt-5-nano"),
verbose=False,
max_iter=3,
)
@@ -519,7 +992,7 @@ def test_native_tool_calling_error_handling(failing_tool: FailingTool):
goal="Perform calculations efficiently",
backstory="You calculate things.",
tools=[failing_tool],
llm=LLM(model="gpt-4o-mini"),
llm=LLM(model="gpt-5-nano"),
verbose=False,
max_iter=3,
)
@@ -578,7 +1051,7 @@ class TestMaxUsageCountWithNativeToolCalling:
goal="Call the counting tool multiple times",
backstory="You are an agent that counts things.",
tools=[tool],
llm=LLM(model="gpt-4o-mini"),
llm=LLM(model="gpt-5-nano"),
verbose=False,
max_iter=5,
)
@@ -606,7 +1079,7 @@ class TestMaxUsageCountWithNativeToolCalling:
goal="Use the counting tool as many times as requested",
backstory="You are an agent that counts things. You must try to use the tool for each value requested.",
tools=[tool],
llm=LLM(model="gpt-4o-mini"),
llm=LLM(model="gpt-5-nano"),
verbose=False,
max_iter=5,
)
@@ -638,7 +1111,7 @@ class TestMaxUsageCountWithNativeToolCalling:
goal="Use the counting tool exactly as requested",
backstory="You are an agent that counts things precisely.",
tools=[tool],
llm=LLM(model="gpt-4o-mini"),
llm=LLM(model="gpt-5-nano"),
verbose=False,
max_iter=5,
)
@@ -653,5 +1126,6 @@ class TestMaxUsageCountWithNativeToolCalling:
result = crew.kickoff()
assert result is not None
# Verify usage count was incremented for each successful call
assert tool.current_usage_count == 2
# Verify the requested calls occurred while keeping usage bounded.
assert tool.current_usage_count >= 2
assert tool.current_usage_count <= tool.max_usage_count