fix: address PR review on flow.usage_metrics

This commit is contained in:
Lucas Gomide
2026-06-11 16:39:07 -03:00
parent b720139eca
commit c48501ae38
2 changed files with 208 additions and 16 deletions

View File

@@ -941,6 +941,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
_usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_flow_match_id: str | None = PrivateAttr(default=None)
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
# Incremented on every kickoff that takes ownership of usage aggregation.
# The listener closure snapshots the epoch at attach time; a stale
# handler still queued in the bus thread pool from a prior kickoff
# compares its snapshot against the current value and bails out so it
# cannot contaminate a later kickoff's accumulator.
_usage_epoch: int = PrivateAttr(default=0)
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
@@ -1011,8 +1017,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return
flow_ref = self
captured_epoch = self._usage_epoch
def _accumulate(source: Any, event: LLMCallCompletedEvent) -> None:
# Stale-handler guard: the bus dispatches sync handlers on a
# thread pool that `emit` does not wait on, so a handler from
# a prior kickoff can still be queued when a later kickoff
# bumps the epoch and resets the accumulator. Bail out so we
# don't leak prior-run usage into the new accumulator.
if captured_epoch != flow_ref._usage_epoch:
return
if current_flow_id.get() != flow_ref._flow_match_id:
return
metrics = _usage_dict_to_metrics(event.usage)
@@ -1033,6 +1047,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
@property
def usage_metrics(self) -> UsageMetrics:
"""Aggregated LLM token usage for the most recent kickoff (or
resume) of this flow instance.
Aggregation is correlated by the ``current_flow_id`` contextvar
captured at kickoff time. Nested kickoffs (a parent flow calling
a child flow's ``kickoff``) intentionally roll the child's
tokens up into the parent because the contextvar is inherited.
Sibling kickoffs that run in parallel under the same parent
contextvar share the same correlation id and may therefore
over-count each other; if you need strict per-flow isolation
in that pattern, run the children in separate tasks that
explicitly set their own ``current_flow_id`` before kickoff.
"""
with self._usage_metrics_lock:
return self._aggregated_usage_metrics.model_copy()
@@ -1330,6 +1357,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
instance._initialize_state(state_data)
instance._pending_feedback_context = pending_context
instance._is_execution_resuming = True
# Seed the usage-aggregation match id so `resume_async` can wire its
# listener and restore `current_flow_id` correctly. Without this,
# a restored flow has a None match id and the handler would either
# ignore its own LLM calls or absorb unrelated ones from sibling
# flows. The accumulator itself starts at zero — any usage from
# before the pause was only observable on the original kickoff
# instance.
instance._flow_match_id = instance.flow_id
return instance
@@ -1428,18 +1463,17 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"No pending feedback context. Use from_pending() to restore a paused flow."
)
# Re-attach the usage aggregation listener for the resume phase so
# LLM calls during outcome collapsing and downstream crews continue
# to roll up into `flow.usage_metrics`. The listener was detached
# when `kickoff_async` returned on pause. We also restore
# `current_flow_id` to the original kickoff's match id so the handler
# filter passes.
owns_usage_aggregation = self._usage_aggregation_handler is None
# Wire usage aggregation for the resume phase. The previous
# `kickoff_async`/`resume_async` already detached its listener,
# and `_attach_usage_aggregation_listener` is idempotent, so we
# can always call it here. We also restore `current_flow_id`
# when missing so the handler's filter passes for LLM calls
# made during outcome collapsing and downstream listener
# execution.
flow_id_token = None
if owns_usage_aggregation:
if current_flow_id.get() is None and self._flow_match_id is not None:
flow_id_token = current_flow_id.set(self._flow_match_id)
self._attach_usage_aggregation_listener()
if current_flow_id.get() is None and self._flow_match_id is not None:
flow_id_token = current_flow_id.set(self._flow_match_id)
self._attach_usage_aggregation_listener()
try:
if get_current_parent_id() is None:
@@ -1642,10 +1676,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return final_result
finally:
if owns_usage_aggregation:
self._detach_usage_aggregation_listener()
if flow_id_token is not None:
current_flow_id.reset(flow_id_token)
# Always release the singleton-bus listener at the end of
# this resume call. In-flight handlers the bus already
# snapshotted continue updating `_aggregated_usage_metrics`;
# only future emits stop reaching this flow. A subsequent
# `resume_async` on the same instance re-attaches a fresh
# listener.
self._detach_usage_aggregation_listener()
if flow_id_token is not None:
current_flow_id.reset(flow_id_token)
def _create_initial_state(self) -> T:
"""Create and initialize flow state with UUID and default values.
@@ -2160,6 +2199,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# gets overwritten later by `inputs["id"]`.
self._flow_match_id = current_flow_id.get()
self._aggregated_usage_metrics = UsageMetrics()
# Bump the epoch BEFORE attaching so any in-flight handler from
# a prior kickoff queued in the bus thread pool sees its stale
# snapshot and bails out instead of writing into the fresh
# accumulator.
self._usage_epoch += 1
self._attach_usage_aggregation_listener()
try:
@@ -2451,6 +2495,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# Ensure all background memory saves complete before returning
if self.memory is not None and hasattr(self.memory, "drain_writes"):
self.memory.drain_writes()
# Detach the singleton-bus listener as soon as this kickoff
# finishes (whether by completion, pause, or error). Handlers
# the bus already snapshotted via `_sync_executor.submit` keep
# running and updating `_aggregated_usage_metrics`; only
# subsequent emits stop reaching this flow. Resume paths
# re-attach a fresh listener via `resume_async`.
if owns_usage_aggregation:
self._detach_usage_aggregation_listener()
if request_id_token is not None:

View File

@@ -10,6 +10,8 @@ explicit contextvar control; no live LLM provider is required.
from __future__ import annotations
import contextvars
import os
import tempfile
from typing import Any, Callable
from uuid import uuid4
@@ -17,8 +19,10 @@ import pytest
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
from crewai.flow.flow import Flow, start
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.flow.flow import Flow, listen, start
from crewai.flow.flow_context import current_flow_id
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from crewai.flow.runtime import _usage_dict_to_metrics
from crewai.types.usage_metrics import UsageMetrics
@@ -250,3 +254,141 @@ class TestFlowUsageAggregation:
failing.kickoff()
assert handler_count() == before
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
"""The bus dispatches sync handlers on a thread pool that ``emit``
does not wait on. A handler still queued from a prior kickoff
must not write into a later kickoff's accumulator — the epoch
snapshot in the handler closure bails out on mismatch."""
captured: dict[str, Any] = {}
def script(flow: Flow) -> None:
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
captured["handler"] = flow._usage_aggregation_handler
captured["match_id"] = flow._flow_match_id
flow = _run(script)
first_total = flow.usage_metrics.total_tokens
assert first_total == 20
# A second kickoff bumps the epoch and resets the accumulator.
flow._script = lambda f: None
flow.kickoff()
assert flow.usage_metrics.total_tokens == 0
stale_handler = captured["handler"]
assert stale_handler is not None
stale_event = LLMCallCompletedEvent(
call_id=str(uuid4()),
model="gpt-4o-mini",
response="ok",
call_type=LLMCallType.LLM_CALL,
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
)
ctx = contextvars.copy_context()
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
# Stale handler bailed: second kickoff's accumulator is still zero.
assert flow.usage_metrics.total_tokens == 0
def test_pause_detaches_listener_and_does_not_leak(self) -> None:
"""When ``kickoff_async`` pauses for human feedback, the listener
must be detached from the singleton bus to avoid leaking handlers
across abandoned paused instances. Pre-pause LLM events still
count because the bus snapshots handlers at emit time. Late
events emitted after the pause returns do not count for this
instance — resume paths re-attach a fresh listener."""
from crewai.flow.async_feedback.types import HumanFeedbackPending
captured: dict[str, Any] = {}
class _PausingFlow(Flow):
@start()
def begin(self) -> None:
_emit_llm_call(
flow_id=self._flow_match_id,
prompt_tokens=10,
completion_tokens=20,
)
captured["pre_pause_total"] = self.usage_metrics.total_tokens
raise HumanFeedbackPending(
context=PendingFeedbackContext(
flow_id=self.flow_id,
flow_class="_PausingFlow",
method_name="begin",
method_output="content",
message="Review:",
)
)
with tempfile.TemporaryDirectory() as tmpdir:
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
flow = _PausingFlow(persistence=persistence)
result = flow.kickoff()
assert isinstance(result, HumanFeedbackPending)
assert captured["pre_pause_total"] == 30
assert flow._usage_aggregation_handler is None
# A late event emitted after the pause does not reach the
# detached listener, so the running total is unchanged.
_emit_llm_call(
flow_id=flow._flow_match_id,
prompt_tokens=2,
completion_tokens=3,
)
assert flow.usage_metrics.total_tokens == 30
def test_aggregates_resume_after_from_pending(self) -> None:
"""A flow restored via ``from_pending`` is a fresh instance with no
``_flow_match_id``; without seeding it, the listener attached in
``resume_async`` either ignores its own LLM calls or absorbs unrelated
ones. ``from_pending`` must seed the match id so the resume-phase
aggregator counts our own calls and only our own calls."""
class _ResumeFlow(Flow):
@start()
def begin(self) -> str:
return "content"
@listen(begin)
def on_begin(self, _feedback: Any) -> str:
_emit_llm_call(
flow_id=self._flow_match_id,
prompt_tokens=100,
completion_tokens=50,
)
_emit_llm_call(
flow_id="some-other-flow",
prompt_tokens=9_999,
completion_tokens=9_999,
)
return "done"
with tempfile.TemporaryDirectory() as tmpdir:
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
flow_id = "usage-resume-test"
persistence.save_pending_feedback(
flow_uuid=flow_id,
context=PendingFeedbackContext(
flow_id=flow_id,
flow_class="_ResumeFlow",
method_name="begin",
method_output="content",
message="Review:",
),
state_data={"id": flow_id},
)
flow = _ResumeFlow.from_pending(flow_id, persistence)
assert flow._flow_match_id == flow.flow_id
flow.resume("ok")
assert flow.usage_metrics.total_tokens == 150
assert flow.usage_metrics.prompt_tokens == 100
assert flow.usage_metrics.completion_tokens == 50
assert flow.usage_metrics.successful_requests == 1