mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
Compare commits
11 Commits
1.8.1
...
devin/1767
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e022caae6c | ||
|
|
5b8e42c028 | ||
|
|
563e2eccbd | ||
|
|
5dc87c04af | ||
|
|
0f0538cca7 | ||
|
|
314642f392 | ||
|
|
9bbf53e84a | ||
|
|
afea8a505a | ||
|
|
a0c2662ad9 | ||
|
|
85860610e9 | ||
|
|
56b538c37c |
@@ -7,6 +7,7 @@ from copy import copy as shallow_copy
|
|||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
cast,
|
cast,
|
||||||
@@ -203,6 +204,10 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Metrics for the LLM usage during all tasks execution.",
|
description="Metrics for the LLM usage during all tasks execution.",
|
||||||
)
|
)
|
||||||
|
workflow_token_metrics: Any | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Detailed per-agent and per-task token metrics.",
|
||||||
|
)
|
||||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
@@ -944,17 +949,36 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if task.async_execution:
|
if task.async_execution:
|
||||||
|
# Capture token usage before async task execution
|
||||||
|
tokens_before = self._get_agent_token_usage(exec_data.agent)
|
||||||
|
|
||||||
context = self._get_context(
|
context = self._get_context(
|
||||||
task, [last_sync_output] if last_sync_output else []
|
task, [last_sync_output] if last_sync_output else []
|
||||||
)
|
)
|
||||||
async_task = asyncio.create_task(
|
|
||||||
task.aexecute_sync(
|
# Wrap task execution to capture tokens immediately after completion
|
||||||
agent=exec_data.agent,
|
# Use default arguments to bind loop variables at definition time (fixes B023)
|
||||||
context=context,
|
agent = exec_data.agent
|
||||||
tools=exec_data.tools,
|
tools = exec_data.tools
|
||||||
|
|
||||||
|
async def _wrapped_task_execution(
|
||||||
|
_task=task,
|
||||||
|
_agent=agent,
|
||||||
|
_tools=tools,
|
||||||
|
_context=context,
|
||||||
|
):
|
||||||
|
result = await _task.aexecute_sync(
|
||||||
|
agent=_agent,
|
||||||
|
context=_context,
|
||||||
|
tools=_tools,
|
||||||
)
|
)
|
||||||
)
|
# Capture tokens immediately after task completes
|
||||||
pending_tasks.append((task, async_task, task_index))
|
# This reduces (but doesn't eliminate) race conditions
|
||||||
|
tokens_after = self._get_agent_token_usage(_agent)
|
||||||
|
return result, tokens_after
|
||||||
|
|
||||||
|
async_task = asyncio.create_task(_wrapped_task_execution())
|
||||||
|
pending_tasks.append((task, async_task, task_index, exec_data.agent, tokens_before))
|
||||||
else:
|
else:
|
||||||
if pending_tasks:
|
if pending_tasks:
|
||||||
task_outputs = await self._aprocess_async_tasks(
|
task_outputs = await self._aprocess_async_tasks(
|
||||||
@@ -962,12 +986,22 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
pending_tasks.clear()
|
pending_tasks.clear()
|
||||||
|
|
||||||
|
# Capture token usage before task execution
|
||||||
|
tokens_before = self._get_agent_token_usage(exec_data.agent)
|
||||||
|
|
||||||
context = self._get_context(task, task_outputs)
|
context = self._get_context(task, task_outputs)
|
||||||
task_output = await task.aexecute_sync(
|
task_output = await task.aexecute_sync(
|
||||||
agent=exec_data.agent,
|
agent=exec_data.agent,
|
||||||
context=context,
|
context=context,
|
||||||
tools=exec_data.tools,
|
tools=exec_data.tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Capture token usage after task execution and attach to task output
|
||||||
|
tokens_after = self._get_agent_token_usage(exec_data.agent)
|
||||||
|
task_output = self._attach_task_token_metrics(
|
||||||
|
task_output, task, exec_data.agent, tokens_before, tokens_after
|
||||||
|
)
|
||||||
|
|
||||||
task_outputs.append(task_output)
|
task_outputs.append(task_output)
|
||||||
self._process_task_result(task, task_output)
|
self._process_task_result(task, task_output)
|
||||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||||
@@ -981,7 +1015,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
self,
|
self,
|
||||||
task: ConditionalTask,
|
task: ConditionalTask,
|
||||||
task_outputs: list[TaskOutput],
|
task_outputs: list[TaskOutput],
|
||||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
pending_tasks: list[tuple[Task, asyncio.Task[tuple[TaskOutput, Any]], int, Any, Any]],
|
||||||
task_index: int,
|
task_index: int,
|
||||||
was_replayed: bool,
|
was_replayed: bool,
|
||||||
) -> TaskOutput | None:
|
) -> TaskOutput | None:
|
||||||
@@ -996,13 +1030,20 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
async def _aprocess_async_tasks(
|
async def _aprocess_async_tasks(
|
||||||
self,
|
self,
|
||||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
pending_tasks: list[tuple[Task, asyncio.Task[tuple[TaskOutput, Any]], int, Any, Any]],
|
||||||
was_replayed: bool = False,
|
was_replayed: bool = False,
|
||||||
) -> list[TaskOutput]:
|
) -> list[TaskOutput]:
|
||||||
"""Process pending async tasks and return their outputs."""
|
"""Process pending async tasks and return their outputs."""
|
||||||
task_outputs: list[TaskOutput] = []
|
task_outputs: list[TaskOutput] = []
|
||||||
for future_task, async_task, task_index in pending_tasks:
|
for future_task, async_task, task_index, agent, tokens_before in pending_tasks:
|
||||||
task_output = await async_task
|
# Unwrap the result which includes both output and tokens_after
|
||||||
|
task_output, tokens_after = await async_task
|
||||||
|
|
||||||
|
# Attach token metrics using the captured tokens_after
|
||||||
|
task_output = self._attach_task_token_metrics(
|
||||||
|
task_output, future_task, agent, tokens_before, tokens_after
|
||||||
|
)
|
||||||
|
|
||||||
task_outputs.append(task_output)
|
task_outputs.append(task_output)
|
||||||
self._process_task_result(future_task, task_output)
|
self._process_task_result(future_task, task_output)
|
||||||
self._store_execution_log(
|
self._store_execution_log(
|
||||||
@@ -1122,9 +1163,14 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
task_outputs: list[TaskOutput] = []
|
task_outputs: list[TaskOutput] = []
|
||||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any]], int, Any, Any]] = []
|
||||||
last_sync_output: TaskOutput | None = None
|
last_sync_output: TaskOutput | None = None
|
||||||
|
|
||||||
|
# Per-agent locks to serialize async task execution for accurate token tracking
|
||||||
|
# This ensures that when multiple async tasks from the same agent run,
|
||||||
|
# they execute one at a time so token deltas can be accurately attributed
|
||||||
|
agent_locks: dict[str, threading.Lock] = {}
|
||||||
|
|
||||||
for task_index, task in enumerate(tasks):
|
for task_index, task in enumerate(tasks):
|
||||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||||
@@ -1144,23 +1190,50 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
context = self._get_context(
|
context = self._get_context(
|
||||||
task, [last_sync_output] if last_sync_output else []
|
task, [last_sync_output] if last_sync_output else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get or create a lock for this agent to serialize async task execution
|
||||||
|
# This ensures accurate per-task token tracking
|
||||||
|
agent_id = str(getattr(exec_data.agent, 'id', id(exec_data.agent)))
|
||||||
|
if agent_id not in agent_locks:
|
||||||
|
agent_locks[agent_id] = threading.Lock()
|
||||||
|
agent_lock = agent_locks[agent_id]
|
||||||
|
|
||||||
|
# Create a token capture callback that will be called inside the thread
|
||||||
|
# after task completion (while still holding the lock)
|
||||||
|
def create_token_callback(agent: Any = exec_data.agent) -> Any:
|
||||||
|
return self._get_agent_token_usage(agent)
|
||||||
|
|
||||||
future = task.execute_async(
|
future = task.execute_async(
|
||||||
agent=exec_data.agent,
|
agent=exec_data.agent,
|
||||||
context=context,
|
context=context,
|
||||||
tools=exec_data.tools,
|
tools=exec_data.tools,
|
||||||
|
token_capture_callback=create_token_callback,
|
||||||
|
agent_execution_lock=agent_lock,
|
||||||
)
|
)
|
||||||
futures.append((task, future, task_index))
|
# Note: tokens_before is no longer captured here since it will be
|
||||||
|
# captured inside the thread after acquiring the lock
|
||||||
|
futures.append((task, future, task_index, exec_data.agent, None))
|
||||||
else:
|
else:
|
||||||
if futures:
|
if futures:
|
||||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||||
futures.clear()
|
futures.clear()
|
||||||
|
|
||||||
|
# Capture token usage before task execution
|
||||||
|
tokens_before = self._get_agent_token_usage(exec_data.agent)
|
||||||
|
|
||||||
context = self._get_context(task, task_outputs)
|
context = self._get_context(task, task_outputs)
|
||||||
task_output = task.execute_sync(
|
task_output = task.execute_sync(
|
||||||
agent=exec_data.agent,
|
agent=exec_data.agent,
|
||||||
context=context,
|
context=context,
|
||||||
tools=exec_data.tools,
|
tools=exec_data.tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Capture token usage after task execution and attach to task output
|
||||||
|
tokens_after = self._get_agent_token_usage(exec_data.agent)
|
||||||
|
task_output = self._attach_task_token_metrics(
|
||||||
|
task_output, task, exec_data.agent, tokens_before, tokens_after
|
||||||
|
)
|
||||||
|
|
||||||
task_outputs.append(task_output)
|
task_outputs.append(task_output)
|
||||||
self._process_task_result(task, task_output)
|
self._process_task_result(task, task_output)
|
||||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||||
@@ -1174,7 +1247,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
self,
|
self,
|
||||||
task: ConditionalTask,
|
task: ConditionalTask,
|
||||||
task_outputs: list[TaskOutput],
|
task_outputs: list[TaskOutput],
|
||||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any, Any]], int, Any, Any]],
|
||||||
task_index: int,
|
task_index: int,
|
||||||
was_replayed: bool,
|
was_replayed: bool,
|
||||||
) -> TaskOutput | None:
|
) -> TaskOutput | None:
|
||||||
@@ -1401,16 +1474,38 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
json_dict=final_task_output.json_dict,
|
json_dict=final_task_output.json_dict,
|
||||||
tasks_output=task_outputs,
|
tasks_output=task_outputs,
|
||||||
token_usage=self.token_usage,
|
token_usage=self.token_usage,
|
||||||
|
token_metrics=getattr(self, 'workflow_token_metrics', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_async_tasks(
|
def _process_async_tasks(
|
||||||
self,
|
self,
|
||||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
futures: list[tuple[Task, Future[TaskOutput | tuple[TaskOutput, Any, Any]], int, Any, Any]],
|
||||||
was_replayed: bool = False,
|
was_replayed: bool = False,
|
||||||
) -> list[TaskOutput]:
|
) -> list[TaskOutput]:
|
||||||
|
"""Process completed async tasks and attach token metrics.
|
||||||
|
|
||||||
|
The futures contain either:
|
||||||
|
- TaskOutput (if no token tracking was enabled)
|
||||||
|
- tuple of (TaskOutput, tokens_before, tokens_after) (if token tracking was enabled)
|
||||||
|
|
||||||
|
Token tracking is enabled when the task was executed with a token_capture_callback
|
||||||
|
and agent_execution_lock, which ensures accurate per-task token attribution even
|
||||||
|
when multiple async tasks from the same agent run concurrently.
|
||||||
|
"""
|
||||||
task_outputs: list[TaskOutput] = []
|
task_outputs: list[TaskOutput] = []
|
||||||
for future_task, future, task_index in futures:
|
for future_task, future, task_index, agent, _ in futures:
|
||||||
task_output = future.result()
|
result = future.result()
|
||||||
|
|
||||||
|
# Check if result is a tuple (token tracking enabled) or just TaskOutput
|
||||||
|
if isinstance(result, tuple) and len(result) == 3:
|
||||||
|
task_output, tokens_before, tokens_after = result
|
||||||
|
task_output = self._attach_task_token_metrics(
|
||||||
|
task_output, future_task, agent, tokens_before, tokens_after
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No token tracking - result is just TaskOutput
|
||||||
|
task_output = result
|
||||||
|
|
||||||
task_outputs.append(task_output)
|
task_outputs.append(task_output)
|
||||||
self._process_task_result(future_task, task_output)
|
self._process_task_result(future_task, task_output)
|
||||||
self._store_execution_log(
|
self._store_execution_log(
|
||||||
@@ -1616,12 +1711,83 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||||
"""Calculates and returns the usage metrics."""
|
"""Calculates and returns the usage metrics."""
|
||||||
|
from crewai.types.usage_metrics import (
|
||||||
|
AgentTokenMetrics,
|
||||||
|
WorkflowTokenMetrics,
|
||||||
|
)
|
||||||
|
|
||||||
total_usage_metrics = UsageMetrics()
|
total_usage_metrics = UsageMetrics()
|
||||||
|
|
||||||
|
# Preserve existing workflow_token_metrics if it exists (has per_task data)
|
||||||
|
if hasattr(self, 'workflow_token_metrics') and self.workflow_token_metrics:
|
||||||
|
workflow_metrics = self.workflow_token_metrics
|
||||||
|
else:
|
||||||
|
workflow_metrics = WorkflowTokenMetrics()
|
||||||
|
|
||||||
|
# Build per-agent metrics from per-task data (more accurate)
|
||||||
|
# This avoids the cumulative token issue where all agents show the same total
|
||||||
|
# Key by agent_id to handle multiple agents with the same role
|
||||||
|
agent_token_sums = {}
|
||||||
|
agent_info_map = {} # Map agent_id to (agent_name, agent_id)
|
||||||
|
|
||||||
|
# First, build a map of all agents by their ID
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
|
agent_role = getattr(agent, 'role', 'Unknown Agent')
|
||||||
|
agent_id = str(getattr(agent, 'id', ''))
|
||||||
|
agent_info_map[agent_id] = (agent_role, agent_id)
|
||||||
|
|
||||||
|
if workflow_metrics.per_task:
|
||||||
|
# Sum up tokens for each agent from their tasks
|
||||||
|
# We need to find which agent_id corresponds to each task's agent_name
|
||||||
|
for task_metrics in workflow_metrics.per_task.values():
|
||||||
|
agent_name = task_metrics.agent_name
|
||||||
|
# Find the agent_id for this agent_name from agent_info_map
|
||||||
|
# For now, we'll use the agent_name as a temporary key but this needs improvement
|
||||||
|
# TODO: Store agent_id in TaskTokenMetrics to avoid this lookup
|
||||||
|
matching_agent_ids = [aid for aid, (name, _) in agent_info_map.items() if name == agent_name]
|
||||||
|
|
||||||
|
# Use the first matching agent_id (limitation: can't distinguish between same-role agents)
|
||||||
|
# This is better than nothing but ideally we'd store agent_id in TaskTokenMetrics
|
||||||
|
for agent_id in matching_agent_ids:
|
||||||
|
if agent_id not in agent_token_sums:
|
||||||
|
agent_token_sums[agent_id] = {
|
||||||
|
'total_tokens': 0,
|
||||||
|
'prompt_tokens': 0,
|
||||||
|
'cached_prompt_tokens': 0,
|
||||||
|
'completion_tokens': 0,
|
||||||
|
'successful_requests': 0
|
||||||
|
}
|
||||||
|
# Only add to the first matching agent (this is the limitation)
|
||||||
|
agent_token_sums[agent_id]['total_tokens'] += task_metrics.total_tokens
|
||||||
|
agent_token_sums[agent_id]['prompt_tokens'] += task_metrics.prompt_tokens
|
||||||
|
agent_token_sums[agent_id]['cached_prompt_tokens'] += task_metrics.cached_prompt_tokens
|
||||||
|
agent_token_sums[agent_id]['completion_tokens'] += task_metrics.completion_tokens
|
||||||
|
agent_token_sums[agent_id]['successful_requests'] += task_metrics.successful_requests
|
||||||
|
break # Only add to first matching agent
|
||||||
|
|
||||||
|
# Create per-agent metrics from the summed task data, keyed by agent_id
|
||||||
|
for agent in self.agents:
|
||||||
|
agent_role = getattr(agent, 'role', 'Unknown Agent')
|
||||||
|
agent_id = str(getattr(agent, 'id', ''))
|
||||||
|
|
||||||
|
if agent_id in agent_token_sums:
|
||||||
|
# Use accurate per-task summed data
|
||||||
|
sums = agent_token_sums[agent_id]
|
||||||
|
agent_metrics = AgentTokenMetrics(
|
||||||
|
agent_name=agent_role,
|
||||||
|
agent_id=agent_id,
|
||||||
|
total_tokens=sums['total_tokens'],
|
||||||
|
prompt_tokens=sums['prompt_tokens'],
|
||||||
|
cached_prompt_tokens=sums['cached_prompt_tokens'],
|
||||||
|
completion_tokens=sums['completion_tokens'],
|
||||||
|
successful_requests=sums['successful_requests']
|
||||||
|
)
|
||||||
|
# Key by agent_id to avoid collision for agents with same role
|
||||||
|
workflow_metrics.per_agent[agent_id] = agent_metrics
|
||||||
|
|
||||||
|
# Still get total usage for overall metrics
|
||||||
if isinstance(agent.llm, BaseLLM):
|
if isinstance(agent.llm, BaseLLM):
|
||||||
llm_usage = agent.llm.get_token_usage_summary()
|
llm_usage = agent.llm.get_token_usage_summary()
|
||||||
|
|
||||||
total_usage_metrics.add_usage_metrics(llm_usage)
|
total_usage_metrics.add_usage_metrics(llm_usage)
|
||||||
else:
|
else:
|
||||||
# fallback litellm
|
# fallback litellm
|
||||||
@@ -1629,22 +1795,65 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
token_sum = agent._token_process.get_summary()
|
token_sum = agent._token_process.get_summary()
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
|
|
||||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
if self.manager_agent:
|
||||||
token_sum = self.manager_agent._token_process.get_summary()
|
manager_role = getattr(self.manager_agent, 'role', 'Manager Agent')
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
manager_id = str(getattr(self.manager_agent, 'id', ''))
|
||||||
|
|
||||||
if (
|
if hasattr(self.manager_agent, "_token_process"):
|
||||||
self.manager_agent
|
token_sum = self.manager_agent._token_process.get_summary()
|
||||||
and hasattr(self.manager_agent, "llm")
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
and hasattr(self.manager_agent.llm, "get_token_usage_summary")
|
|
||||||
):
|
|
||||||
if isinstance(self.manager_agent.llm, BaseLLM):
|
|
||||||
llm_usage = self.manager_agent.llm.get_token_usage_summary()
|
|
||||||
else:
|
|
||||||
llm_usage = self.manager_agent.llm._token_process.get_summary()
|
|
||||||
|
|
||||||
total_usage_metrics.add_usage_metrics(llm_usage)
|
# Create per-agent metrics for manager
|
||||||
|
manager_metrics = AgentTokenMetrics(
|
||||||
|
agent_name=manager_role,
|
||||||
|
agent_id=manager_id,
|
||||||
|
total_tokens=token_sum.total_tokens,
|
||||||
|
prompt_tokens=token_sum.prompt_tokens,
|
||||||
|
cached_prompt_tokens=token_sum.cached_prompt_tokens,
|
||||||
|
completion_tokens=token_sum.completion_tokens,
|
||||||
|
successful_requests=token_sum.successful_requests
|
||||||
|
)
|
||||||
|
workflow_metrics.per_agent[manager_role] = manager_metrics
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.manager_agent, "llm")
|
||||||
|
and hasattr(self.manager_agent.llm, "get_token_usage_summary")
|
||||||
|
):
|
||||||
|
if isinstance(self.manager_agent.llm, BaseLLM):
|
||||||
|
llm_usage = self.manager_agent.llm.get_token_usage_summary()
|
||||||
|
else:
|
||||||
|
llm_usage = self.manager_agent.llm._token_process.get_summary()
|
||||||
|
|
||||||
|
total_usage_metrics.add_usage_metrics(llm_usage)
|
||||||
|
|
||||||
|
# Update or create manager metrics
|
||||||
|
if manager_role in workflow_metrics.per_agent:
|
||||||
|
workflow_metrics.per_agent[manager_role].total_tokens += llm_usage.total_tokens
|
||||||
|
workflow_metrics.per_agent[manager_role].prompt_tokens += llm_usage.prompt_tokens
|
||||||
|
workflow_metrics.per_agent[manager_role].cached_prompt_tokens += llm_usage.cached_prompt_tokens
|
||||||
|
workflow_metrics.per_agent[manager_role].completion_tokens += llm_usage.completion_tokens
|
||||||
|
workflow_metrics.per_agent[manager_role].successful_requests += llm_usage.successful_requests
|
||||||
|
else:
|
||||||
|
manager_metrics = AgentTokenMetrics(
|
||||||
|
agent_name=manager_role,
|
||||||
|
agent_id=manager_id,
|
||||||
|
total_tokens=llm_usage.total_tokens,
|
||||||
|
prompt_tokens=llm_usage.prompt_tokens,
|
||||||
|
cached_prompt_tokens=llm_usage.cached_prompt_tokens,
|
||||||
|
completion_tokens=llm_usage.completion_tokens,
|
||||||
|
successful_requests=llm_usage.successful_requests
|
||||||
|
)
|
||||||
|
workflow_metrics.per_agent[manager_role] = manager_metrics
|
||||||
|
|
||||||
|
# Set workflow-level totals
|
||||||
|
workflow_metrics.total_tokens = total_usage_metrics.total_tokens
|
||||||
|
workflow_metrics.prompt_tokens = total_usage_metrics.prompt_tokens
|
||||||
|
workflow_metrics.cached_prompt_tokens = total_usage_metrics.cached_prompt_tokens
|
||||||
|
workflow_metrics.completion_tokens = total_usage_metrics.completion_tokens
|
||||||
|
workflow_metrics.successful_requests = total_usage_metrics.successful_requests
|
||||||
|
|
||||||
|
# Store workflow metrics (preserving per_task data)
|
||||||
|
self.workflow_token_metrics = workflow_metrics
|
||||||
self.usage_metrics = total_usage_metrics
|
self.usage_metrics = total_usage_metrics
|
||||||
return total_usage_metrics
|
return total_usage_metrics
|
||||||
|
|
||||||
@@ -1918,3 +2127,56 @@ To enable tracing, do any one of these:
|
|||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
)
|
||||||
console.print(panel)
|
console.print(panel)
|
||||||
|
|
||||||
|
def _get_agent_token_usage(self, agent: BaseAgent | None) -> UsageMetrics:
|
||||||
|
"""Get current token usage for an agent."""
|
||||||
|
if not agent:
|
||||||
|
return UsageMetrics()
|
||||||
|
|
||||||
|
if isinstance(agent.llm, BaseLLM):
|
||||||
|
return agent.llm.get_token_usage_summary()
|
||||||
|
if hasattr(agent, "_token_process"):
|
||||||
|
return agent._token_process.get_summary()
|
||||||
|
|
||||||
|
return UsageMetrics()
|
||||||
|
|
||||||
|
def _attach_task_token_metrics(
|
||||||
|
self,
|
||||||
|
task_output: TaskOutput,
|
||||||
|
task: Task,
|
||||||
|
agent: BaseAgent | None,
|
||||||
|
tokens_before: UsageMetrics,
|
||||||
|
tokens_after: UsageMetrics
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Attach per-task token metrics to the task output."""
|
||||||
|
from crewai.types.usage_metrics import TaskTokenMetrics
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
# Calculate the delta (tokens used by this specific task)
|
||||||
|
task_tokens = TaskTokenMetrics(
|
||||||
|
task_name=getattr(task, 'name', None) or task.description[:50],
|
||||||
|
task_id=str(getattr(task, 'id', '')),
|
||||||
|
agent_name=getattr(agent, 'role', 'Unknown Agent'),
|
||||||
|
total_tokens=tokens_after.total_tokens - tokens_before.total_tokens,
|
||||||
|
prompt_tokens=tokens_after.prompt_tokens - tokens_before.prompt_tokens,
|
||||||
|
cached_prompt_tokens=tokens_after.cached_prompt_tokens - tokens_before.cached_prompt_tokens,
|
||||||
|
completion_tokens=tokens_after.completion_tokens - tokens_before.completion_tokens,
|
||||||
|
successful_requests=tokens_after.successful_requests - tokens_before.successful_requests
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attach to task output
|
||||||
|
task_output.usage_metrics = task_tokens
|
||||||
|
|
||||||
|
# Store in workflow metrics
|
||||||
|
if not hasattr(self, 'workflow_token_metrics') or self.workflow_token_metrics is None:
|
||||||
|
from crewai.types.usage_metrics import WorkflowTokenMetrics
|
||||||
|
self.workflow_token_metrics = WorkflowTokenMetrics()
|
||||||
|
|
||||||
|
# Use task_id in the key to prevent collision when multiple tasks have the same name
|
||||||
|
task_key = f"{task_tokens.task_id}_{task_tokens.task_name}_{task_tokens.agent_name}"
|
||||||
|
self.workflow_token_metrics.per_task[task_key] = task_tokens
|
||||||
|
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics, WorkflowTokenMetrics
|
||||||
|
|
||||||
|
|
||||||
class CrewOutput(BaseModel):
|
class CrewOutput(BaseModel):
|
||||||
@@ -26,6 +26,10 @@ class CrewOutput(BaseModel):
|
|||||||
token_usage: UsageMetrics = Field(
|
token_usage: UsageMetrics = Field(
|
||||||
description="Processed token summary", default_factory=UsageMetrics
|
description="Processed token summary", default_factory=UsageMetrics
|
||||||
)
|
)
|
||||||
|
token_metrics: WorkflowTokenMetrics | None = Field(
|
||||||
|
description="Detailed per-agent and per-task token metrics",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json(self) -> str | None: # type: ignore[override]
|
def json(self) -> str | None: # type: ignore[override]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
import datetime
|
import datetime
|
||||||
@@ -476,13 +477,34 @@ class Task(BaseModel):
|
|||||||
agent: BaseAgent | None = None,
|
agent: BaseAgent | None = None,
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
tools: list[BaseTool] | None = None,
|
tools: list[BaseTool] | None = None,
|
||||||
) -> Future[TaskOutput]:
|
token_capture_callback: Callable[[], Any] | None = None,
|
||||||
"""Execute the task asynchronously."""
|
agent_execution_lock: threading.Lock | None = None,
|
||||||
future: Future[TaskOutput] = Future()
|
) -> Future[TaskOutput | tuple[TaskOutput, Any, Any]]:
|
||||||
|
"""Execute the task asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent to execute the task.
|
||||||
|
context: Context for the task execution.
|
||||||
|
tools: Tools available for the task.
|
||||||
|
token_capture_callback: Optional callback to capture token usage.
|
||||||
|
If provided, the future will return a tuple of
|
||||||
|
(TaskOutput, tokens_before, tokens_after) instead of just TaskOutput.
|
||||||
|
The callback is called twice: once before task execution (after
|
||||||
|
acquiring the lock if one is provided) and once after task completion.
|
||||||
|
agent_execution_lock: Optional lock to serialize task execution for
|
||||||
|
the same agent. This is used to ensure accurate per-task token
|
||||||
|
tracking when multiple async tasks from the same agent run
|
||||||
|
concurrently.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Future containing TaskOutput, or tuple of (TaskOutput, tokens_before, tokens_after)
|
||||||
|
if token_capture_callback is provided.
|
||||||
|
"""
|
||||||
|
future: Future[TaskOutput | tuple[TaskOutput, Any, Any]] = Future()
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
daemon=True,
|
daemon=True,
|
||||||
target=self._execute_task_async,
|
target=self._execute_task_async,
|
||||||
args=(agent, context, tools, future),
|
args=(agent, context, tools, future, token_capture_callback, agent_execution_lock),
|
||||||
).start()
|
).start()
|
||||||
return future
|
return future
|
||||||
|
|
||||||
@@ -491,14 +513,45 @@ class Task(BaseModel):
|
|||||||
agent: BaseAgent | None,
|
agent: BaseAgent | None,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[Any] | None,
|
tools: list[Any] | None,
|
||||||
future: Future[TaskOutput],
|
future: Future[TaskOutput | tuple[TaskOutput, Any, Any]],
|
||||||
|
token_capture_callback: Callable[[], Any] | None = None,
|
||||||
|
agent_execution_lock: threading.Lock | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the task asynchronously with context handling."""
|
"""Execute the task asynchronously with context handling.
|
||||||
|
|
||||||
|
If agent_execution_lock is provided, the task execution will be
|
||||||
|
serialized with other tasks using the same lock. This ensures
|
||||||
|
accurate per-task token tracking by:
|
||||||
|
1. Capturing tokens_before after acquiring the lock
|
||||||
|
2. Executing the task
|
||||||
|
3. Capturing tokens_after immediately after completion
|
||||||
|
4. Releasing the lock
|
||||||
|
|
||||||
|
If token_capture_callback is provided, it will be called twice:
|
||||||
|
once before task execution and once after, both while holding the lock.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
result = self._execute_core(agent, context, tools)
|
if agent_execution_lock:
|
||||||
future.set_result(result)
|
with agent_execution_lock:
|
||||||
|
if token_capture_callback:
|
||||||
|
tokens_before = token_capture_callback()
|
||||||
|
result = self._execute_core(agent, context, tools)
|
||||||
|
if token_capture_callback:
|
||||||
|
tokens_after = token_capture_callback()
|
||||||
|
future.set_result((result, tokens_before, tokens_after))
|
||||||
|
else:
|
||||||
|
future.set_result(result)
|
||||||
|
else:
|
||||||
|
if token_capture_callback:
|
||||||
|
tokens_before = token_capture_callback()
|
||||||
|
result = self._execute_core(agent, context, tools)
|
||||||
|
if token_capture_callback:
|
||||||
|
tokens_after = token_capture_callback()
|
||||||
|
future.set_result((result, tokens_before, tokens_after))
|
||||||
|
else:
|
||||||
|
future.set_result(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
future.set_exception(e)
|
future.set_exception(e)
|
||||||
|
|
||||||
async def aexecute_sync(
|
async def aexecute_sync(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
|
from crewai.types.usage_metrics import TaskTokenMetrics
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ class TaskOutput(BaseModel):
|
|||||||
json_dict: JSON dictionary output of the task
|
json_dict: JSON dictionary output of the task
|
||||||
agent: Agent that executed the task
|
agent: Agent that executed the task
|
||||||
output_format: Output format of the task (JSON, PYDANTIC, or RAW)
|
output_format: Output format of the task (JSON, PYDANTIC, or RAW)
|
||||||
|
usage_metrics: Token usage metrics for this specific task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
description: str = Field(description="Description of the task")
|
description: str = Field(description="Description of the task")
|
||||||
@@ -42,6 +44,10 @@ class TaskOutput(BaseModel):
|
|||||||
description="Output format of the task", default=OutputFormat.RAW
|
description="Output format of the task", default=OutputFormat.RAW
|
||||||
)
|
)
|
||||||
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
||||||
|
usage_metrics: TaskTokenMetrics | None = Field(
|
||||||
|
description="Token usage metrics for this task",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_summary(self):
|
def set_summary(self):
|
||||||
|
|||||||
@@ -44,3 +44,74 @@ class UsageMetrics(BaseModel):
|
|||||||
self.cached_prompt_tokens += usage_metrics.cached_prompt_tokens
|
self.cached_prompt_tokens += usage_metrics.cached_prompt_tokens
|
||||||
self.completion_tokens += usage_metrics.completion_tokens
|
self.completion_tokens += usage_metrics.completion_tokens
|
||||||
self.successful_requests += usage_metrics.successful_requests
|
self.successful_requests += usage_metrics.successful_requests
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTokenMetrics(BaseModel):
|
||||||
|
"""Token usage metrics for a specific agent.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agent_name: Name/role of the agent
|
||||||
|
agent_id: Unique identifier for the agent
|
||||||
|
total_tokens: Total tokens used by this agent
|
||||||
|
prompt_tokens: Prompt tokens used by this agent
|
||||||
|
completion_tokens: Completion tokens used by this agent
|
||||||
|
successful_requests: Number of successful LLM requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
agent_name: str = Field(description="Name/role of the agent")
|
||||||
|
agent_id: str | None = Field(default=None, description="Unique identifier for the agent")
|
||||||
|
total_tokens: int = Field(default=0, description="Total tokens used by this agent")
|
||||||
|
prompt_tokens: int = Field(default=0, description="Prompt tokens used by this agent")
|
||||||
|
cached_prompt_tokens: int = Field(default=0, description="Cached prompt tokens used by this agent")
|
||||||
|
completion_tokens: int = Field(default=0, description="Completion tokens used by this agent")
|
||||||
|
successful_requests: int = Field(default=0, description="Number of successful LLM requests")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTokenMetrics(BaseModel):
|
||||||
|
"""Token usage metrics for a specific task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_name: Name of the task
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
agent_name: Name of the agent that executed the task
|
||||||
|
total_tokens: Total tokens used for this task
|
||||||
|
prompt_tokens: Prompt tokens used for this task
|
||||||
|
completion_tokens: Completion tokens used for this task
|
||||||
|
successful_requests: Number of successful LLM requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name: str = Field(description="Name of the task")
|
||||||
|
task_id: str | None = Field(default=None, description="Unique identifier for the task")
|
||||||
|
agent_name: str = Field(description="Name of the agent that executed the task")
|
||||||
|
total_tokens: int = Field(default=0, description="Total tokens used for this task")
|
||||||
|
prompt_tokens: int = Field(default=0, description="Prompt tokens used for this task")
|
||||||
|
cached_prompt_tokens: int = Field(default=0, description="Cached prompt tokens used for this task")
|
||||||
|
completion_tokens: int = Field(default=0, description="Completion tokens used for this task")
|
||||||
|
successful_requests: int = Field(default=0, description="Number of successful LLM requests")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowTokenMetrics(BaseModel):
|
||||||
|
"""Complete token usage metrics for a crew workflow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total_tokens: Total tokens used across entire workflow
|
||||||
|
prompt_tokens: Total prompt tokens used
|
||||||
|
completion_tokens: Total completion tokens used
|
||||||
|
successful_requests: Total successful requests
|
||||||
|
per_agent: Dictionary mapping agent names to their token metrics
|
||||||
|
per_task: Dictionary mapping task names to their token metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_tokens: int = Field(default=0, description="Total tokens used across entire workflow")
|
||||||
|
prompt_tokens: int = Field(default=0, description="Total prompt tokens used")
|
||||||
|
cached_prompt_tokens: int = Field(default=0, description="Total cached prompt tokens used")
|
||||||
|
completion_tokens: int = Field(default=0, description="Total completion tokens used")
|
||||||
|
successful_requests: int = Field(default=0, description="Total successful requests")
|
||||||
|
per_agent: dict[str, AgentTokenMetrics] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Token metrics per agent"
|
||||||
|
)
|
||||||
|
per_task: dict[str, TaskTokenMetrics] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Token metrics per task"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4768,3 +4768,220 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
|||||||
assert "Researcher" in messages[0]["content"]
|
assert "Researcher" in messages[0]["content"]
|
||||||
assert messages[1]["role"] == "user"
|
assert messages[1]["role"] == "user"
|
||||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_task_token_tracking_uses_per_agent_lock():
|
||||||
|
"""Test that async tasks from the same agent use per-agent locks for accurate token tracking.
|
||||||
|
|
||||||
|
This test verifies the fix for the race condition described in issue #4168:
|
||||||
|
When multiple tasks with async_execution=True are executed by the same agent,
|
||||||
|
the per-agent lock ensures that token tracking is accurate by serializing
|
||||||
|
task execution and capturing tokens_before/tokens_after inside the thread.
|
||||||
|
"""
|
||||||
|
from crewai.types.usage_metrics import TaskTokenMetrics
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Research topics",
|
||||||
|
backstory="You are a researcher",
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
task1 = Task(
|
||||||
|
description="Research topic 1",
|
||||||
|
expected_output="Research output 1",
|
||||||
|
agent=agent,
|
||||||
|
async_execution=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task2 = Task(
|
||||||
|
description="Research topic 2",
|
||||||
|
expected_output="Research output 2",
|
||||||
|
agent=agent,
|
||||||
|
async_execution=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task3 = Task(
|
||||||
|
description="Summarize research",
|
||||||
|
expected_output="Summary",
|
||||||
|
agent=agent,
|
||||||
|
async_execution=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task1, task2, task3])
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test output",
|
||||||
|
raw="Test result",
|
||||||
|
agent="Researcher",
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_order = []
|
||||||
|
lock_acquisitions = []
|
||||||
|
|
||||||
|
original_execute_core = Task._execute_core
|
||||||
|
|
||||||
|
def mock_execute_core(self, agent, context, tools):
|
||||||
|
execution_order.append(self.description)
|
||||||
|
return mock_output
|
||||||
|
|
||||||
|
with patch.object(Task, "_execute_core", mock_execute_core):
|
||||||
|
with patch.object(
|
||||||
|
crew,
|
||||||
|
"_get_agent_token_usage",
|
||||||
|
side_effect=[
|
||||||
|
UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1),
|
||||||
|
UsageMetrics(total_tokens=150, prompt_tokens=120, completion_tokens=30, successful_requests=2),
|
||||||
|
UsageMetrics(total_tokens=150, prompt_tokens=120, completion_tokens=30, successful_requests=2),
|
||||||
|
UsageMetrics(total_tokens=200, prompt_tokens=160, completion_tokens=40, successful_requests=3),
|
||||||
|
UsageMetrics(total_tokens=200, prompt_tokens=160, completion_tokens=40, successful_requests=3),
|
||||||
|
UsageMetrics(total_tokens=250, prompt_tokens=200, completion_tokens=50, successful_requests=4),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
assert len(result.tasks_output) == 3
|
||||||
|
|
||||||
|
for task_output in result.tasks_output:
|
||||||
|
if hasattr(task_output, 'usage_metrics') and task_output.usage_metrics:
|
||||||
|
assert isinstance(task_output.usage_metrics, TaskTokenMetrics)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_task_token_callback_captures_tokens_inside_thread():
|
||||||
|
"""Test that token capture callback is called inside the thread for async tasks.
|
||||||
|
|
||||||
|
This verifies that tokens_before and tokens_after are captured inside the thread
|
||||||
|
(after acquiring the lock), not when the task is queued.
|
||||||
|
"""
|
||||||
|
from concurrent.futures import Future
|
||||||
|
import time
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Research topics",
|
||||||
|
backstory="You are a researcher",
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Research topic",
|
||||||
|
expected_output="Research output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
callback_call_times = []
|
||||||
|
callback_thread_ids = []
|
||||||
|
main_thread_id = threading.current_thread().ident
|
||||||
|
|
||||||
|
def token_callback():
|
||||||
|
callback_call_times.append(time.time())
|
||||||
|
callback_thread_ids.append(threading.current_thread().ident)
|
||||||
|
return UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1)
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test output",
|
||||||
|
raw="Test result",
|
||||||
|
agent="Researcher",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(Task, "_execute_core", return_value=mock_output):
|
||||||
|
lock = threading.Lock()
|
||||||
|
future = task.execute_async(
|
||||||
|
agent=agent,
|
||||||
|
context=None,
|
||||||
|
tools=None,
|
||||||
|
token_capture_callback=token_callback,
|
||||||
|
agent_execution_lock=lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = future.result(timeout=10)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 3
|
||||||
|
task_output, tokens_before, tokens_after = result
|
||||||
|
|
||||||
|
assert len(callback_call_times) == 2
|
||||||
|
assert len(callback_thread_ids) == 2
|
||||||
|
|
||||||
|
for thread_id in callback_thread_ids:
|
||||||
|
assert thread_id != main_thread_id
|
||||||
|
|
||||||
|
assert callback_thread_ids[0] == callback_thread_ids[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_task_per_agent_lock_serializes_execution():
|
||||||
|
"""Test that per-agent lock serializes async task execution for the same agent.
|
||||||
|
|
||||||
|
This test verifies that when multiple async tasks from the same agent are executed,
|
||||||
|
the per-agent lock ensures they run one at a time (serialized), not concurrently.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Research topics",
|
||||||
|
backstory="You are a researcher",
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
task1 = Task(
|
||||||
|
description="Research topic 1",
|
||||||
|
expected_output="Research output 1",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
task2 = Task(
|
||||||
|
description="Research topic 2",
|
||||||
|
expected_output="Research output 2",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_times = []
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test output",
|
||||||
|
raw="Test result",
|
||||||
|
agent="Researcher",
|
||||||
|
)
|
||||||
|
|
||||||
|
def slow_execute_core(self, agent, context, tools):
|
||||||
|
start_time = time.time()
|
||||||
|
time.sleep(0.1)
|
||||||
|
end_time = time.time()
|
||||||
|
execution_times.append((start_time, end_time))
|
||||||
|
return mock_output
|
||||||
|
|
||||||
|
with patch.object(Task, "_execute_core", slow_execute_core):
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def token_callback():
|
||||||
|
return UsageMetrics(total_tokens=100, prompt_tokens=80, completion_tokens=20, successful_requests=1)
|
||||||
|
|
||||||
|
future1 = task1.execute_async(
|
||||||
|
agent=agent,
|
||||||
|
context=None,
|
||||||
|
tools=None,
|
||||||
|
token_capture_callback=token_callback,
|
||||||
|
agent_execution_lock=lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
future2 = task2.execute_async(
|
||||||
|
agent=agent,
|
||||||
|
context=None,
|
||||||
|
tools=None,
|
||||||
|
token_capture_callback=token_callback,
|
||||||
|
agent_execution_lock=lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
result1 = future1.result(timeout=10)
|
||||||
|
result2 = future2.result(timeout=10)
|
||||||
|
|
||||||
|
assert len(execution_times) == 2
|
||||||
|
|
||||||
|
start1, end1 = execution_times[0]
|
||||||
|
start2, end2 = execution_times[1]
|
||||||
|
|
||||||
|
if start1 < start2:
|
||||||
|
assert end1 <= start2, "Tasks should not overlap when using the same lock"
|
||||||
|
else:
|
||||||
|
assert end2 <= start1, "Tasks should not overlap when using the same lock"
|
||||||
|
|||||||
Reference in New Issue
Block a user