mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
fix: address PR review on flow.usage_metrics
This commit is contained in:
@@ -941,6 +941,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_flow_match_id: str | None = PrivateAttr(default=None)
|
||||
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
|
||||
# Incremented on every kickoff that takes ownership of usage aggregation.
|
||||
# The listener closure snapshots the epoch at attach time; a stale
|
||||
# handler still queued in the bus thread pool from a prior kickoff
|
||||
# compares its snapshot against the current value and bails out so it
|
||||
# cannot contaminate a later kickoff's accumulator.
|
||||
_usage_epoch: int = PrivateAttr(default=0)
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
|
||||
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
|
||||
@@ -1011,8 +1017,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return
|
||||
|
||||
flow_ref = self
|
||||
captured_epoch = self._usage_epoch
|
||||
|
||||
def _accumulate(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
# Stale-handler guard: the bus dispatches sync handlers on a
|
||||
# thread pool that `emit` does not wait on, so a handler from
|
||||
# a prior kickoff can still be queued when a later kickoff
|
||||
# bumps the epoch and resets the accumulator. Bail out so we
|
||||
# don't leak prior-run usage into the new accumulator.
|
||||
if captured_epoch != flow_ref._usage_epoch:
|
||||
return
|
||||
if current_flow_id.get() != flow_ref._flow_match_id:
|
||||
return
|
||||
metrics = _usage_dict_to_metrics(event.usage)
|
||||
@@ -1033,6 +1047,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
@property
|
||||
def usage_metrics(self) -> UsageMetrics:
|
||||
"""Aggregated LLM token usage for the most recent kickoff (or
|
||||
resume) of this flow instance.
|
||||
|
||||
Aggregation is correlated by the ``current_flow_id`` contextvar
|
||||
captured at kickoff time. Nested kickoffs (a parent flow calling
|
||||
a child flow's ``kickoff``) intentionally roll the child's
|
||||
tokens up into the parent because the contextvar is inherited.
|
||||
Sibling kickoffs that run in parallel under the same parent
|
||||
contextvar share the same correlation id and may therefore
|
||||
over-count each other; if you need strict per-flow isolation
|
||||
in that pattern, run the children in separate tasks that
|
||||
explicitly set their own ``current_flow_id`` before kickoff.
|
||||
"""
|
||||
with self._usage_metrics_lock:
|
||||
return self._aggregated_usage_metrics.model_copy()
|
||||
|
||||
@@ -1330,6 +1357,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
instance._initialize_state(state_data)
|
||||
instance._pending_feedback_context = pending_context
|
||||
instance._is_execution_resuming = True
|
||||
# Seed the usage-aggregation match id so `resume_async` can wire its
|
||||
# listener and restore `current_flow_id` correctly. Without this,
|
||||
# a restored flow has a None match id and the handler would either
|
||||
# ignore its own LLM calls or absorb unrelated ones from sibling
|
||||
# flows. The accumulator itself starts at zero — any usage from
|
||||
# before the pause was only observable on the original kickoff
|
||||
# instance.
|
||||
instance._flow_match_id = instance.flow_id
|
||||
|
||||
return instance
|
||||
|
||||
@@ -1428,18 +1463,17 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"No pending feedback context. Use from_pending() to restore a paused flow."
|
||||
)
|
||||
|
||||
# Re-attach the usage aggregation listener for the resume phase so
|
||||
# LLM calls during outcome collapsing and downstream crews continue
|
||||
# to roll up into `flow.usage_metrics`. The listener was detached
|
||||
# when `kickoff_async` returned on pause. We also restore
|
||||
# `current_flow_id` to the original kickoff's match id so the handler
|
||||
# filter passes.
|
||||
owns_usage_aggregation = self._usage_aggregation_handler is None
|
||||
# Wire usage aggregation for the resume phase. The previous
|
||||
# `kickoff_async`/`resume_async` already detached its listener,
|
||||
# and `_attach_usage_aggregation_listener` is idempotent, so we
|
||||
# can always call it here. We also restore `current_flow_id`
|
||||
# when missing so the handler's filter passes for LLM calls
|
||||
# made during outcome collapsing and downstream listener
|
||||
# execution.
|
||||
flow_id_token = None
|
||||
if owns_usage_aggregation:
|
||||
if current_flow_id.get() is None and self._flow_match_id is not None:
|
||||
flow_id_token = current_flow_id.set(self._flow_match_id)
|
||||
self._attach_usage_aggregation_listener()
|
||||
if current_flow_id.get() is None and self._flow_match_id is not None:
|
||||
flow_id_token = current_flow_id.set(self._flow_match_id)
|
||||
self._attach_usage_aggregation_listener()
|
||||
|
||||
try:
|
||||
if get_current_parent_id() is None:
|
||||
@@ -1642,10 +1676,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return final_result
|
||||
finally:
|
||||
if owns_usage_aggregation:
|
||||
self._detach_usage_aggregation_listener()
|
||||
if flow_id_token is not None:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
# Always release the singleton-bus listener at the end of
|
||||
# this resume call. In-flight handlers the bus already
|
||||
# snapshotted continue updating `_aggregated_usage_metrics`;
|
||||
# only future emits stop reaching this flow. A subsequent
|
||||
# `resume_async` on the same instance re-attaches a fresh
|
||||
# listener.
|
||||
self._detach_usage_aggregation_listener()
|
||||
if flow_id_token is not None:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
|
||||
def _create_initial_state(self) -> T:
|
||||
"""Create and initialize flow state with UUID and default values.
|
||||
@@ -2160,6 +2199,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# gets overwritten later by `inputs["id"]`.
|
||||
self._flow_match_id = current_flow_id.get()
|
||||
self._aggregated_usage_metrics = UsageMetrics()
|
||||
# Bump the epoch BEFORE attaching so any in-flight handler from
|
||||
# a prior kickoff queued in the bus thread pool sees its stale
|
||||
# snapshot and bails out instead of writing into the fresh
|
||||
# accumulator.
|
||||
self._usage_epoch += 1
|
||||
self._attach_usage_aggregation_listener()
|
||||
|
||||
try:
|
||||
@@ -2451,6 +2495,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# Ensure all background memory saves complete before returning
|
||||
if self.memory is not None and hasattr(self.memory, "drain_writes"):
|
||||
self.memory.drain_writes()
|
||||
# Detach the singleton-bus listener as soon as this kickoff
|
||||
# finishes (whether by completion, pause, or error). Handlers
|
||||
# the bus already snapshotted via `_sync_executor.submit` keep
|
||||
# running and updating `_aggregated_usage_metrics`; only
|
||||
# subsequent emits stop reaching this flow. Resume paths
|
||||
# re-attach a fresh listener via `resume_async`.
|
||||
if owns_usage_aggregation:
|
||||
self._detach_usage_aggregation_listener()
|
||||
if request_id_token is not None:
|
||||
|
||||
@@ -10,6 +10,8 @@ explicit contextvar control; no live LLM provider is required.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -17,8 +19,10 @@ import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.flow.flow_context import current_flow_id
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
from crewai.flow.runtime import _usage_dict_to_metrics
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
@@ -250,3 +254,141 @@ class TestFlowUsageAggregation:
|
||||
failing.kickoff()
|
||||
|
||||
assert handler_count() == before
|
||||
|
||||
def test_stale_handler_from_prior_kickoff_does_not_contaminate(self) -> None:
|
||||
"""The bus dispatches sync handlers on a thread pool that ``emit``
|
||||
does not wait on. A handler still queued from a prior kickoff
|
||||
must not write into a later kickoff's accumulator — the epoch
|
||||
snapshot in the handler closure bails out on mismatch."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def script(flow: Flow) -> None:
|
||||
_emit_llm_call(flow_id=flow._flow_match_id, prompt_tokens=10, completion_tokens=10)
|
||||
captured["handler"] = flow._usage_aggregation_handler
|
||||
captured["match_id"] = flow._flow_match_id
|
||||
|
||||
flow = _run(script)
|
||||
first_total = flow.usage_metrics.total_tokens
|
||||
assert first_total == 20
|
||||
|
||||
# A second kickoff bumps the epoch and resets the accumulator.
|
||||
flow._script = lambda f: None
|
||||
flow.kickoff()
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
stale_handler = captured["handler"]
|
||||
assert stale_handler is not None
|
||||
|
||||
stale_event = LLMCallCompletedEvent(
|
||||
call_id=str(uuid4()),
|
||||
model="gpt-4o-mini",
|
||||
response="ok",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
usage={"prompt_tokens": 999, "completion_tokens": 999, "total_tokens": 1998},
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
ctx.run(lambda: (current_flow_id.set(captured["match_id"]), stale_handler(object(), stale_event)))
|
||||
|
||||
# Stale handler bailed: second kickoff's accumulator is still zero.
|
||||
assert flow.usage_metrics.total_tokens == 0
|
||||
|
||||
def test_pause_detaches_listener_and_does_not_leak(self) -> None:
|
||||
"""When ``kickoff_async`` pauses for human feedback, the listener
|
||||
must be detached from the singleton bus to avoid leaking handlers
|
||||
across abandoned paused instances. Pre-pause LLM events still
|
||||
count because the bus snapshots handlers at emit time. Late
|
||||
events emitted after the pause returns do not count for this
|
||||
instance — resume paths re-attach a fresh listener."""
|
||||
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackPending
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class _PausingFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> None:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
)
|
||||
captured["pre_pause_total"] = self.usage_metrics.total_tokens
|
||||
raise HumanFeedbackPending(
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=self.flow_id,
|
||||
flow_class="_PausingFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
)
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow = _PausingFlow(persistence=persistence)
|
||||
result = flow.kickoff()
|
||||
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
assert captured["pre_pause_total"] == 30
|
||||
assert flow._usage_aggregation_handler is None
|
||||
|
||||
# A late event emitted after the pause does not reach the
|
||||
# detached listener, so the running total is unchanged.
|
||||
_emit_llm_call(
|
||||
flow_id=flow._flow_match_id,
|
||||
prompt_tokens=2,
|
||||
completion_tokens=3,
|
||||
)
|
||||
assert flow.usage_metrics.total_tokens == 30
|
||||
|
||||
def test_aggregates_resume_after_from_pending(self) -> None:
|
||||
"""A flow restored via ``from_pending`` is a fresh instance with no
|
||||
``_flow_match_id``; without seeding it, the listener attached in
|
||||
``resume_async`` either ignores its own LLM calls or absorbs unrelated
|
||||
ones. ``from_pending`` must seed the match id so the resume-phase
|
||||
aggregator counts our own calls and only our own calls."""
|
||||
|
||||
class _ResumeFlow(Flow):
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "content"
|
||||
|
||||
@listen(begin)
|
||||
def on_begin(self, _feedback: Any) -> str:
|
||||
_emit_llm_call(
|
||||
flow_id=self._flow_match_id,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
_emit_llm_call(
|
||||
flow_id="some-other-flow",
|
||||
prompt_tokens=9_999,
|
||||
completion_tokens=9_999,
|
||||
)
|
||||
return "done"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = SQLiteFlowPersistence(os.path.join(tmpdir, "f.db"))
|
||||
flow_id = "usage-resume-test"
|
||||
persistence.save_pending_feedback(
|
||||
flow_uuid=flow_id,
|
||||
context=PendingFeedbackContext(
|
||||
flow_id=flow_id,
|
||||
flow_class="_ResumeFlow",
|
||||
method_name="begin",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
),
|
||||
state_data={"id": flow_id},
|
||||
)
|
||||
|
||||
flow = _ResumeFlow.from_pending(flow_id, persistence)
|
||||
assert flow._flow_match_id == flow.flow_id
|
||||
|
||||
flow.resume("ok")
|
||||
|
||||
assert flow.usage_metrics.total_tokens == 150
|
||||
assert flow.usage_metrics.prompt_tokens == 100
|
||||
assert flow.usage_metrics.completion_tokens == 50
|
||||
assert flow.usage_metrics.successful_requests == 1
|
||||
|
||||
Reference in New Issue
Block a user