fix: address PR review on flow.usage_metrics

- Protect _aggregated_usage_metrics with a lock so concurrent
  LLMCallCompletedEvent handlers can't race the read-modify-write
  inside add_usage_metrics, and so usage_metrics snapshots are
  consistent.
- Wire the usage aggregation listener into resume_async so LLM
  calls during outcome collapsing and downstream crews continue
  to roll up into flow.usage_metrics after a paused-then-resumed
  kickoff. Restores current_flow_id to the original kickoff's
  match id when none is set, and detaches in finally.
- Guard against reentrant kickoff on the same Flow instance:
  only the outer kickoff captures _flow_match_id, resets the
  accumulator, and owns the listener lifecycle. Inner reentrant
  calls pass through and no longer wipe outer state or detach
  the shared handler.
- Rename test_snapshot_is_immutable to
  test_usage_metrics_returns_independent_copy to reflect that
  the property returns a copy of a (still-mutable) UsageMetrics.
- Extend test_handler_is_unregistered_after_kickoff to also
  cover the failure path, confirming the handler is removed
  when kickoff raises.
This commit is contained in:
Lucas Gomide
2026-06-11 13:38:55 -03:00
parent c4476366ff
commit 540f5df767
2 changed files with 252 additions and 193 deletions

View File

@@ -934,6 +934,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
_state: Any = PrivateAttr(default=None)
_deferred_flow_started_event_id: str | None = PrivateAttr(default=None)
_aggregated_usage_metrics: UsageMetrics = PrivateAttr(default_factory=UsageMetrics)
# Serializes mutations and snapshot reads on `_aggregated_usage_metrics`.
# The bus dispatches sync handlers from a `ThreadPoolExecutor`, so two
# concurrent `LLMCallCompletedEvent`s can race the read-modify-write
# inside `add_usage_metrics`.
_usage_metrics_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_flow_match_id: str | None = PrivateAttr(default=None)
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
@@ -1011,7 +1016,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if current_flow_id.get() != flow_ref._flow_match_id:
return
metrics = _usage_dict_to_metrics(event.usage)
if metrics is not None:
if metrics is None:
return
with flow_ref._usage_metrics_lock:
flow_ref._aggregated_usage_metrics.add_usage_metrics(metrics)
crewai_event_bus.on(LLMCallCompletedEvent)(_accumulate)
@@ -1026,7 +1033,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
@property
def usage_metrics(self) -> UsageMetrics:
return self._aggregated_usage_metrics.model_copy()
with self._usage_metrics_lock:
return self._aggregated_usage_metrics.model_copy()
def recall(self, query: str, **kwargs: Any) -> Any:
"""Recall relevant memories. Delegates to this flow's memory.
@@ -1420,201 +1428,231 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"No pending feedback context. Use from_pending() to restore a paused flow."
)
if get_current_parent_id() is None:
reset_emission_counter()
reset_last_event_id()
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
inputs=None,
),
)
if future and isinstance(future, Future):
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning("FlowStartedEvent handler failed", exc_info=True)
get_env_context()
context = self._pending_feedback_context
emit = context.emit
default_outcome = context.default_outcome
# Try to get the live LLM from the re-imported decorator first.
# This preserves the fully-configured object (credentials, safety_settings, etc.)
# for same-process resume. For cross-process resume, fall back to the
# serialized context.llm which is now a dict with full config (or a legacy string).
from crewai.flow.human_feedback import _deserialize_llm_from_context
llm = None
method = self._methods.get(FlowMethodName(context.method_name))
if method is not None:
live_llm = getattr(method, "_human_feedback_llm", None)
if live_llm is not None:
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
if isinstance(live_llm, BaseLLMClass):
llm = live_llm
if llm is None:
llm = _deserialize_llm_from_context(context.llm)
collapsed_outcome: str | None = None
if not feedback.strip():
if default_outcome:
collapsed_outcome = default_outcome
elif emit:
collapsed_outcome = emit[0]
elif emit:
if llm is not None:
collapsed_outcome = self._collapse_to_outcome(
feedback=feedback,
outcomes=emit,
llm=llm,
)
else:
collapsed_outcome = emit[0]
result = HumanFeedbackResult(
output=context.method_output,
feedback=feedback,
outcome=collapsed_outcome,
timestamp=datetime.now(),
method_name=context.method_name,
metadata=context.metadata,
)
self.human_feedback_history.append(result)
self.last_human_feedback = result
self._completed_methods.add(FlowMethodName(context.method_name))
self._pending_feedback_context = None
if self.persistence is not None:
self.persistence.clear_pending_feedback(context.flow_id)
if not self.suppress_flow_events:
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
flow_name=self.name or self.__class__.__name__,
method_name=context.method_name,
result=collapsed_outcome if emit else result,
state=self._state,
),
)
# Clear resumption flag before triggering listeners
# This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes)
self._is_execution_resuming = False
if emit and collapsed_outcome is None:
collapsed_outcome = default_outcome or emit[0]
result.outcome = collapsed_outcome
# Re-attach the usage aggregation listener for the resume phase so
# LLM calls during outcome collapsing and downstream crews continue
# to roll up into `flow.usage_metrics`. The listener was detached
# when `kickoff_async` returned on pause. We also restore
# `current_flow_id` to the original kickoff's match id so the handler
# filter passes.
owns_usage_aggregation = self._usage_aggregation_handler is None
flow_id_token = None
if owns_usage_aggregation:
if (
current_flow_id.get() is None
and self._flow_match_id is not None
):
flow_id_token = current_flow_id.set(self._flow_match_id)
self._attach_usage_aggregation_listener()
try:
if emit and collapsed_outcome:
self._method_outputs.append(collapsed_outcome)
await self._execute_listeners(
FlowMethodName(collapsed_outcome),
result,
)
else:
await self._execute_listeners(
FlowMethodName(context.method_name),
result,
)
except Exception as e:
# Check if flow was paused again for human feedback (loop case)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if get_current_parent_id() is None:
reset_emission_counter()
reset_last_event_id()
if isinstance(e, HumanFeedbackPending):
self._pending_feedback_context = e.context
if self.persistence is None:
from crewai.flow.persistence.factory import default_flow_persistence
self.persistence = default_flow_persistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
)
crewai_event_bus.emit(
if not self.suppress_flow_events:
future = crewai_event_bus.emit(
self,
FlowPausedEvent(
type="flow_paused",
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
flow_id=e.context.flow_id,
method_name=e.context.method_name,
state=self._copy_and_serialize_state(),
message=e.context.message,
emit=e.context.emit,
inputs=None,
),
)
return e
raise
if future and isinstance(future, Future):
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning(
"FlowStartedEvent handler failed", exc_info=True
)
final_result = self._method_outputs[-1] if self._method_outputs else result
get_env_context()
if self._event_futures:
await asyncio.gather(
*[
asyncio.wrap_future(f)
for f in self._event_futures
if isinstance(f, Future)
]
)
self._event_futures.clear()
context = self._pending_feedback_context
emit = context.emit
default_outcome = context.default_outcome
if (
not self.suppress_flow_events
and not self._should_defer_trace_finalization()
):
future = crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_result,
state=self._copy_and_serialize_state(),
),
)
if future and isinstance(future, Future):
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning("FlowFinishedEvent handler failed", exc_info=True)
# Try to get the live LLM from the re-imported decorator first.
# This preserves the fully-configured object (credentials, safety_settings, etc.)
# for same-process resume. For cross-process resume, fall back to the
# serialized context.llm which is now a dict with full config (or a legacy string).
from crewai.flow.human_feedback import _deserialize_llm_from_context
trace_listener = TraceCollectionListener()
if (
trace_listener.batch_manager.batch_owner_type == "flow"
and current_flow_id.get() == self.flow_id
and not trace_listener.batch_manager.defer_session_finalization
and not current_flow_defer_trace_finalization.get()
):
if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion()
llm = None
method = self._methods.get(FlowMethodName(context.method_name))
if method is not None:
live_llm = getattr(method, "_human_feedback_llm", None)
if live_llm is not None:
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
if isinstance(live_llm, BaseLLMClass):
llm = live_llm
if llm is None:
llm = _deserialize_llm_from_context(context.llm)
collapsed_outcome: str | None = None
if not feedback.strip():
if default_outcome:
collapsed_outcome = default_outcome
elif emit:
collapsed_outcome = emit[0]
elif emit:
if llm is not None:
collapsed_outcome = self._collapse_to_outcome(
feedback=feedback,
outcomes=emit,
llm=llm,
)
else:
trace_listener.batch_manager.finalize_batch()
collapsed_outcome = emit[0]
return final_result
result = HumanFeedbackResult(
output=context.method_output,
feedback=feedback,
outcome=collapsed_outcome,
timestamp=datetime.now(),
method_name=context.method_name,
metadata=context.metadata,
)
self.human_feedback_history.append(result)
self.last_human_feedback = result
self._completed_methods.add(FlowMethodName(context.method_name))
self._pending_feedback_context = None
if self.persistence is not None:
self.persistence.clear_pending_feedback(context.flow_id)
if not self.suppress_flow_events:
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
flow_name=self.name or self.__class__.__name__,
method_name=context.method_name,
result=collapsed_outcome if emit else result,
state=self._state,
),
)
# Clear resumption flag before triggering listeners
# This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes)
self._is_execution_resuming = False
if emit and collapsed_outcome is None:
collapsed_outcome = default_outcome or emit[0]
result.outcome = collapsed_outcome
try:
if emit and collapsed_outcome:
self._method_outputs.append(collapsed_outcome)
await self._execute_listeners(
FlowMethodName(collapsed_outcome),
result,
)
else:
await self._execute_listeners(
FlowMethodName(context.method_name),
result,
)
except Exception as e:
# Check if flow was paused again for human feedback (loop case)
from crewai.flow.async_feedback.types import HumanFeedbackPending
if isinstance(e, HumanFeedbackPending):
self._pending_feedback_context = e.context
if self.persistence is None:
from crewai.flow.persistence.factory import (
default_flow_persistence,
)
self.persistence = default_flow_persistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
)
crewai_event_bus.emit(
self,
FlowPausedEvent(
type="flow_paused",
flow_name=self.name or self.__class__.__name__,
flow_id=e.context.flow_id,
method_name=e.context.method_name,
state=self._copy_and_serialize_state(),
message=e.context.message,
emit=e.context.emit,
),
)
return e
raise
final_result = (
self._method_outputs[-1] if self._method_outputs else result
)
if self._event_futures:
await asyncio.gather(
*[
asyncio.wrap_future(f)
for f in self._event_futures
if isinstance(f, Future)
]
)
self._event_futures.clear()
if (
not self.suppress_flow_events
and not self._should_defer_trace_finalization()
):
future = crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_result,
state=self._copy_and_serialize_state(),
),
)
if future and isinstance(future, Future):
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning(
"FlowFinishedEvent handler failed", exc_info=True
)
trace_listener = TraceCollectionListener()
if (
trace_listener.batch_manager.batch_owner_type == "flow"
and current_flow_id.get() == self.flow_id
and not trace_listener.batch_manager.defer_session_finalization
and not current_flow_defer_trace_finalization.get()
):
if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion()
else:
trace_listener.batch_manager.finalize_batch()
return final_result
finally:
if owns_usage_aggregation:
self._detach_usage_aggregation_listener()
if flow_id_token is not None:
current_flow_id.reset(flow_id_token)
def _create_initial_state(self) -> T:
"""Create and initialize flow state with UUID and default values.
@@ -2118,12 +2156,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
runtime_scope = crewai_event_bus._enter_runtime_scope()
# Capture the flow id seen by `FlowTrackable._set_flow_context` so we
# can match LLM call events back to this flow even if `state.id` gets
# overwritten later by `inputs["id"]`.
self._flow_match_id = current_flow_id.get()
self._aggregated_usage_metrics = UsageMetrics()
self._attach_usage_aggregation_listener()
# Guard against a reentrant kickoff on the same Flow instance: only
# the outermost call captures `_flow_match_id`, resets the accumulator,
# and owns the listener lifecycle. An inner reentrant call passes
# through so it doesn't wipe outer's state or detach the shared handler.
owns_usage_aggregation = self._usage_aggregation_handler is None
if owns_usage_aggregation:
# Capture the flow id seen by `FlowTrackable._set_flow_context` so
# we can match LLM call events back to this flow even if `state.id`
# gets overwritten later by `inputs["id"]`.
self._flow_match_id = current_flow_id.get()
self._aggregated_usage_metrics = UsageMetrics()
self._attach_usage_aggregation_listener()
try:
# Reset flow state for fresh execution unless restoring from persistence
@@ -2414,7 +2458,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# Ensure all background memory saves complete before returning
if self.memory is not None and hasattr(self.memory, "drain_writes"):
self.memory.drain_writes()
self._detach_usage_aggregation_listener()
if owns_usage_aggregation:
self._detach_usage_aggregation_listener()
if request_id_token is not None:
current_flow_request_id.reset(request_id_token)
if flow_defer_trace_finalization_token is not None:

View File

@@ -205,7 +205,10 @@ class TestFlowUsageAggregation:
assert flow.usage_metrics.total_tokens == 500
assert flow.usage_metrics.successful_requests == 1
def test_snapshot_is_immutable(self) -> None:
def test_usage_metrics_returns_independent_copy(self) -> None:
"""``usage_metrics`` must return a copy, not the internal instance —
otherwise callers can clobber the in-flight accumulator."""
flow = _run(
lambda f: _emit_llm_call(
flow_id=f._flow_match_id, prompt_tokens=50, completion_tokens=50
@@ -219,7 +222,7 @@ class TestFlowUsageAggregation:
def test_handler_is_unregistered_after_kickoff(self) -> None:
"""Long-lived workers (Celery, devkit) must not leak one handler per
kickoff on the singleton bus."""
kickoff on the singleton bus, on either the success or failure path."""
def handler_count() -> int:
return len(
@@ -236,3 +239,14 @@ class TestFlowUsageAggregation:
flow.kickoff()
assert handler_count() == before
def boom(_f: Flow) -> None:
raise RuntimeError("boom")
failing = _ScriptedFlow()
failing._script = boom
with pytest.raises(RuntimeError, match="boom"):
failing.kickoff()
assert handler_count() == before