diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index 559484a1d..92b78d12d 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -941,6 +941,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): _usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) _flow_match_id: str | None = PrivateAttr(default=None) _usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None) + # Incremented on every kickoff that takes ownership of usage aggregation. + # The listener closure snapshots the epoch at attach time; a stale + # handler still queued in the bus thread pool from a prior kickoff + # compares its snapshot against the current value and bails out so it + # cannot contaminate a later kickoff's accumulator. + _usage_epoch: int = PrivateAttr(default=0) def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override] class _FlowGeneric(cls): # type: ignore[valid-type,misc] @@ -1011,8 +1017,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): return flow_ref = self + captured_epoch = self._usage_epoch def _accumulate(source: Any, event: LLMCallCompletedEvent) -> None: + # Stale-handler guard: the bus dispatches sync handlers on a + # thread pool that `emit` does not wait on, so a handler from + # a prior kickoff can still be queued when a later kickoff + # bumps the epoch and resets the accumulator. Bail out so we + # don't leak prior-run usage into the new accumulator. + if captured_epoch != flow_ref._usage_epoch: + return if current_flow_id.get() != flow_ref._flow_match_id: return metrics = _usage_dict_to_metrics(event.usage) @@ -1033,6 +1047,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): @property def usage_metrics(self) -> UsageMetrics: + """Aggregated LLM token usage for the most recent kickoff (or + resume) of this flow instance. + + Aggregation is correlated by the ``current_flow_id`` contextvar + captured at kickoff time. Nested kickoffs (a parent flow calling + a child flow's ``kickoff``) intentionally roll the child's + tokens up into the parent because the contextvar is inherited. + Sibling kickoffs that run in parallel under the same parent + contextvar share the same correlation id and may therefore + over-count each other; if you need strict per-flow isolation + in that pattern, run the children in separate tasks that + explicitly set their own ``current_flow_id`` before kickoff. + """ with self._usage_metrics_lock: return self._aggregated_usage_metrics.model_copy() @@ -1330,6 +1357,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): instance._initialize_state(state_data) instance._pending_feedback_context = pending_context instance._is_execution_resuming = True + # Seed the usage-aggregation match id so `resume_async` can wire its + # listener and restore `current_flow_id` correctly. Without this, + # a restored flow has a None match id and the handler would either + # ignore its own LLM calls or absorb unrelated ones from sibling + # flows. The accumulator itself starts at zero — any usage from + # before the pause was only observable on the original kickoff + # instance. + instance._flow_match_id = instance.flow_id return instance @@ -1428,18 +1463,17 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): "No pending feedback context. Use from_pending() to restore a paused flow." ) - # Re-attach the usage aggregation listener for the resume phase so - # LLM calls during outcome collapsing and downstream crews continue - # to roll up into `flow.usage_metrics`. The listener was detached - # when `kickoff_async` returned on pause. We also restore - # `current_flow_id` to the original kickoff's match id so the handler - # filter passes. - owns_usage_aggregation = self._usage_aggregation_handler is None + # Wire usage aggregation for the resume phase. The previous + # `kickoff_async`/`resume_async` already detached its listener, + # and `_attach_usage_aggregation_listener` is idempotent, so we + # can always call it here. We also restore `current_flow_id` + # when missing so the handler's filter passes for LLM calls + # made during outcome collapsing and downstream listener + # execution. flow_id_token = None - if owns_usage_aggregation: - if current_flow_id.get() is None and self._flow_match_id is not None: - flow_id_token = current_flow_id.set(self._flow_match_id) - self._attach_usage_aggregation_listener() + if current_flow_id.get() is None and self._flow_match_id is not None: + flow_id_token = current_flow_id.set(self._flow_match_id) + self._attach_usage_aggregation_listener() try: if get_current_parent_id() is None: @@ -1642,10 +1676,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): return final_result finally: - if owns_usage_aggregation: - self._detach_usage_aggregation_listener() - if flow_id_token is not None: - current_flow_id.reset(flow_id_token) + # Always release the singleton-bus listener at the end of + # this resume call. In-flight handlers the bus already + # snapshotted continue updating `_aggregated_usage_metrics`; + # only future emits stop reaching this flow. A subsequent + # `resume_async` on the same instance re-attaches a fresh + # listener. + self._detach_usage_aggregation_listener() + if flow_id_token is not None: + current_flow_id.reset(flow_id_token) def _create_initial_state(self) -> T: """Create and initialize flow state with UUID and default values. @@ -2160,6 +2199,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): # gets overwritten later by `inputs["id"]`. self._flow_match_id = current_flow_id.get() self._aggregated_usage_metrics = UsageMetrics() + # Bump the epoch BEFORE attaching so any in-flight handler from + # a prior kickoff queued in the bus thread pool sees its stale + # snapshot and bails out instead of writing into the fresh + # accumulator. + self._usage_epoch += 1 self._attach_usage_aggregation_listener() try: @@ -2451,6 +2495,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): # Ensure all background memory saves complete before returning if self.memory is not None and hasattr(self.memory, "drain_writes"): self.memory.drain_writes() + # Detach the singleton-bus listener as soon as this kickoff + # finishes (whether by completion, pause, or error). Handlers + # the bus already snapshotted via `_sync_executor.submit` keep + # running and updating `_aggregated_usage_metrics`; only + # subsequent emits stop reaching this flow. Resume paths + # re-attach a fresh listener via `resume_async`. if owns_usage_aggregation: self._detach_usage_aggregation_listener() if request_id_token is not None: diff --git a/lib/crewai/tests/test_flow_usage_metrics.py b/lib/crewai/tests/test_flow_usage_metrics.py index 30b412398..ac6c4875d 100644 --- a/lib/crewai/tests/test_flow_usage_metrics.py +++ b/lib/crewai/tests/test_flow_usage_metrics.py @@ -10,6 +10,8 @@ explicit contextvar control; no live LLM provider is required. from __future__ import annotations import contextvars +import os +import tempfile from typing import Any, Callable from uuid import uuid4 @@ -17,8 +19,10 @@ import pytest from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType -from crewai.flow.flow import Flow, start +from crewai.flow.async_feedback.types import PendingFeedbackContext +from crewai.flow.flow import Flow, listen, start from crewai.flow.flow_context import current_flow_id +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence from crewai.flow.runtime import _usage_dict_to_metrics from crewai.types.usage_metrics import UsageMetrics @@ -250,3 +254,141 @@ class TestFlowUsageAggregation: failing.kickoff() assert handler_count() == before + + def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None: + """The bus dispatches sync handlers on a thread pool that ``emit`` + does not wait on. A handler still queued from a prior kickoff + must not write into a later kickoff's accumulator — the epoch + snapshot in the handler closure bails out on mismatch.""" + + captured: dict[str, Any] = {} + + def script(flow: Flow) -> None: + _emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10) + captured["handler"] = flow._usage_aggregation_handler + captured["match_id"] = flow._flow_match_id + + flow = _run(script) + first_total = flow.usage_metrics.total_tokens + assert first_total == 20 + + # A second kickoff bumps the epoch and resets the accumulator. + flow._script = lambda f: None + flow.kickoff() + assert flow.usage_metrics.total_tokens == 0 + + stale_handler = captured["handler"] + assert stale_handler is not None + + stale_event = LLMCallCompletedEvent( + call_id=str(uuid4()), + model="gpt-4o-mini", + response="ok", + call_type=LLMCallType.LLM_CALL, + usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998}, + ) + ctx = contextvars.copy_context() + ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event))) + + # Stale handler bailed: second kickoff's accumulator is still zero. + assert flow.usage_metrics.total_tokens == 0 + + def test_pause_detaches_listener_and_does_not_leak(self) -> None: + """When ``kickoff_async`` pauses for human feedback, the listener + must be detached from the singleton bus to avoid leaking handlers + across abandoned paused instances. Pre-pause LLM events still + count because the bus snapshots handlers at emit time. Late + events emitted after the pause returns do not count for this + instance — resume paths re-attach a fresh listener.""" + + from crewai.flow.async_feedback.types import HumanFeedbackPending + + captured: dict[str, Any] = {} + + class _PausingFlow(Flow): + @start() + def begin(self) -> None: + _emit_llm_call( + flow_id=self._flow_match_id, + prompt_tokens=10, + completion_tokens=20, + ) + captured["pre_pause_total"] = self.usage_metrics.total_tokens + raise HumanFeedbackPending( + context=PendingFeedbackContext( + flow_id=self.flow_id, + flow_class="_PausingFlow", + method_name="begin", + method_output="content", + message="Review:", + ) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db")) + flow = _PausingFlow(persistence=persistence) + result = flow.kickoff() + + assert isinstance(result, HumanFeedbackPending) + assert captured["pre_pause_total"] == 30 + assert flow._usage_aggregation_handler is None + + # A late event emitted after the pause does not reach the + # detached listener, so the running total is unchanged. + _emit_llm_call( + flow_id=flow._flow_match_id, + prompt_tokens=2, + completion_tokens=3, + ) + assert flow.usage_metrics.total_tokens == 30 + + def test_aggregates_resume_after_from_pending(self) -> None: + """A flow restored via ``from_pending`` is a fresh instance with no + ``_flow_match_id``; without seeding it, the listener attached in + ``resume_async`` either ignores its own LLM calls or absorbs unrelated + ones. ``from_pending`` must seed the match id so the resume-phase + aggregator counts our own calls and only our own calls.""" + + class _ResumeFlow(Flow): + @start() + def begin(self) -> str: + return "content" + + @listen(begin) + def on_begin(self, _feedback: Any) -> str: + _emit_llm_call( + flow_id=self._flow_match_id, + prompt_tokens=100, + completion_tokens=50, + ) + _emit_llm_call( + flow_id="some-other-flow", + prompt_tokens=9_999, + completion_tokens=9_999, + ) + return "done" + + with tempfile.TemporaryDirectory() as tmpdir: + persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db")) + flow_id = "usage-resume-test" + persistence.save_pending_feedback( + flow_uuid=flow_id, + context=PendingFeedbackContext( + flow_id=flow_id, + flow_class="_ResumeFlow", + method_name="begin", + method_output="content", + message="Review:", + ), + state_data={"id": flow_id}, + ) + + flow = _ResumeFlow.from_pending(flow_id, persistence) + assert flow._flow_match_id == flow.flow_id + + flow.resume("ok") + + assert flow.usage_metrics.total_tokens == 150 + assert flow.usage_metrics.prompt_tokens == 100 + assert flow.usage_metrics.completion_tokens == 50 + assert flow.usage_metrics.successful_requests == 1