mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 21:28:10 +00:00
* feat: aggregate LLM token usage at the flow level Introduces `flow.usage_metrics`, a snapshot of every LLMCallCompletedEvent emitted under the flow's `current_flow_id` for the duration of one kickoff (or resume) call. Aggregation happens on the singleton event bus so it covers crews, direct `LLM.call`s, and nested listener calls — solving the mismatch where the SDK reported only the last crew's usage while the Enterprise UI showed the correct full total. Co-authored-by: Cursor <cursoragent@cursor.com> * refactor: centralize provider key normalization in UsageMetrics Add UsageMetrics.from_provider_dict to normalize raw LLM usage dicts across providers (LiteLLM, native Anthropic, native Gemini, OpenAI nested cached). BaseLLM._track_token_usage_internal and the flow-level aggregator now share this single source of truth, so `flow.usage_metrics` agrees with per-LLM totals on every provider — including the native Anthropic path that emits `input_tokens`/`output_tokens` instead of `prompt_tokens`/`completion_tokens`. * fix: flush event bus before reading aggregated usage_metrics `crewai_event_bus.emit` dispatches LLMCallCompletedEvent handlers on a ThreadPoolExecutor (fire-and-forget), so a flow whose last LLM call completes right before kickoff_async/resume_async returns can detach the usage listener while that handler is still queued, leaving its tokens off `flow.usage_metrics`. Match `Crew.kickoff()` and call `crewai_event_bus.flush()` in both finally blocks so every handler drains before the listener is detached. --------- Co-authored-by: Cursor <cursoragent@cursor.com>
512 lines
18 KiB
Python
512 lines
18 KiB
Python
"""Tests for flow-level token usage aggregation
|
|
|
|
``flow.usage_metrics`` listens to ``LLMCallCompletedEvent`` for the duration
|
|
of ``kickoff_async`` so it covers every LLM call inside the flow — crew-led,
|
|
tool-led, AND bare ``LLM.call(...)`` from a flow method. We exercise the
|
|
aggregator end-to-end through the real event bus with fabricated events and
|
|
explicit contextvar control; no live LLM provider is required.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextvars
|
|
import os
|
|
import tempfile
|
|
from typing import Any, Callable
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
|
|
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
|
from crewai.flow.flow import Flow, listen, start
|
|
from crewai.flow.flow_context import current_flow_id
|
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
|
from crewai.flow.runtime import _usage_dict_to_metrics
|
|
from crewai.types.usage_metrics import UsageMetrics
|
|
|
|
|
|
def _emit_llm_call(
|
|
*,
|
|
flow_id: str | None,
|
|
prompt_tokens: int = 0,
|
|
completion_tokens: int = 0,
|
|
cached_prompt_tokens: int = 0,
|
|
reasoning_tokens: int = 0,
|
|
cache_creation_tokens: int = 0,
|
|
) -> None:
|
|
"""Emit one fake ``LLMCallCompletedEvent`` with ``current_flow_id`` pinned
|
|
to ``flow_id``.
|
|
|
|
Runs in a freshly-copied context so the value the bus snapshots at emit
|
|
time is exactly ``flow_id`` — independent of the calling thread's outer
|
|
context. Mirrors how the real ``LLM.call`` emits events at runtime.
|
|
"""
|
|
usage: dict[str, Any] = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
}
|
|
for key, value in (
|
|
("cached_prompt_tokens", cached_prompt_tokens),
|
|
("reasoning_tokens", reasoning_tokens),
|
|
("cache_creation_tokens", cache_creation_tokens),
|
|
):
|
|
if value:
|
|
usage[key] = value
|
|
event = LLMCallCompletedEvent(
|
|
call_id=str(uuid4()),
|
|
model="gpt-4o-mini",
|
|
response="ok",
|
|
call_type=LLMCallType.LLM_CALL,
|
|
usage=usage,
|
|
)
|
|
|
|
ctx = contextvars.copy_context()
|
|
|
|
def _emit() -> None:
|
|
current_flow_id.set(flow_id)
|
|
future = crewai_event_bus.emit(object(), event)
|
|
if future is not None:
|
|
future.result(timeout=5.0)
|
|
|
|
ctx.run(_emit)
|
|
|
|
|
|
class _ScriptedFlow(Flow):
|
|
"""A Flow whose ``@start`` delegates to a per-instance ``_script`` closure.
|
|
|
|
Each test attaches a script with ``flow._script = lambda f: ...`` so we
|
|
don't redefine a Flow subclass for every scenario.
|
|
"""
|
|
|
|
@start()
|
|
def run(self) -> None:
|
|
script: Callable[[Flow], None] = getattr(self, "_script", lambda _f: None)
|
|
script(self)
|
|
|
|
|
|
def _run(script: Callable[[Flow], None] = lambda _f: None) -> Flow:
|
|
"""Build a ``_ScriptedFlow``, attach ``script``, kickoff. Returns the flow."""
|
|
flow = _ScriptedFlow()
|
|
flow._script = script
|
|
flow.kickoff()
|
|
return flow
|
|
|
|
|
|
class TestUsageDictToMetrics:
|
|
"""Unit tests for the dict-to-UsageMetrics normalizer."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"usage, expected",
|
|
[
|
|
(None, None),
|
|
({}, None),
|
|
(
|
|
{"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
UsageMetrics(
|
|
prompt_tokens=10,
|
|
completion_tokens=20,
|
|
total_tokens=30,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# total_tokens missing → derived from prompt + completion
|
|
(
|
|
{"prompt_tokens": 4, "completion_tokens": 6},
|
|
UsageMetrics(
|
|
prompt_tokens=4,
|
|
completion_tokens=6,
|
|
total_tokens=10,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# Extended provider-specific keys flow through normalization
|
|
(
|
|
{
|
|
"prompt_tokens": 100,
|
|
"completion_tokens": 80,
|
|
"total_tokens": 180,
|
|
"cached_prompt_tokens": 40,
|
|
"reasoning_tokens": 25,
|
|
"cache_creation_tokens": 10,
|
|
},
|
|
UsageMetrics(
|
|
prompt_tokens=100,
|
|
completion_tokens=80,
|
|
total_tokens=180,
|
|
cached_prompt_tokens=40,
|
|
reasoning_tokens=25,
|
|
cache_creation_tokens=10,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# Garbage / non-int values coerce to 0 instead of crashing
|
|
(
|
|
{"prompt_tokens": "n/a", "completion_tokens": None, "total_tokens": 7},
|
|
UsageMetrics(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# Native Anthropic provider emits input_tokens/output_tokens
|
|
(
|
|
{"input_tokens": 12, "output_tokens": 8},
|
|
UsageMetrics(
|
|
prompt_tokens=12,
|
|
completion_tokens=8,
|
|
total_tokens=20,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# Native Gemini provider emits prompt_token_count/candidates_token_count
|
|
(
|
|
{
|
|
"prompt_token_count": 30,
|
|
"candidates_token_count": 20,
|
|
"reasoning_tokens": 5,
|
|
},
|
|
UsageMetrics(
|
|
prompt_tokens=30,
|
|
completion_tokens=20,
|
|
total_tokens=50,
|
|
reasoning_tokens=5,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
# OpenAI nests cached_tokens under prompt_tokens_details
|
|
(
|
|
{
|
|
"prompt_tokens": 100,
|
|
"completion_tokens": 50,
|
|
"prompt_tokens_details": {"cached_tokens": 30},
|
|
},
|
|
UsageMetrics(
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
total_tokens=150,
|
|
cached_prompt_tokens=30,
|
|
successful_requests=1,
|
|
),
|
|
),
|
|
],
|
|
ids=[
|
|
"none",
|
|
"empty",
|
|
"all_keys",
|
|
"no_total",
|
|
"extended_keys",
|
|
"garbage",
|
|
"anthropic_aliases",
|
|
"gemini_aliases",
|
|
"openai_nested_cached",
|
|
],
|
|
)
|
|
def test_normalization(
|
|
self, usage: dict[str, Any] | None, expected: UsageMetrics | None
|
|
) -> None:
|
|
assert _usage_dict_to_metrics(usage) == expected
|
|
|
|
|
|
class TestFlowUsageAggregation:
|
|
"""End-to-end tests driving the listener through the real event bus."""
|
|
|
|
def test_sums_every_llm_call_in_the_flow(self) -> None:
|
|
"""Multiple LLM calls — including bare ``LLM.call(...)`` made outside
|
|
any crew — accumulate; ``successful_requests`` tracks the call count."""
|
|
|
|
def script(flow: Flow) -> None:
|
|
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=300, completion_tokens=300)
|
|
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=200, completion_tokens=100)
|
|
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=20, completion_tokens=20)
|
|
|
|
flow = _run(script)
|
|
|
|
assert flow.usage_metrics.total_tokens == 940
|
|
assert flow.usage_metrics.prompt_tokens == 520
|
|
assert flow.usage_metrics.completion_tokens == 420
|
|
assert flow.usage_metrics.successful_requests == 3
|
|
|
|
def test_returns_zero_when_no_calls_happen(self) -> None:
|
|
flow = _run()
|
|
assert flow.usage_metrics == UsageMetrics()
|
|
|
|
def test_ignores_events_from_other_flows(self) -> None:
|
|
"""Concurrent flow runs share the singleton bus, so the listener must
|
|
scope itself to its own flow via the contextvar match."""
|
|
|
|
def script(flow: Flow) -> None:
|
|
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=50, completion_tokens=50)
|
|
_emit_llm_call(flow_id="some-other-flow", prompt_tokens=49_000, completion_tokens=50_999)
|
|
|
|
flow = _run(script)
|
|
|
|
assert flow.usage_metrics.total_tokens == 100
|
|
assert flow.usage_metrics.successful_requests == 1
|
|
|
|
def test_resets_between_kickoffs(self) -> None:
|
|
flow = _ScriptedFlow()
|
|
flow._script = lambda f: _emit_llm_call(
|
|
flow_id=f._flow_match_id, prompt_tokens=250, completion_tokens=250
|
|
)
|
|
|
|
flow.kickoff()
|
|
flow.kickoff()
|
|
|
|
assert flow.usage_metrics.total_tokens == 500
|
|
assert flow.usage_metrics.successful_requests == 1
|
|
|
|
def test_usage_metrics_returns_independent_copy(self) -> None:
|
|
"""``usage_metrics`` must return a copy, not the internal instance —
|
|
otherwise callers can clobber the in-flight accumulator."""
|
|
|
|
flow = _run(
|
|
lambda f: _emit_llm_call(
|
|
flow_id=f._flow_match_id, prompt_tokens=50, completion_tokens=50
|
|
)
|
|
)
|
|
|
|
snapshot = flow.usage_metrics
|
|
snapshot.total_tokens = 999_999
|
|
|
|
assert flow.usage_metrics.total_tokens == 100
|
|
|
|
def test_handler_is_unregistered_after_kickoff(self) -> None:
|
|
"""Long-lived workers (Celery, devkit) must not leak one handler per
|
|
kickoff on the singleton bus, on either the success or failure path."""
|
|
|
|
def handler_count() -> int:
|
|
return len(
|
|
crewai_event_bus._sync_handlers.get(LLMCallCompletedEvent, frozenset())
|
|
)
|
|
|
|
before = handler_count()
|
|
|
|
flow = _ScriptedFlow()
|
|
flow._script = lambda f: _emit_llm_call(
|
|
flow_id=f._flow_match_id, prompt_tokens=5, completion_tokens=5
|
|
)
|
|
for _ in range(3):
|
|
flow.kickoff()
|
|
|
|
assert handler_count() == before
|
|
|
|
def boom(_f: Flow) -> None:
|
|
raise RuntimeError("boom")
|
|
|
|
failing = _ScriptedFlow()
|
|
failing._script = boom
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
failing.kickoff()
|
|
|
|
assert handler_count() == before
|
|
|
|
def test_kickoff_flushes_event_bus_before_returning(
|
|
self, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""`kickoff_async` must drain pending LLMCallCompletedEvent handlers
|
|
before detaching the listener — otherwise late handlers landing on
|
|
the threadpool would be lost on short flows. Mirrors the flush
|
|
``Crew.kickoff()`` performs before reporting ``token_usage``."""
|
|
|
|
flush_calls: list[None] = []
|
|
original_flush = crewai_event_bus.flush
|
|
|
|
def tracked_flush(*args: Any, **kwargs: Any) -> bool:
|
|
flush_calls.append(None)
|
|
return original_flush(*args, **kwargs)
|
|
|
|
monkeypatch.setattr(crewai_event_bus, "flush", tracked_flush)
|
|
|
|
flow = _ScriptedFlow()
|
|
flow._script = lambda f: _emit_llm_call(
|
|
flow_id=f._flow_match_id, prompt_tokens=3, completion_tokens=4
|
|
)
|
|
flow.kickoff()
|
|
|
|
assert flush_calls, "kickoff did not flush the event bus before returning"
|
|
assert flow.usage_metrics.total_tokens == 7
|
|
|
|
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
|
|
"""A handler still queued from a prior kickoff must not write into
|
|
a later kickoff's accumulator. The handler's closure captures its
|
|
own accumulator object, so any late writes land on an orphaned
|
|
instance and the live ``usage_metrics`` is unaffected."""
|
|
|
|
captured: dict[str, Any] = {}
|
|
|
|
def script(flow: Flow) -> None:
|
|
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
|
|
captured["handler"] = flow._usage_aggregation_handler
|
|
captured["match_id"] = flow._flow_match_id
|
|
|
|
flow = _run(script)
|
|
assert flow.usage_metrics.total_tokens == 20
|
|
|
|
flow._script = lambda f: None
|
|
flow.kickoff()
|
|
assert flow.usage_metrics.total_tokens == 0
|
|
|
|
stale_handler = captured["handler"]
|
|
assert stale_handler is not None
|
|
|
|
stale_event = LLMCallCompletedEvent(
|
|
call_id=str(uuid4()),
|
|
model="gpt-4o-mini",
|
|
response="ok",
|
|
call_type=LLMCallType.LLM_CALL,
|
|
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
|
|
)
|
|
ctx = contextvars.copy_context()
|
|
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
|
|
|
|
assert flow.usage_metrics.total_tokens == 0
|
|
|
|
def test_pause_detaches_listener_and_does_not_leak(self) -> None:
|
|
"""When ``kickoff_async`` pauses for human feedback, the listener
|
|
must be detached from the singleton bus to avoid leaking handlers
|
|
across abandoned paused instances. Pre-pause LLM events still
|
|
count because the bus snapshots handlers at emit time. Late
|
|
events emitted after the pause returns do not count for this
|
|
instance — resume paths re-attach a fresh listener."""
|
|
|
|
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
|
|
|
captured: dict[str, Any] = {}
|
|
|
|
class _PausingFlow(Flow):
|
|
@start()
|
|
def begin(self) -> None:
|
|
_emit_llm_call(
|
|
flow_id=self._flow_match_id,
|
|
prompt_tokens=10,
|
|
completion_tokens=20,
|
|
)
|
|
captured["pre_pause_total"] = self.usage_metrics.total_tokens
|
|
raise HumanFeedbackPending(
|
|
context=PendingFeedbackContext(
|
|
flow_id=self.flow_id,
|
|
flow_class="_PausingFlow",
|
|
method_name="begin",
|
|
method_output="content",
|
|
message="Review:",
|
|
)
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
|
flow = _PausingFlow(persistence=persistence)
|
|
result = flow.kickoff()
|
|
|
|
assert isinstance(result, HumanFeedbackPending)
|
|
assert captured["pre_pause_total"] == 30
|
|
assert flow._usage_aggregation_handler is None
|
|
|
|
# A late event emitted after the pause does not reach the
|
|
# detached listener, so the running total is unchanged.
|
|
_emit_llm_call(
|
|
flow_id=flow._flow_match_id,
|
|
prompt_tokens=2,
|
|
completion_tokens=3,
|
|
)
|
|
assert flow.usage_metrics.total_tokens == 30
|
|
|
|
def test_aggregates_resume_after_from_pending(self) -> None:
|
|
"""A flow restored via ``from_pending`` is a fresh instance with no
|
|
``_flow_match_id``; without seeding it, the listener attached in
|
|
``resume_async`` either ignores its own LLM calls or absorbs unrelated
|
|
ones. ``from_pending`` must seed the match id so the resume-phase
|
|
aggregator counts our own calls and only our own calls."""
|
|
|
|
class _ResumeFlow(Flow):
|
|
@start()
|
|
def begin(self) -> str:
|
|
return "content"
|
|
|
|
@listen(begin)
|
|
def on_begin(self, _feedback: Any) -> str:
|
|
_emit_llm_call(
|
|
flow_id=self._flow_match_id,
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
)
|
|
_emit_llm_call(
|
|
flow_id="some-other-flow",
|
|
prompt_tokens=9_999,
|
|
completion_tokens=9_999,
|
|
)
|
|
return "done"
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
|
flow_id = "usage-resume-test"
|
|
persistence.save_pending_feedback(
|
|
flow_uuid=flow_id,
|
|
context=PendingFeedbackContext(
|
|
flow_id=flow_id,
|
|
flow_class="_ResumeFlow",
|
|
method_name="begin",
|
|
method_output="content",
|
|
message="Review:",
|
|
),
|
|
state_data={"id": flow_id},
|
|
)
|
|
|
|
flow = _ResumeFlow.from_pending(flow_id, persistence)
|
|
assert flow._flow_match_id == flow.flow_id
|
|
|
|
flow.resume("ok")
|
|
|
|
assert flow.usage_metrics.total_tokens == 150
|
|
assert flow.usage_metrics.prompt_tokens == 100
|
|
assert flow.usage_metrics.completion_tokens == 50
|
|
assert flow.usage_metrics.successful_requests == 1
|
|
|
|
def test_resume_aggregates_under_foreign_flow_context(self) -> None:
|
|
"""Resume must override an already-set ``current_flow_id`` so its
|
|
own LLM events match the listener's filter even when invoked from
|
|
inside another flow's active context."""
|
|
|
|
class _ResumeFlow(Flow):
|
|
@start()
|
|
def begin(self) -> str:
|
|
return "content"
|
|
|
|
@listen(begin)
|
|
def on_begin(self, _feedback: Any) -> str:
|
|
_emit_llm_call(
|
|
flow_id=self._flow_match_id,
|
|
prompt_tokens=42,
|
|
completion_tokens=8,
|
|
)
|
|
return "done"
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
|
flow_id = "resume-foreign-context"
|
|
persistence.save_pending_feedback(
|
|
flow_uuid=flow_id,
|
|
context=PendingFeedbackContext(
|
|
flow_id=flow_id,
|
|
flow_class="_ResumeFlow",
|
|
method_name="begin",
|
|
method_output="content",
|
|
message="Review:",
|
|
),
|
|
state_data={"id": flow_id},
|
|
)
|
|
|
|
foreign_token = current_flow_id.set("some-parent-flow")
|
|
try:
|
|
flow = _ResumeFlow.from_pending(flow_id, persistence)
|
|
flow.resume("ok")
|
|
finally:
|
|
current_flow_id.reset(foreign_token)
|
|
|
|
assert flow.usage_metrics.total_tokens == 50
|
|
assert flow.usage_metrics.successful_requests == 1
|