mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
fix: keep usage listener on pause, guard cross-kickoff handler leakage
Two related races in flow.usage_metrics aggregation: 1. Paused kickoff dropped its listener. When kickoff_async returned a HumanFeedbackPending, the finally detached the listener even though the bus dispatches LLM event handlers on a thread pool that emit does not wait on. Any pre-pause LLM call whose handler future was still queued would silently lose its tokens. Fix: track a paused_for_feedback flag in kickoff_async and skip the detach when set. The listener stays attached on the instance so late events continue to accumulate. 2. Stale handlers from a prior kickoff could bleed into a later one. The handler closure captured flow_ref and wrote into flow._aggregated_usage_metrics. If a handler from kickoff #1 was still queued when kickoff #2 reset the accumulator, it would contaminate the new run's totals. Fix: snapshot a per-flow _usage_epoch in the handler closure at attach time. kickoff_async bumps the epoch before attaching, so any in-flight handler from a prior kickoff sees its stale snapshot and bails out.
This commit is contained in:
@@ -941,6 +941,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_flow_match_id: str | None = PrivateAttr(default=None)
|
||||
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
|
||||
# Incremented on every kickoff that takes ownership of usage aggregation.
|
||||
# The listener closure snapshots the epoch at attach time; a stale
|
||||
# handler still queued in the bus thread pool from a prior kickoff
|
||||
# compares its snapshot against the current value and bails out so it
|
||||
# cannot contaminate a later kickoff's accumulator.
|
||||
_usage_epoch: int = PrivateAttr(default=0)
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
|
||||
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
|
||||
@@ -1011,8 +1017,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return
|
||||
|
||||
flow_ref = self
|
||||
captured_epoch = self._usage_epoch
|
||||
|
||||
def _accumulate(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
# Stale-handler guard: the bus dispatches sync handlers on a
|
||||
# thread pool that `emit` does not wait on, so a handler from
|
||||
# a prior kickoff can still be queued when a later kickoff
|
||||
# bumps the epoch and resets the accumulator. Bail out so we
|
||||
# don't leak prior-run usage into the new accumulator.
|
||||
if captured_epoch != flow_ref._usage_epoch:
|
||||
return
|
||||
if current_flow_id.get() != flow_ref._flow_match_id:
|
||||
return
|
||||
metrics = _usage_dict_to_metrics(event.usage)
|
||||
@@ -1436,18 +1450,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"No pending feedback context. Use from_pending() to restore a paused flow."
|
||||
)
|
||||
|
||||
# Re-attach the usage aggregation listener for the resume phase so
|
||||
# LLM calls during outcome collapsing and downstream crews continue
|
||||
# to roll up into `flow.usage_metrics`. The listener was detached
|
||||
# when `kickoff_async` returned on pause. We also restore
|
||||
# `current_flow_id` to the original kickoff's match id so the handler
|
||||
# filter passes.
|
||||
owns_usage_aggregation = self._usage_aggregation_handler is None
|
||||
# Wire usage aggregation for the resume phase. Two cases:
|
||||
# 1. We inherited an attached listener from a `kickoff_async`
|
||||
# that paused — keep counting into the same accumulator.
|
||||
# 2. The instance came from `from_pending` (fresh) — attach
|
||||
# a new listener.
|
||||
# In both cases we restore `current_flow_id` so the handler's
|
||||
# filter passes for LLM calls made during outcome collapsing and
|
||||
# downstream listener execution.
|
||||
flow_id_token = None
|
||||
if owns_usage_aggregation:
|
||||
if current_flow_id.get() is None and self._flow_match_id is not None:
|
||||
flow_id_token = current_flow_id.set(self._flow_match_id)
|
||||
self._attach_usage_aggregation_listener()
|
||||
if current_flow_id.get() is None and self._flow_match_id is not None:
|
||||
flow_id_token = current_flow_id.set(self._flow_match_id)
|
||||
self._attach_usage_aggregation_listener()
|
||||
|
||||
try:
|
||||
if get_current_parent_id() is None:
|
||||
@@ -1650,7 +1664,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return final_result
|
||||
finally:
|
||||
if owns_usage_aggregation:
|
||||
# If we re-paused for human feedback, leave the listener
|
||||
# attached so the next `resume_async` can take over.
|
||||
# Otherwise (completion or unexpected error), release it.
|
||||
if self._pending_feedback_context is None:
|
||||
self._detach_usage_aggregation_listener()
|
||||
if flow_id_token is not None:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
@@ -2168,8 +2185,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# gets overwritten later by `inputs["id"]`.
|
||||
self._flow_match_id = current_flow_id.get()
|
||||
self._aggregated_usage_metrics = UsageMetrics()
|
||||
# Bump the epoch BEFORE attaching so any in-flight handler from
|
||||
# a prior kickoff queued in the bus thread pool sees its stale
|
||||
# snapshot and bails out instead of writing into the fresh
|
||||
# accumulator.
|
||||
self._usage_epoch += 1
|
||||
self._attach_usage_aggregation_listener()
|
||||
|
||||
# Flips in the `HumanFeedbackPending` branch so `finally` keeps the
|
||||
# listener attached. Late LLM events during the pause window and
|
||||
# the subsequent `resume_async` call continue to accumulate into
|
||||
# this run's `flow.usage_metrics`.
|
||||
paused_for_feedback = False
|
||||
|
||||
try:
|
||||
# Reset flow state for fresh execution unless restoring from persistence
|
||||
is_restoring = (
|
||||
@@ -2354,6 +2382,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
paused_for_feedback = True
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence.factory import (
|
||||
@@ -2459,7 +2488,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# Ensure all background memory saves complete before returning
|
||||
if self.memory is not None and hasattr(self.memory, "drain_writes"):
|
||||
self.memory.drain_writes()
|
||||
if owns_usage_aggregation:
|
||||
# On pause keep the listener attached so events during the
|
||||
# pause-to-resume window still count and `resume_async` can
|
||||
# take over the same accumulator. Otherwise (completion or
|
||||
# unexpected error) release it.
|
||||
if owns_usage_aggregation and not paused_for_feedback:
|
||||
self._detach_usage_aggregation_listener()
|
||||
if request_id_token is not None:
|
||||
current_flow_request_id.reset(request_id_token)
|
||||
|
||||
@@ -255,6 +255,94 @@ class TestFlowUsageAggregation:
|
||||
|
||||
assert handler_count() == before
|
||||
|
||||
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
|
||||
"""The bus dispatches sync handlers on a thread pool that ``emit``
|
||||
does not wait on. A handler still queued from a prior kickoff
|
||||
must not write into a later kickoff's accumulator — the epoch
|
||||
snapshot in the handler closure bails out on mismatch."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def script(flow: Flow) -> None:
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
|
||||
captured["handler"] = flow._usage_aggregation_handler
|
||||
captured["match_id"] = flow._flow_match_id
|
||||
|
||||
flow = _run(script)
|
||||
first_total = flow.usage_metrics.total_tokens
|
||||
assert first_total == 20
|
||||
|
||||
# A second kickoff bumps the epoch and resets the accumulator.
|
||||
flow._script = lambda f: None
|
||||
flow.kickoff()
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
stale_handler = captured["handler"]
|
||||
assert stale_handler is not None
|
||||
|
||||
stale_event = LLMCallCompletedEvent(
|
||||
call_id=str(uuid4()),
|
||||
model="gpt-4o-mini",
|
||||
response="ok",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
|
||||
|
||||
# Stale handler bailed: second kickoff's accumulator is still zero.
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
def test_listener_persists_after_pause(self) -> None:
|
||||
"""When ``kickoff_async`` pauses for human feedback, the listener
|
||||
must stay attached so late LLM events (queued in the bus thread
|
||||
pool by pre-pause LLM calls that emit but don't wait on their
|
||||
handler future) still count for this run. Otherwise the pause's
|
||||
``finally`` would detach the listener and silently drop them."""
|
||||
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class _PausingFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> None:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
)
|
||||
captured["pre_pause_total"] = self.usage_metrics.total_tokens
|
||||
raise HumanFeedbackPending(
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=self.flow_id,
|
||||
flow_class="_PausingFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
)
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow = _PausingFlow(persistence=persistence)
|
||||
result = flow.kickoff()
|
||||
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
assert captured["pre_pause_total"] == 30
|
||||
assert flow._usage_aggregation_handler is not None
|
||||
|
||||
# Simulate a late LLM event arriving after the pause — without
|
||||
# the keep-on-pause fix this would be dropped silently.
|
||||
_emit_llm_call(
|
||||
flow_id=flow._flow_match_id,
|
||||
prompt_tokens=2,
|
||||
completion_tokens=3,
|
||||
)
|
||||
assert flow.usage_metrics.total_tokens == 35
|
||||
|
||||
flow._detach_usage_aggregation_listener()
|
||||
|
||||
def test_aggregates_resume_after_from_pending(self) -> None:
|
||||
"""A flow restored via ``from_pending`` is a fresh instance with no
|
||||
``_flow_match_id``; without seeding it, the listener attached in
|
||||
|
||||
Reference in New Issue
Block a user