mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Fix token tracking race condition in threading-based async execution
This commit fixes the race condition described in issue #4168 where token tracking was inaccurate when multiple async tasks from the same agent ran concurrently. The fix introduces: 1. Per-agent locks to serialize async task execution for accurate token tracking when multiple async tasks from the same agent run concurrently 2. Token capture callback that captures both tokens_before and tokens_after inside the thread (after acquiring the lock), not when the task is queued 3. Updated _process_async_tasks to handle the new return type from execute_async which now returns (TaskOutput, tokens_before, tokens_after) This ensures that token deltas are accurately attributed to each task even when multiple async tasks from the same agent overlap in execution. Tests added: - test_async_task_token_tracking_uses_per_agent_lock - test_async_task_token_callback_captures_tokens_inside_thread - test_async_task_per_agent_lock_serializes_execution Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
@@ -1153,8 +1154,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any]], int, Any, Any]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
# Per-agent locks to serialize async task execution for accurate token tracking
|
||||
# This ensures that when multiple async tasks from the same agent run,
|
||||
# they execute one at a time so token deltas can be accurately attributed
|
||||
agent_locks: dict[str, threading.Lock] = {}
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
@@ -1172,18 +1178,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
# Capture token usage before async task execution
|
||||
tokens_before = self._get_agent_token_usage(exec_data.agent)
|
||||
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
|
||||
# Get or create a lock for this agent to serialize async task execution
|
||||
# This ensures accurate per-task token tracking
|
||||
agent_id = str(getattr(exec_data.agent, 'id', id(exec_data.agent)))
|
||||
if agent_id not in agent_locks:
|
||||
agent_locks[agent_id] = threading.Lock()
|
||||
agent_lock = agent_locks[agent_id]
|
||||
|
||||
# Create a token capture callback that will be called inside the thread
|
||||
# after task completion (while still holding the lock)
|
||||
def create_token_callback(agent: Any = exec_data.agent) -> Any:
|
||||
return self._get_agent_token_usage(agent)
|
||||
|
||||
future = task.execute_async(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
token_capture_callback=create_token_callback,
|
||||
agent_execution_lock=agent_lock,
|
||||
)
|
||||
futures.append((task, future, task_index, exec_data.agent, tokens_before))
|
||||
# Note: tokens_before is no longer captured here since it will be
|
||||
# captured inside the thread after acquiring the lock
|
||||
futures.append((task, future, task_index, exec_data.agent, None))
|
||||
else:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
@@ -1218,7 +1238,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int, Any, Any]],
|
||||
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any, Any]], int, Any, Any]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
@@ -1450,18 +1470,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: list[tuple[Task, Future[TaskOutput], int, Any, Any]],
|
||||
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any, Any]], int, Any, Any]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
"""Process completed async tasks and attach token metrics.
|
||||
|
||||
The futures contain either:
|
||||
- TaskOutput (if no token tracking was enabled)
|
||||
- tuple of (TaskOutput, tokens_before, tokens_after) (if token tracking was enabled)
|
||||
|
||||
Token tracking is enabled when the task was executed with a token_capture_callback
|
||||
and agent_execution_lock, which ensures accurate per-task token attribution even
|
||||
when multiple async tasks from the same agent run concurrently.
|
||||
"""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, future, task_index, agent, tokens_before in futures:
|
||||
task_output = future.result()
|
||||
for future_task, future, task_index, agent, _ in futures:
|
||||
result = future.result()
|
||||
|
||||
# Capture token usage after async task execution and attach to task output
|
||||
tokens_after = self._get_agent_token_usage(agent)
|
||||
task_output = self._attach_task_token_metrics(
|
||||
task_output, future_task, agent, tokens_before, tokens_after
|
||||
)
|
||||
# Check if result is a tuple (token tracking enabled) or just TaskOutput
|
||||
if isinstance(result, tuple) and len(result) == 3:
|
||||
task_output, tokens_before, tokens_after = result
|
||||
task_output = self._attach_task_token_metrics(
|
||||
task_output, future_task, agent, tokens_before, tokens_after
|
||||
)
|
||||
else:
|
||||
# No token tracking - result is just TaskOutput
|
||||
task_output = result
|
||||
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
cast,
|
||||
get_args,
|
||||
@@ -476,13 +477,34 @@ class Task(BaseModel):
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Future[TaskOutput]:
|
||||
"""Execute the task asynchronously."""
|
||||
future: Future[TaskOutput] = Future()
|
||||
token_capture_callback: Callable[[], Any] | None = None,
|
||||
agent_execution_lock: threading.Lock | None = None,
|
||||
) -> Future[TaskOutput | tuple[TaskOutput, Any, Any]]:
|
||||
"""Execute the task asynchronously.
|
||||
|
||||
Args:
|
||||
agent: The agent to execute the task.
|
||||
context: Context for the task execution.
|
||||
tools: Tools available for the task.
|
||||
token_capture_callback: Optional callback to capture token usage.
|
||||
If provided, the future will return a tuple of
|
||||
(TaskOutput, tokens_before, tokens_after) instead of just TaskOutput.
|
||||
The callback is called twice: once before task execution (after
|
||||
acquiring the lock if one is provided) and once after task completion.
|
||||
agent_execution_lock: Optional lock to serialize task execution for
|
||||
the same agent. This is used to ensure accurate per-task token
|
||||
tracking when multiple async tasks from the same agent run
|
||||
concurrently.
|
||||
|
||||
Returns:
|
||||
Future containing TaskOutput, or tuple of (TaskOutput, tokens_before, tokens_after)
|
||||
if token_capture_callback is provided.
|
||||
"""
|
||||
future: Future[TaskOutput | tuple[TaskOutput, Any, Any]] = Future()
|
||||
threading.Thread(
|
||||
daemon=True,
|
||||
target=self._execute_task_async,
|
||||
args=(agent, context, tools, future),
|
||||
args=(agent, context, tools, future, token_capture_callback, agent_execution_lock),
|
||||
).start()
|
||||
return future
|
||||
|
||||
@@ -491,14 +513,45 @@ class Task(BaseModel):
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
future: Future[TaskOutput],
|
||||
future: Future[TaskOutput | tuple[TaskOutput, Any, Any]],
|
||||
token_capture_callback: Callable[[], Any] | None = None,
|
||||
agent_execution_lock: threading.Lock | None = None,
|
||||
) -> None:
|
||||
"""Execute the task asynchronously with context handling."""
|
||||
"""Execute the task asynchronously with context handling.
|
||||
|
||||
If agent_execution_lock is provided, the task execution will be
|
||||
serialized with other tasks using the same lock. This ensures
|
||||
accurate per-task token tracking by:
|
||||
1. Capturing tokens_before after acquiring the lock
|
||||
2. Executing the task
|
||||
3. Capturing tokens_after immediately after completion
|
||||
4. Releasing the lock
|
||||
|
||||
If token_capture_callback is provided, it will be called twice:
|
||||
once before task execution and once after, both while holding the lock.
|
||||
"""
|
||||
try:
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
if agent_execution_lock:
|
||||
with agent_execution_lock:
|
||||
if token_capture_callback:
|
||||
tokens_before = token_capture_callback()
|
||||
result = self._execute_core(agent, context, tools)
|
||||
if token_capture_callback:
|
||||
tokens_after = token_capture_callback()
|
||||
future.set_result((result, tokens_before, tokens_after))
|
||||
else:
|
||||
future.set_result(result)
|
||||
else:
|
||||
if token_capture_callback:
|
||||
tokens_before = token_capture_callback()
|
||||
result = self._execute_core(agent, context, tools)
|
||||
if token_capture_callback:
|
||||
tokens_after = token_capture_callback()
|
||||
future.set_result((result, tokens_before, tokens_after))
|
||||
else:
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
future.set_exception(e)
|
||||
|
||||
async def aexecute_sync(
|
||||
self,
|
||||
|
||||
@@ -4768,3 +4768,220 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
|
||||
|
||||
def test_async_task_token_tracking_uses_per_agent_lock():
|
||||
"""Test that async tasks from the same agent use per-agent locks for accurate token tracking.
|
||||
|
||||
This test verifies the fix for the race condition described in issue #4168:
|
||||
When multiple tasks with async_execution=True are executed by the same agent,
|
||||
the per-agent lock ensures that token tracking is accurate by serializing
|
||||
task execution and capturing tokens_before/tokens_after inside the thread.
|
||||
"""
|
||||
from crewai.types.usage_metrics import TaskTokenMetrics
|
||||
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="You are a researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research topic 1",
|
||||
expected_output="Research output 1",
|
||||
agent=agent,
|
||||
async_execution=True,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Research topic 2",
|
||||
expected_output="Research output 2",
|
||||
agent=agent,
|
||||
async_execution=True,
|
||||
)
|
||||
|
||||
task3 = Task(
|
||||
description="Summarize research",
|
||||
expected_output="Summary",
|
||||
agent=agent,
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task1, task2, task3])
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test output",
|
||||
raw="Test result",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
execution_order = []
|
||||
lock_acquisitions = []
|
||||
|
||||
original_execute_core = Task._execute_core
|
||||
|
||||
def mock_execute_core(self, agent, context, tools):
|
||||
execution_order.append(self.description)
|
||||
return mock_output
|
||||
|
||||
with patch.object(Task, "_execute_core", mock_execute_core):
|
||||
with patch.object(
|
||||
crew,
|
||||
"_get_agent_token_usage",
|
||||
side_effect=[
|
||||
UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1),
|
||||
UsageMetrics(total_tokens=150, prompt_tokens=120, completion_tokens=30, successful_requests=2),
|
||||
UsageMetrics(total_tokens=150, prompt_tokens=120, completion_tokens=30, successful_requests=2),
|
||||
UsageMetrics(total_tokens=200, prompt_tokens=160, completion_tokens=40, successful_requests=3),
|
||||
UsageMetrics(total_tokens=200, prompt_tokens=160, completion_tokens=40, successful_requests=3),
|
||||
UsageMetrics(total_tokens=250, prompt_tokens=200, completion_tokens=50, successful_requests=4),
|
||||
]
|
||||
):
|
||||
result = crew.kickoff()
|
||||
|
||||
assert len(result.tasks_output) == 3
|
||||
|
||||
for task_output in result.tasks_output:
|
||||
if hasattr(task_output, 'usage_metrics') and task_output.usage_metrics:
|
||||
assert isinstance(task_output.usage_metrics, TaskTokenMetrics)
|
||||
|
||||
|
||||
def test_async_task_token_callback_captures_tokens_inside_thread():
|
||||
"""Test that token capture callback is called inside the thread for async tasks.
|
||||
|
||||
This verifies that tokens_before and tokens_after are captured inside the thread
|
||||
(after acquiring the lock), not when the task is queued.
|
||||
"""
|
||||
from concurrent.futures import Future
|
||||
import time
|
||||
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="You are a researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research topic",
|
||||
expected_output="Research output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
callback_call_times = []
|
||||
callback_thread_ids = []
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
||||
def token_callback():
|
||||
callback_call_times.append(time.time())
|
||||
callback_thread_ids.append(threading.current_thread().ident)
|
||||
return UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1)
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test output",
|
||||
raw="Test result",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
with patch.object(Task, "_execute_core", return_value=mock_output):
|
||||
lock = threading.Lock()
|
||||
future = task.execute_async(
|
||||
agent=agent,
|
||||
context=None,
|
||||
tools=None,
|
||||
token_capture_callback=token_callback,
|
||||
agent_execution_lock=lock,
|
||||
)
|
||||
|
||||
result = future.result(timeout=10)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 3
|
||||
task_output, tokens_before, tokens_after = result
|
||||
|
||||
assert len(callback_call_times) == 2
|
||||
assert len(callback_thread_ids) == 2
|
||||
|
||||
for thread_id in callback_thread_ids:
|
||||
assert thread_id != main_thread_id
|
||||
|
||||
assert callback_thread_ids[0] == callback_thread_ids[1]
|
||||
|
||||
|
||||
def test_async_task_per_agent_lock_serializes_execution():
|
||||
"""Test that per-agent lock serializes async task execution for the same agent.
|
||||
|
||||
This test verifies that when multiple async tasks from the same agent are executed,
|
||||
the per-agent lock ensures they run one at a time (serialized), not concurrently.
|
||||
"""
|
||||
import time
|
||||
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="You are a researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research topic 1",
|
||||
expected_output="Research output 1",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Research topic 2",
|
||||
expected_output="Research output 2",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
execution_times = []
|
||||
|
||||
mock_output = TaskOutput(
|
||||
description="Test output",
|
||||
raw="Test result",
|
||||
agent="Researcher",
|
||||
)
|
||||
|
||||
def slow_execute_core(self, agent, context, tools):
|
||||
start_time = time.time()
|
||||
time.sleep(0.1)
|
||||
end_time = time.time()
|
||||
execution_times.append((start_time, end_time))
|
||||
return mock_output
|
||||
|
||||
with patch.object(Task, "_execute_core", slow_execute_core):
|
||||
lock = threading.Lock()
|
||||
|
||||
def token_callback():
|
||||
return UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1)
|
||||
|
||||
future1 = task1.execute_async(
|
||||
agent=agent,
|
||||
context=None,
|
||||
tools=None,
|
||||
token_capture_callback=token_callback,
|
||||
agent_execution_lock=lock,
|
||||
)
|
||||
|
||||
future2 = task2.execute_async(
|
||||
agent=agent,
|
||||
context=None,
|
||||
tools=None,
|
||||
token_capture_callback=token_callback,
|
||||
agent_execution_lock=lock,
|
||||
)
|
||||
|
||||
result1 = future1.result(timeout=10)
|
||||
result2 = future2.result(timeout=10)
|
||||
|
||||
assert len(execution_times) == 2
|
||||
|
||||
start1, end1 = execution_times[0]
|
||||
start2, end2 = execution_times[1]
|
||||
|
||||
if start1 < start2:
|
||||
assert end1 <= start2, "Tasks should not overlap when using the same lock"
|
||||
else:
|
||||
assert end2 <= start1, "Tasks should not overlap when using the same lock"
|
||||
|
||||
Reference in New Issue
Block a user