fix: keep usage listener on pause, guard cross-kickoff handler leakage

Two related races in flow.usage_metrics aggregation:

1. Paused kickoff dropped its listener. When kickoff_async returned a
   HumanFeedbackPending, the finally detached the listener even though
   the bus dispatches LLM event handlers on a thread pool that emit
   does not wait on. Any pre-pause LLM call whose handler future was
   still queued would silently lose its tokens.

   Fix: track a paused_for_feedback flag in kickoff_async and skip the
   detach when set. The listener stays attached on the instance so late
   events continue to accumulate.

2. Stale handlers from a prior kickoff could bleed into a later one.
   The handler closure captured flow_ref and wrote into
   flow._aggregated_usage_metrics. If a handler from kickoff #1 was
   still queued when kickoff #2 reset the accumulator, it would
   contaminate the new run's totals.

   Fix: snapshot a per-flow _usage_epoch in the handler closure at
   attach time. kickoff_async bumps the epoch before attaching, so any
   in-flight handler from a prior kickoff sees its stale snapshot and
   bails out.
This commit is contained in:
Lucas Gomide
2026-06-11 14:29:08 -03:00
parent 8565713a1a
commit a64b41dd42
2 changed files with 134 additions and 13 deletions

View File

@@ -941,6 +941,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
_usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_flow_match_id: str | None = PrivateAttr(default=None)
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
# Incremented on every kickoff that takes ownership of usage aggregation.
# The listener closure snapshots the epoch at attach time; a stale
# handler still queued in the bus thread pool from a prior kickoff
# compares its snapshot against the current value and bails out so it
# cannot contaminate a later kickoff's accumulator.
_usage_epoch: int = PrivateAttr(default=0)
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
@@ -1011,8 +1017,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return
flow_ref = self
captured_epoch = self._usage_epoch
def _accumulate(source: Any, event: LLMCallCompletedEvent) -> None:
# Stale-handler guard: the bus dispatches sync handlers on a
# thread pool that `emit` does not wait on, so a handler from
# a prior kickoff can still be queued when a later kickoff
# bumps the epoch and resets the accumulator. Bail out so we
# don't leak prior-run usage into the new accumulator.
if captured_epoch != flow_ref._usage_epoch:
return
if current_flow_id.get() != flow_ref._flow_match_id:
return
metrics = _usage_dict_to_metrics(event.usage)
@@ -1436,18 +1450,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"No pending feedback context. Use from_pending() to restore a paused flow."
)
# Re-attach the usage aggregation listener for the resume phase so
# LLM calls during outcome collapsing and downstream crews continue
# to roll up into `flow.usage_metrics`. The listener was detached
# when `kickoff_async` returned on pause. We also restore
# `current_flow_id` to the original kickoff's match id so the handler
# filter passes.
owns_usage_aggregation = self._usage_aggregation_handler is None
# Wire usage aggregation for the resume phase. Two cases:
# 1. We inherited an attached listener from a `kickoff_async`
# that paused — keep counting into the same accumulator.
# 2. The instance came from `from_pending` (fresh) — attach
# a new listener.
# In both cases we restore `current_flow_id` so the handler's
# filter passes for LLM calls made during outcome collapsing and
# downstream listener execution.
flow_id_token = None
if owns_usage_aggregation:
if current_flow_id.get() is None and self._flow_match_id is not None:
flow_id_token = current_flow_id.set(self._flow_match_id)
self._attach_usage_aggregation_listener()
if current_flow_id.get() is None and self._flow_match_id is not None:
flow_id_token = current_flow_id.set(self._flow_match_id)
self._attach_usage_aggregation_listener()
try:
if get_current_parent_id() is None:
@@ -1650,7 +1664,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return final_result
finally:
if owns_usage_aggregation:
# If we re-paused for human feedback, leave the listener
# attached so the next `resume_async` can take over.
# Otherwise (completion or unexpected error), release it.
if self._pending_feedback_context is None:
self._detach_usage_aggregation_listener()
if flow_id_token is not None:
current_flow_id.reset(flow_id_token)
@@ -2168,8 +2185,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# gets overwritten later by `inputs["id"]`.
self._flow_match_id = current_flow_id.get()
self._aggregated_usage_metrics = UsageMetrics()
# Bump the epoch BEFORE attaching so any in-flight handler from
# a prior kickoff queued in the bus thread pool sees its stale
# snapshot and bails out instead of writing into the fresh
# accumulator.
self._usage_epoch += 1
self._attach_usage_aggregation_listener()
# Flips in the `HumanFeedbackPending` branch so `finally` keeps the
# listener attached. Late LLM events during the pause window and
# the subsequent `resume_async` call continue to accumulate into
# this run's `flow.usage_metrics`.
paused_for_feedback = False
try:
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = (
@@ -2354,6 +2382,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
paused_for_feedback = True
# Auto-save pending feedback (create default persistence if needed)
if self.persistence is None:
from crewai.flow.persistence.factory import (
@@ -2459,7 +2488,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# Ensure all background memory saves complete before returning
if self.memory is not None and hasattr(self.memory, "drain_writes"):
self.memory.drain_writes()
if owns_usage_aggregation:
# On pause keep the listener attached so events during the
# pause-to-resume window still count and `resume_async` can
# take over the same accumulator. Otherwise (completion or
# unexpected error) release it.
if owns_usage_aggregation and not paused_for_feedback:
self._detach_usage_aggregation_listener()
if request_id_token is not None:
current_flow_request_id.reset(request_id_token)

View File

@@ -255,6 +255,94 @@ class TestFlowUsageAggregation:
assert handler_count() == before
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
"""The bus dispatches sync handlers on a thread pool that ``emit``
does not wait on. A handler still queued from a prior kickoff
must not write into a later kickoff's accumulator — the epoch
snapshot in the handler closure bails out on mismatch."""
captured: dict[str, Any] = {}
def script(flow: Flow) -> None:
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
captured["handler"] = flow._usage_aggregation_handler
captured["match_id"] = flow._flow_match_id
flow = _run(script)
first_total = flow.usage_metrics.total_tokens
assert first_total == 20
# A second kickoff bumps the epoch and resets the accumulator.
flow._script = lambda f: None
flow.kickoff()
assert flow.usage_metrics.total_tokens == 0
stale_handler = captured["handler"]
assert stale_handler is not None
stale_event = LLMCallCompletedEvent(
call_id=str(uuid4()),
model="gpt-4o-mini",
response="ok",
call_type=LLMCallType.LLM_CALL,
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
)
ctx = contextvars.copy_context()
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
# Stale handler bailed: second kickoff's accumulator is still zero.
assert flow.usage_metrics.total_tokens == 0
def test_listener_persists_after_pause(self) -> None:
"""When ``kickoff_async`` pauses for human feedback, the listener
must stay attached so late LLM events (queued in the bus thread
pool by pre-pause LLM calls that emit but don't wait on their
handler future) still count for this run. Otherwise the pause's
``finally`` would detach the listener and silently drop them."""
from crewai.flow.async_feedback.types import HumanFeedbackPending
captured: dict[str, Any] = {}
class _PausingFlow(Flow):
@start()
def begin(self) -> None:
_emit_llm_call(
flow_id=self._flow_match_id,
prompt_tokens=10,
completion_tokens=20,
)
captured["pre_pause_total"] = self.usage_metrics.total_tokens
raise HumanFeedbackPending(
context=PendingFeedbackContext(
flow_id=self.flow_id,
flow_class="_PausingFlow",
method_name="begin",
method_output="content",
message="Review:",
)
)
with tempfile.TemporaryDirectory() as tmpdir:
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
flow = _PausingFlow(persistence=persistence)
result = flow.kickoff()
assert isinstance(result, HumanFeedbackPending)
assert captured["pre_pause_total"] == 30
assert flow._usage_aggregation_handler is not None
# Simulate a late LLM event arriving after the pause — without
# the keep-on-pause fix this would be dropped silently.
_emit_llm_call(
flow_id=flow._flow_match_id,
prompt_tokens=2,
completion_tokens=3,
)
assert flow.usage_metrics.total_tokens == 35
flow._detach_usage_aggregation_listener()
def test_aggregates_resume_after_from_pending(self) -> None:
"""A flow restored via ``from_pending`` is a fresh instance with no
``_flow_match_id``; without seeding it, the listener attached in