diff --git a/docs/en/guides/flows/conversational-flows.mdx b/docs/en/guides/flows/conversational-flows.mdx new file mode 100644 index 000000000..f06e1e6b9 --- /dev/null +++ b/docs/en/guides/flows/conversational-flows.mdx @@ -0,0 +1,150 @@ +--- +title: Conversational Flows +description: Build multi-turn chat apps with kickoff per turn, message history, and intent routing. +icon: comments +mode: "wide" +--- + +## One entry point: `kickoff` + +Chat apps should use **`flow.kickoff(user_message=..., session_id=...)`** for each user line. Do not add a separate `chat()` method — session identity is `state.id` (same as `inputs["id"]` with `@persist`). + +| API | Use for | +|-----|---------| +| `kickoff(user_message=..., session_id=...)` | Each new user message (API, WebSocket turn, CLI loop) | +| `ask()` | Blocking prompt **inside** one step (wizard, clarification) | +| `@human_feedback` | Approve/reject **a step output** before continuing — not the next chat line | + +## Quick start + +```python +from uuid import uuid4 + +from pydantic import Field + +from crewai.flow import ChatState, Flow, listen, persist, router, start +from crewai.flow.persistence import SQLiteFlowPersistence + + +@persist(SQLiteFlowPersistence()) +class SupportFlow(Flow[ChatState]): + @start() + def bootstrap(self): + if self.state.session_ready: + return "ready" + # load permissions once per session + self.state.session_ready = True + return "ready" + + @router(bootstrap) + def route(self): + if self.state.last_intent: + return self.state.last_intent + return self.classify_intent( + self.state.last_user_message, + outcomes=["order", "help", "goodbye"], + llm="gpt-4o-mini", + context=self.conversation_messages, + ) + + @listen("order") + def handle_order(self): + reply = "Your order is on the way." + self.append_message("assistant", reply) + return reply + + @listen("help") + def handle_help(self): + reply = "How can I help?" + self.append_message("assistant", reply) + return reply + + +session_id = str(uuid4()) +flow = SupportFlow() + +# Turn 1 +flow.kickoff(user_message="Where is my order?", session_id=session_id) + +# Turn 2 — same session, flow finished after turn 1 is normal +flow.kickoff(user_message="What about returns?", session_id=session_id) +``` + +## When the flow finishes but the user keeps chatting + +`FlowFinished` means **this graph run** completed. The conversation continues with **another `kickoff`** and the same `session_id`. `@persist` restores `messages`, permissions flags, and context. + +For multi-turn chat, prefer **`@persist` on a single terminal step** (for example `finalize`) rather than on the whole `Flow` class. Class-level persist saves after every method; `load_state` uses the latest row, which is often a mid-run snapshot (for example right after `bootstrap`) and can omit handler updates from the same turn. + +Do **not** use `@human_feedback` for follow-up questions unless a human must approve a specific payload before it is shown. + +## Tracing across turns + +By default, `ConversationalConfig(defer_trace_finalization=True)` keeps **one trace batch** for the whole chat session instead of finalizing after every `kickoff()`. Call `flow.finalize_session_traces()` when the user leaves (or use `ChatSession.close()`, which does this automatically). + +```python +flow.kickoff(user_message="Hello", session_id=session_id) +flow.kickoff(user_message="Track my order", session_id=session_id) +flow.finalize_session_traces() # one link for the full conversation +``` + +Crews kicked off **inside** a flow method (for example a research step) append their events to the **parent flow batch**. `CrewKickoffCompleted` does not finalize that batch (even when crew worker threads lose `current_flow_id` context); finalization happens in `finalize_session_traces()`. + +Per-turn `flow_finished` is also deferred: only `flow_started` opens the session scope on the first turn, and `flow_finished` runs once in `finalize_session_traces()` so the event bus does not warn about a missing `flow_started`. + +## Kickoff parameters + +| Parameter | Purpose | +|-----------|---------| +| `user_message` | This turn's user text (also `inputs["user_message"]`) | +| `session_id` | Conversation UUID → `inputs["id"]` / `state.id` | +| `intents` | Optional labels for pre-kickoff `classify_intent` | +| `intent_llm` | LLM for classification (required with `intents`) | +| `interactive=True` | CLI demo loop via `ask()` (not for production APIs) | + +Class-level defaults: + +```python +from crewai.flow import ConversationalConfig, Flow + + +class MyFlow(Flow[ChatState]): + conversational_config = ConversationalConfig( + default_intents=["order", "help"], + intent_llm="gpt-4o-mini", + ) +``` + +## Helpers on `Flow` + +- `append_message(role, content)` — update `state.messages` +- `conversation_messages` — history for LLM calls +- `classify_intent(text, outcomes, llm=..., context=...)` — route labels (same logic as `@human_feedback` collapse) +- `receive_user_message(text, ...)` — append user line + optional classify +- `input_history` — audit trail from `ask()` + +Recommended state shape: `ChatState` (`id`, `messages`, `last_user_message`, `last_intent`, `session_ready`). + +## ChatSession (WebSocket / SSE bridge) + +For UIs, use `ChatSession` to wrap kickoff and map events to `ChatMessage`: + +```python +from crewai.flow import ChatSession + + +def on_event(msg): + print(msg.type, msg.payload) + + +session = ChatSession(flow, session_id="channel-1", on_event=on_event) +turn = session.handle_turn("Hello") +print(turn.output, turn.intent) +session.close() +``` + +`QueueInputProvider` supports blocking `ask()` fed by a WebSocket handler (`provider.push(session_id, text)`). + +## Streaming + +Enable `stream = True` on the Flow class and use `kickoff` / `ChatSession.handle_turn(..., stream=True)` to emit `assistant_delta` events through `ConversationEventBridge`. diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index e63b6d4bf..959bdd769 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -320,20 +320,28 @@ class EventListener(BaseEventListener): self._telemetry.flow_execution_span( event.flow_name, list(source._methods.keys()) ) - self.formatter.handle_flow_created(event.flow_name, str(source.flow_id)) - self.formatter.handle_flow_started(event.flow_name, str(source.flow_id)) + if not getattr(source, "suppress_flow_events", False): + self.formatter.handle_flow_created( + event.flow_name, str(source.flow_id) + ) + self.formatter.handle_flow_started( + event.flow_name, str(source.flow_id) + ) @crewai_event_bus.on(FlowFinishedEvent) def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None: - self.formatter.handle_flow_status( - event.flow_name, - source.flow_id, - ) + if not getattr(source, "suppress_flow_events", False): + self.formatter.handle_flow_status( + event.flow_name, + source.flow_id, + ) @crewai_event_bus.on(MethodExecutionStartedEvent) def on_method_execution_started( - _: Any, event: MethodExecutionStartedEvent + source: Any, event: MethodExecutionStartedEvent ) -> None: + if getattr(source, "suppress_flow_events", False): + return self.formatter.handle_method_status( event.method_name, "running", @@ -341,8 +349,10 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(MethodExecutionFinishedEvent) def on_method_execution_finished( - _: Any, event: MethodExecutionFinishedEvent + source: Any, event: MethodExecutionFinishedEvent ) -> None: + if getattr(source, "suppress_flow_events", False): + return self.formatter.handle_method_status( event.method_name, "completed", diff --git a/lib/crewai/src/crewai/events/listeners/tracing/first_time_trace_handler.py b/lib/crewai/src/crewai/events/listeners/tracing/first_time_trace_handler.py index 436d50c27..e6fb4b32e 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/first_time_trace_handler.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/first_time_trace_handler.py @@ -222,6 +222,8 @@ To enable tracing later, do any one of these: return self.batch_manager.batch_owner_type = None self.batch_manager.batch_owner_id = None + self.batch_manager.defer_session_finalization = False + self.batch_manager._batch_finalized = False self.batch_manager.current_batch = None self.batch_manager.event_buffer.clear() self.batch_manager.trace_batch_id = None diff --git a/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py b/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py index 0cfe227ac..282d2f1b5 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py @@ -70,6 +70,8 @@ class TraceBatchManager: self.execution_start_times: dict[str, datetime] = {} self.batch_owner_type: str | None = None self.batch_owner_id: str | None = None + self.defer_session_finalization: bool = False + self._batch_finalized: bool = False self.backend_initialized: bool = False self.ephemeral_trace_url: str | None = None try: @@ -101,6 +103,7 @@ class TraceBatchManager: user_context=user_context, execution_metadata=execution_metadata ) self.is_current_batch_ephemeral = use_ephemeral + self._batch_finalized = False self.record_start_time("execution") @@ -312,6 +315,9 @@ class TraceBatchManager: def finalize_batch(self) -> TraceBatch | None: """Finalize batch and return it for sending""" + if self._batch_finalized: + return None + if not self.current_batch or not is_tracing_enabled_in_context(): return None @@ -340,10 +346,8 @@ class TraceBatchManager: self.current_batch.events = sorted_events events_sent_count = len(sorted_events) if sorted_events: - original_buffer = self.event_buffer self.event_buffer = sorted_events events_sent_to_backend_status = self._send_events_to_backend() - self.event_buffer = original_buffer if events_sent_to_backend_status == 500 and self.trace_batch_id: self._mark_batch_as_failed( self.trace_batch_id, "Error sending events to backend" @@ -360,6 +364,7 @@ class TraceBatchManager: self.event_buffer.clear() self.trace_batch_id = None self.is_current_batch_ephemeral = False + self._batch_finalized = True self._cleanup_batch_data() @@ -371,7 +376,7 @@ class TraceBatchManager: Args: events_count: Number of events that were successfully sent """ - if not self.plus_api or not self.trace_batch_id: + if self._batch_finalized or not self.plus_api or not self.trace_batch_id: return try: @@ -390,6 +395,7 @@ class TraceBatchManager: ) if response.status_code == 200: + self._batch_finalized = True access_code = response.json().get("access_code", None) console = Console() settings = Settings() diff --git a/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py b/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py index 8bac1518e..3867ccb4d 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py @@ -1,5 +1,6 @@ """Trace collection listener for orchestrating trace collection.""" +from datetime import datetime, timezone import os from typing import Any, ClassVar import uuid @@ -264,18 +265,18 @@ class TraceCollectionListener(BaseEventListener): @event_bus.on(CrewKickoffStartedEvent) def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None: - if self.batch_manager.batch_owner_type != "flow": - # Always call _initialize_crew_batch to claim ownership. - # If batch was already initialized by a concurrent action event - # (e.g. LLM/tool before crew_kickoff_started), initialize_batch() - # returns early but batch_owner_type is still correctly set to "crew". - # Skip only when a parent flow already owns the batch. + # Nested crew inside Flow.kickoff: never claim an existing flow session batch. + if not self._nested_in_flow_execution() and ( + not self.batch_manager.is_batch_initialized() + ): self._initialize_crew_batch(source, event) self._handle_trace_event("crew_kickoff_started", source, event) @event_bus.on(CrewKickoffCompletedEvent) def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None: self._handle_trace_event("crew_kickoff_completed", source, event) + if self._nested_in_flow_execution(): + return if self.batch_manager.batch_owner_type == "crew": if self.first_time_handler.is_first_time: self.first_time_handler.mark_events_collected() @@ -286,10 +287,12 @@ class TraceCollectionListener(BaseEventListener): @event_bus.on(CrewKickoffFailedEvent) def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None: self._handle_trace_event("crew_kickoff_failed", source, event) + if self._nested_in_flow_execution(): + return if self.first_time_handler.is_first_time: self.first_time_handler.mark_events_collected() self.first_time_handler.handle_execution_completion() - else: + elif self.batch_manager.batch_owner_type == "crew": self.batch_manager.finalize_batch() @event_bus.on(TaskStartedEvent) @@ -708,8 +711,34 @@ class TraceCollectionListener(BaseEventListener): @on_signal def handle_signal(source: Any, event: SignalEvent) -> None: """Flush trace batch on system signals to prevent data loss.""" - if self.batch_manager.is_batch_initialized(): - self.batch_manager.finalize_batch() + if not self.batch_manager.is_batch_initialized(): + return + # Multi-turn flows defer batch finalization to finalize_session_traces(). + if self.batch_manager.defer_session_finalization: + return + self.batch_manager.finalize_batch() + + @staticmethod + def _is_inside_active_flow_context() -> bool: + """True when ``kickoff_async`` has set ``current_flow_id`` (nested crew).""" + from crewai.flow.flow_context import current_flow_id + + return current_flow_id.get() is not None + + def _flow_owns_trace_batch(self) -> bool: + """True when an in-flight conversational flow already owns the trace batch.""" + if self.batch_manager.batch_owner_type == "flow": + return True + batch = self.batch_manager.current_batch + if batch is not None: + return batch.execution_metadata.get("execution_type") == "flow" + return False + + def _nested_in_flow_execution(self) -> bool: + """True when a crew runs inside a flow session (context or batch ownership).""" + return ( + self._is_inside_active_flow_context() or self._flow_owns_trace_batch() + ) def _initialize_crew_batch(self, source: Any, event: BaseEvent) -> None: """Initialize trace batch. @@ -730,6 +759,31 @@ class TraceCollectionListener(BaseEventListener): self._initialize_batch(user_context, execution_metadata) + def _try_initialize_flow_batch_from_context(self, event: Any) -> bool: + """Claim a flow trace batch when an action event fires inside kickoff. + + Flows with ``suppress_flow_events=True`` skip ``FlowStartedEvent``, so + LLM/tool events must not fall back to implicit crew batches. + """ + from crewai.flow.flow_context import current_flow_id, current_flow_name + + flow_id = current_flow_id.get() + if flow_id is None: + return False + + started_at = getattr(event, "timestamp", None) or datetime.now(timezone.utc) + user_context = self._get_user_context() + execution_metadata = { + "flow_name": current_flow_name.get() or "Unknown Flow", + "execution_start": started_at, + "crewai_version": get_crewai_version(), + "execution_type": "flow", + } + self.batch_manager.batch_owner_type = "flow" + self.batch_manager.batch_owner_id = flow_id + self._initialize_batch(user_context, execution_metadata) + return True + def _initialize_flow_batch(self, source: Any, event: BaseEvent) -> None: """Initialize trace batch for Flow execution. @@ -794,12 +848,19 @@ class TraceCollectionListener(BaseEventListener): event: Event object. """ if not self.batch_manager.is_batch_initialized(): - user_context = self._get_user_context() - execution_metadata = { - "crew_name": getattr(source, "name", "Unknown Crew"), - "crewai_version": get_crewai_version(), - } - self._initialize_batch(user_context, execution_metadata) + if self._try_initialize_flow_batch_from_context(event): + pass + elif not self._nested_in_flow_execution(): + user_context = self._get_user_context() + execution_metadata = { + "crew_name": getattr(source, "name", "Unknown Crew"), + "crewai_version": get_crewai_version(), + } + self.batch_manager.batch_owner_type = "crew" + self.batch_manager.batch_owner_id = getattr( + source, "id", str(uuid.uuid4()) + ) + self._initialize_batch(user_context, execution_metadata) self.batch_manager.begin_event_processing() try: diff --git a/lib/crewai/src/crewai/flow/__init__.py b/lib/crewai/src/crewai/flow/__init__.py index 6922725fa..bd26860a1 100644 --- a/lib/crewai/src/crewai/flow/__init__.py +++ b/lib/crewai/src/crewai/flow/__init__.py @@ -4,12 +4,24 @@ from crewai.flow.async_feedback import ( HumanFeedbackProvider, PendingFeedbackContext, ) +from crewai.flow.chat import ( + ChatMessage, + ChatSession, + ConversationEventBridge, + TurnResult, +) +from crewai.flow.conversation import ( + ChatState, + ConversationalConfig, + ConversationalInputs, +) from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.flow_config import flow_config from crewai.flow.flow_serializer import flow_structure from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback from crewai.flow.input_provider import InputProvider, InputResponse from crewai.flow.persistence import persist +from crewai.flow.providers import QueueInputProvider from crewai.flow.visualization import ( FlowStructure, build_flow_structure, @@ -18,7 +30,13 @@ from crewai.flow.visualization import ( __all__ = [ + "ChatMessage", + "ChatSession", + "ChatState", "ConsoleProvider", + "ConversationEventBridge", + "ConversationalConfig", + "ConversationalInputs", "Flow", "FlowStructure", "HumanFeedbackPending", @@ -27,6 +45,8 @@ __all__ = [ "InputProvider", "InputResponse", "PendingFeedbackContext", + "QueueInputProvider", + "TurnResult", "and_", "build_flow_structure", "flow_config", diff --git a/lib/crewai/src/crewai/flow/chat.py b/lib/crewai/src/crewai/flow/chat.py new file mode 100644 index 000000000..044068b5c --- /dev/null +++ b/lib/crewai/src/crewai/flow/chat.py @@ -0,0 +1,307 @@ +"""Transport-agnostic chat session bridge for conversational flows.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator, Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from crewai.flow.conversation import get_conversation_messages, get_conversational_config +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + from crewai.llms.base_llm import BaseLLM + from crewai.types.streaming import FlowStreamingOutput + + +ChatMessageType = Literal[ + "user_message", + "assistant_delta", + "assistant_done", + "turn_started", + "turn_finished", + "error", + "tool_started", + "tool_finished", +] + + +class ChatMessage(BaseModel): + """Versioned wire format for chat UIs (WebSocket, SSE, webhooks).""" + + version: str = "1" + type: ChatMessageType + session_id: str + payload: dict[str, Any] = Field(default_factory=dict) + seq: int | None = None + + +@dataclass +class TurnResult: + """Outcome of a single conversational turn.""" + + session_id: str + output: Any + intent: str | None = None + messages: list[LLMMessage] = field(default_factory=list) + streaming: FlowStreamingOutput | None = None + + +class ChatSession: + """Wraps ``Flow.kickoff`` for one chat session (``state.id``).""" + + def __init__( + self, + flow: Flow[Any], + session_id: str | None = None, + *, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, + on_event: Callable[[ChatMessage], None] | None = None, + ) -> None: + self._flow = flow + self._session_id = session_id or str(uuid4()) + self._intents = list(intents) if intents else None + self._intent_llm = intent_llm + self._on_event = on_event + self._seq = 0 + self._bridge: ConversationEventBridge | None = None + config = get_conversational_config(flow) + if config is not None and config.defer_trace_finalization: + flow.defer_trace_finalization = True + if on_event is not None: + self._bridge = ConversationEventBridge( + session_id=self._session_id, + handler=on_event, + ) + self._bridge.register() + + @property + def session_id(self) -> str: + return self._session_id + + def handle_turn( + self, + user_message: str, + *, + stream: bool | None = None, + ) -> TurnResult: + """Run one conversational turn and return output plus message history.""" + self._emit("turn_started", {"user_message": user_message}) + use_stream = stream if stream is not None else bool(self._flow.stream) + + try: + result = self._flow.kickoff( + user_message=user_message, + session_id=self._session_id, + intents=self._intents, + intent_llm=self._intent_llm, + ) + except Exception as exc: + self._emit("error", {"message": str(exc)}) + raise + + streaming = None + output: Any = result + if use_stream and hasattr(result, "__iter__"): + from crewai.types.streaming import FlowStreamingOutput + + if isinstance(result, FlowStreamingOutput): + streaming = result + for chunk in result: + text = getattr(chunk, "content", None) or str(chunk) + self._emit("assistant_delta", {"chunk": text}) + output = result.result + else: + for chunk in result: + text = getattr(chunk, "content", None) or str(chunk) + self._emit("assistant_delta", {"chunk": text}) + + intent = None + state = self._flow.state + if hasattr(state, "last_intent"): + intent = getattr(state, "last_intent", None) + elif isinstance(state, dict): + intent = state.get("last_intent") + + messages = get_conversation_messages(self._flow) + self._emit( + "assistant_done", + {"output": output, "intent": intent}, + ) + self._emit("turn_finished", {"output": output}) + + return TurnResult( + session_id=self._session_id, + output=output, + intent=intent, + messages=messages, + streaming=streaming, + ) + + def iter_turn_stream( + self, + user_message: str, + ) -> Iterator[ChatMessage]: + """Run a streaming turn and yield ``ChatMessage`` events.""" + collected: list[ChatMessage] = [] + + def _collect(msg: ChatMessage) -> None: + collected.append(msg) + + prior = self._on_event + self._on_event = _collect + if self._bridge is None: + self._bridge = ConversationEventBridge( + session_id=self._session_id, + handler=_collect, + ) + self._bridge.register() + try: + self.handle_turn(user_message, stream=True) + finally: + self._on_event = prior + yield from collected + + def close(self) -> None: + if self._bridge is not None: + self._bridge.unregister() + self._bridge = None + if self._flow._should_defer_trace_finalization(): + self._flow.finalize_session_traces() + + def _emit(self, msg_type: ChatMessageType, payload: dict[str, Any]) -> None: + if self._on_event is None: + return + self._seq += 1 + self._on_event( + ChatMessage( + type=msg_type, + session_id=self._session_id, + payload=payload, + seq=self._seq, + ) + ) + + +class ConversationEventBridge: + """Maps CrewAI bus events to ``ChatMessage`` for a session.""" + + def __init__( + self, + session_id: str, + handler: Callable[[ChatMessage], None], + ) -> None: + self._session_id = session_id + self._handler = handler + self._seq = 0 + self._handlers: list[Any] = [] + + def register(self) -> None: + from crewai.events import crewai_event_bus + from crewai.events.types.flow_events import FlowFinishedEvent + from crewai.events.types.llm_events import ( + LLMStreamChunkEvent, + LLMThinkingChunkEvent, + ) + from crewai.events.types.tool_usage_events import ( + ToolUsageFinishedEvent, + ToolUsageStartedEvent, + ) + + bus = crewai_event_bus + + @bus.on(LLMStreamChunkEvent) + def _on_chunk(_source: Any, event: LLMStreamChunkEvent) -> None: + if not self._matches(event): + return + chunk = getattr(event, "chunk", None) + if chunk: + self._dispatch( + "assistant_delta", + {"chunk": chunk, "agent_role": getattr(event, "agent_role", "")}, + ) + + @bus.on(LLMThinkingChunkEvent) + def _on_thinking(_source: Any, event: LLMThinkingChunkEvent) -> None: + if not self._matches(event): + return + chunk = getattr(event, "chunk", None) + if chunk: + self._dispatch( + "assistant_delta", + { + "chunk": chunk, + "thinking": True, + "agent_role": getattr(event, "agent_role", ""), + }, + ) + + @bus.on(ToolUsageStartedEvent) + def _on_tool_start(_source: Any, event: ToolUsageStartedEvent) -> None: + if not self._matches(event): + return + self._dispatch( + "tool_started", + {"tool_name": getattr(event, "tool_name", "")}, + ) + + @bus.on(ToolUsageFinishedEvent) + def _on_tool_end(_source: Any, event: ToolUsageFinishedEvent) -> None: + if not self._matches(event): + return + self._dispatch( + "tool_finished", + {"tool_name": getattr(event, "tool_name", "")}, + ) + + @bus.on(FlowFinishedEvent) + def _on_finished(_source: Any, event: FlowFinishedEvent) -> None: + if not self._matches(event): + return + self._dispatch("turn_finished", {"result": getattr(event, "result", None)}) + + self._handlers = [ + _on_chunk, + _on_thinking, + _on_tool_start, + _on_tool_end, + _on_finished, + ] + + def unregister(self) -> None: + self._handlers.clear() + + def _matches(self, event: Any) -> bool: + meta = getattr(event, "fingerprint_metadata", None) or {} + if isinstance(meta, dict) and meta.get("conversation_id") == self._session_id: + return True + fp = getattr(event, "source_fingerprint", None) + return fp == self._session_id + + def _dispatch(self, msg_type: ChatMessageType, payload: dict[str, Any]) -> None: + self._seq += 1 + self._handler( + ChatMessage( + type=msg_type, + session_id=self._session_id, + payload=payload, + seq=self._seq, + ) + ) + + +def stamp_conversation_fingerprint(event: Any, session_id: str) -> None: + """Stamp ``conversation_id`` on an event before dispatch to external systems.""" + if not getattr(event, "source_fingerprint", None): + event.source_fingerprint = session_id + meta = getattr(event, "fingerprint_metadata", None) + if meta is None: + event.fingerprint_metadata = {"conversation_id": session_id} + elif isinstance(meta, dict): + meta.setdefault("conversation_id", session_id) diff --git a/lib/crewai/src/crewai/flow/conversation.py b/lib/crewai/src/crewai/flow/conversation.py new file mode 100644 index 000000000..400a0dc12 --- /dev/null +++ b/lib/crewai/src/crewai/flow/conversation.py @@ -0,0 +1,243 @@ +"""Conversational turn helpers for CrewAI Flows. + +Provides message history utilities, kickoff input normalization, and optional +class-level defaults via ``ConversationalConfig``. Session identity is ``state.id`` +(``inputs["id"]`` / ``kickoff(session_id=...)``), not a separate Flow field. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + from crewai.llms.base_llm import BaseLLM + + +TurnMode = Literal["auto", "follow_up", "initial"] + +_EXIT_COMMANDS_DEFAULT: tuple[str, ...] = ("exit", "quit") + + +class ConversationalInputs(TypedDict, total=False): + """Conventional ``kickoff(inputs=...)`` keys for chat turns.""" + + id: str + user_message: str | dict[str, Any] + last_intent: str + + +@dataclass +class ConversationalConfig: + """Optional class-level defaults for conversational flows. + + Override per kickoff via ``user_message``, ``session_id``, ``intents``, etc. + """ + + default_intents: Sequence[str] | None = None + intent_llm: str | None = None + interactive_prompt: str = "You: " + interactive_timeout: float | None = None + exit_commands: Sequence[str] = field(default_factory=lambda: _EXIT_COMMANDS_DEFAULT) + defer_trace_finalization: bool = True + + +class ChatState(BaseModel): + """Recommended persisted state shape for multi-turn flows.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + messages: list[LLMMessage] = Field(default_factory=list) + last_user_message: str | None = None + last_intent: str | None = None + session_ready: bool = False + + +def _coerce_user_message_text(user_message: str | dict[str, Any] | Any) -> str: + if isinstance(user_message, str): + return user_message + if isinstance(user_message, dict): + content = user_message.get("content") + if content is not None: + return str(content) + return str(user_message) + + +def normalize_kickoff_inputs( + inputs: dict[str, Any] | None, + *, + user_message: str | dict[str, Any] | None = None, + session_id: str | None = None, +) -> dict[str, Any]: + """Merge conversational kickoff kwargs into the inputs dict.""" + merged: dict[str, Any] = dict(inputs or {}) + + if session_id is not None: + merged["id"] = session_id + + if user_message is not None: + merged["user_message"] = user_message + elif "user_message" in merged and isinstance(merged["user_message"], str): + pass + + return merged + + +def get_conversation_messages(flow: Flow[Any]) -> list[LLMMessage]: + """Read message history from flow state or the internal fallback buffer.""" + buffer: list[LLMMessage] = getattr(flow, "_conversation_messages", []) + state = getattr(flow, "_state", None) + if state is None: + return list(buffer) + + if isinstance(state, dict): + messages = state.get("messages") + if isinstance(messages, list): + return cast(list[LLMMessage], messages) + elif isinstance(state, BaseModel) and hasattr(state, "messages"): + messages = getattr(state, "messages", None) + if isinstance(messages, list): + return cast(list[LLMMessage], messages) + + return list(buffer) + + +def append_message( + flow: Flow[Any], + role: Literal["user", "assistant", "system", "tool"], + content: str, + **extra: Any, +) -> None: + """Append a message to ``state.messages`` or the flow fallback buffer.""" + message: LLMMessage = {"role": role, "content": content} + for key, value in extra.items(): + if key in ("tool_call_id", "name", "tool_calls", "files"): + message[key] = value # type: ignore[literal-required] + + state = getattr(flow, "_state", None) + if state is not None: + if isinstance(state, dict): + messages = state.get("messages") + if isinstance(messages, list): + messages.append(message) + return + elif isinstance(state, BaseModel) and hasattr(state, "messages"): + messages = getattr(state, "messages", None) + if messages is None: + object.__setattr__(state, "messages", []) + messages = getattr(state, "messages") + if isinstance(messages, list): + messages.append(message) + return + + if not hasattr(flow, "_conversation_messages"): + object.__setattr__(flow, "_conversation_messages", []) + flow._conversation_messages.append(message) + + +def set_state_field(flow: Flow[Any], name: str, value: Any) -> None: + """Set a field on structured or dict flow state when present.""" + state = getattr(flow, "_state", None) + if state is None: + return + if isinstance(state, dict): + state[name] = value + elif isinstance(state, BaseModel) and hasattr(state, name): + object.__setattr__(state, name, value) + + +def receive_user_message( + flow: Flow[Any], + text: str, + *, + outcomes: Sequence[str] | None = None, + llm: str | BaseLLM | None = None, + metadata: dict[str, Any] | None = None, +) -> str: + """Record a user turn: append message and optionally classify intent.""" + append_message(flow, "user", text) + set_state_field(flow, "last_user_message", text) + + if outcomes and llm is not None: + intent = flow.classify_intent( + text, + outcomes, + llm=llm, + context=get_conversation_messages(flow), + ) + set_state_field(flow, "last_intent", intent) + return intent + + return text + + +def prepare_conversational_turn( + flow: Flow[Any], + *, + user_message: str | dict[str, Any] | None = None, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, + config: ConversationalConfig | None = None, +) -> None: + """Hydrate conversation state after inputs are merged into flow state.""" + if user_message is None: + state = getattr(flow, "_state", None) + if isinstance(state, dict) and "user_message" in state: + user_message = state["user_message"] + elif isinstance(state, BaseModel) and hasattr(state, "user_message"): + user_message = getattr(state, "user_message", None) + + if user_message is None: + return + + text = _coerce_user_message_text(user_message) + if not text.strip(): + return + + # Fresh classification each turn (do not reuse prior turn's route label). + set_state_field(flow, "last_intent", None) + + resolved_intents = intents + if resolved_intents is None and config is not None: + resolved_intents = config.default_intents + + resolved_llm = intent_llm + if resolved_llm is None and config is not None: + resolved_llm = config.intent_llm + + if resolved_intents: + if resolved_llm is None: + raise ValueError("intent_llm is required when intents are provided") + receive_user_message( + flow, + text, + outcomes=resolved_intents, + llm=resolved_llm, + ) + else: + receive_user_message(flow, text) + + +def input_history_to_messages(entries: Sequence[Any]) -> list[LLMMessage]: + """Convert ``Flow.input_history`` entries to LLM message format.""" + messages: list[LLMMessage] = [] + for entry in entries: + prompt = entry.get("message") if isinstance(entry, dict) else None + response = entry.get("response") if isinstance(entry, dict) else None + if prompt: + messages.append({"role": "assistant", "content": str(prompt)}) + if response: + messages.append({"role": "user", "content": str(response)}) + return messages + + +def get_conversational_config(flow: Flow[Any]) -> ConversationalConfig | None: + """Return class-level ``conversational_config`` if defined.""" + return getattr(type(flow), "conversational_config", None) diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index d22794873..bf2cc2c18 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -83,7 +83,20 @@ from crewai.events.types.flow_events import ( MethodExecutionStartedEvent, ) from crewai.flow.constants import AND_CONDITION, OR_CONDITION -from crewai.flow.flow_context import current_flow_id, current_flow_request_id +from crewai.flow.conversation import ( + ConversationalConfig, + append_message as _append_conversation_message, + get_conversation_messages, + get_conversational_config, + normalize_kickoff_inputs, + prepare_conversational_turn, + receive_user_message as _receive_user_message, +) +from crewai.flow.flow_context import ( + current_flow_id, + current_flow_name, + current_flow_request_id, +) from crewai.flow.flow_wrappers import ( FlowCondition, FlowConditions, @@ -952,6 +965,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): memory: Memory | MemoryScope | MemorySlice | None = Field(default=None) input_provider: InputProvider | None = Field(default=None) suppress_flow_events: bool = Field(default=False) + defer_trace_finalization: bool = Field( + default=False, + description=( + "When True, do not finalize the trace batch at the end of each kickoff. " + "Call finalize_session_traces() when the chat session ends." + ), + ) human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list) last_human_feedback: HumanFeedbackResult | None = Field(default=None) @@ -1073,8 +1093,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): _pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None) _human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict) _input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list) + _conversation_messages: list[dict[str, Any]] = PrivateAttr(default_factory=list) + _pending_user_message: str | dict[str, Any] | None = PrivateAttr(default=None) + _pending_intents: Sequence[str] | None = PrivateAttr(default=None) + _pending_intent_llm: str | BaseLLM | None = PrivateAttr(default=None) _state: Any = PrivateAttr(default=None) + conversational_config: ClassVar[ConversationalConfig | None] = None + def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override] class _FlowGeneric(cls): # type: ignore[valid-type,misc] pass @@ -1199,6 +1225,116 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): result: list[str] = self.memory.extract_memories(content) return result + @property + def conversation_messages(self) -> list[dict[str, Any]]: + """Message history from state or the internal conversation buffer.""" + return get_conversation_messages(self) + + @property + def input_history(self) -> list[InputHistoryEntry]: + """Read-only view of prompts and responses from ``ask()``.""" + return list(self._input_history) + + def append_message( + self, + role: Literal["user", "assistant", "system", "tool"], + content: str, + **extra: Any, + ) -> None: + """Append a message to conversation history on state or the fallback buffer.""" + _append_conversation_message(self, role, content, **extra) + + def classify_intent( + self, + text: str, + outcomes: Sequence[str], + *, + llm: str | BaseLLM, + context: Sequence[dict[str, Any]] | None = None, + ) -> str: + """Map user text to one of the given outcomes using an LLM.""" + if context: + context_blob = "\n".join( + f"{m.get('role', 'user')}: {m.get('content', '')}" for m in context + ) + feedback = f"{context_blob}\n\nLatest user message: {text}" + else: + feedback = text + return self._collapse_to_outcome(feedback, outcomes, llm) + + def receive_user_message( + self, + text: str, + *, + outcomes: Sequence[str] | None = None, + llm: str | BaseLLM | None = None, + ) -> str: + """Append a user message and optionally set ``last_intent`` on state.""" + return _receive_user_message( + self, + text, + outcomes=outcomes, + llm=llm, + ) + + def _configure_conversational_kickoff( + self, + *, + inputs: dict[str, Any] | None = None, + user_message: str | dict[str, Any] | None = None, + session_id: str | None = None, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, + ) -> dict[str, Any]: + """Store pending conversational turn options for ``kickoff_async``.""" + config = get_conversational_config(self) or self.conversational_config + resolved_intents = intents + resolved_llm = intent_llm + if config is not None: + if resolved_intents is None: + resolved_intents = config.default_intents + if resolved_llm is None: + resolved_llm = config.intent_llm + + resolved_message = user_message + if resolved_message is None and inputs and "user_message" in inputs: + resolved_message = inputs["user_message"] + + self._pending_user_message = resolved_message + self._pending_intents = list(resolved_intents) if resolved_intents else None + self._pending_intent_llm = resolved_llm + + if config is not None and config.defer_trace_finalization: + self.defer_trace_finalization = True + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + + TraceCollectionListener().batch_manager.defer_session_finalization = True + + return normalize_kickoff_inputs( + inputs, + user_message=resolved_message, + session_id=session_id, + ) + + def _clear_conversational_kickoff(self) -> None: + self._pending_user_message = None + self._pending_intents = None + self._pending_intent_llm = None + + def _apply_pending_conversational_turn(self) -> None: + if self._pending_user_message is None: + return + config = get_conversational_config(self) or self.conversational_config + prepare_conversational_turn( + self, + user_message=self._pending_user_message, + intents=self._pending_intents, + intent_llm=self._pending_intent_llm, + config=config, + ) + def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool: """Mark an OR listener as fired atomically. @@ -1532,20 +1668,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): reset_emission_counter() reset_last_event_id() - if not self.suppress_flow_events: - future = crewai_event_bus.emit( - self, - FlowStartedEvent( - type="flow_started", - flow_name=self.name or self.__class__.__name__, - inputs=None, - ), - ) - if future and isinstance(future, Future): - try: - await asyncio.wrap_future(future) - except Exception: - logger.warning("FlowStartedEvent handler failed", exc_info=True) + future = crewai_event_bus.emit( + self, + FlowStartedEvent( + type="flow_started", + flow_name=self._flow_display_name(), + inputs=None, + ), + ) + if future and isinstance(future, Future): + try: + await asyncio.wrap_future(future) + except Exception: + logger.warning("FlowStartedEvent handler failed", exc_info=True) get_env_context() @@ -1698,29 +1833,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) self._event_futures.clear() - if not self.suppress_flow_events: - future = crewai_event_bus.emit( - self, - FlowFinishedEvent( - type="flow_finished", - flow_name=self.name or self.__class__.__name__, - result=final_result, - state=self._copy_and_serialize_state(), - ), - ) - if future and isinstance(future, Future): - try: - await asyncio.wrap_future(future) - except Exception: - logger.warning("FlowFinishedEvent handler failed", exc_info=True) + if not self._should_defer_trace_finalization(): + await self._emit_flow_finished_async(final_result) - trace_listener = TraceCollectionListener() - if trace_listener.batch_manager.batch_owner_type == "flow": - if trace_listener.first_time_handler.is_first_time: - trace_listener.first_time_handler.mark_events_collected() - trace_listener.first_time_handler.handle_execution_completion() - else: - trace_listener.batch_manager.finalize_batch() + self._finalize_flow_trace_batch() return final_result @@ -2033,6 +2149,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): input_files: dict[str, FileInput] | None = None, from_checkpoint: CheckpointConfig | None = None, restore_from_state_id: str | None = None, + *, + user_message: str | dict[str, Any] | None = None, + session_id: str | None = None, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, + interactive: bool = False, + interactive_prompt: str | None = None, + interactive_timeout: float | None = None, + exit_commands: Sequence[str] | None = None, ) -> Any | FlowStreamingOutput: """Start the flow execution in a synchronous context. @@ -2052,10 +2177,49 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): If the referenced state is not found, the kickoff falls back silently to baseline behavior. Cannot be combined with ``from_checkpoint``; passing both raises ``ValueError``. + user_message: Text or ``{"role": "user", "content": "..."}`` for this + chat turn. Appended to ``state.messages`` before the graph runs. + session_id: Conversation session UUID; merged into ``inputs["id"]`` + for ``@persist`` restoration. + intents: Optional outcome labels for pre-kickoff intent classification. + intent_llm: LLM used when ``intents`` is set. + interactive: If True, run a CLI loop (``ask`` per line) until exit or + timeout. For local demos only; APIs should pass ``user_message``. + interactive_prompt: Prompt shown by ``ask()`` in interactive mode. + interactive_timeout: Per-line timeout for interactive ``ask()``. + exit_commands: Words that end interactive mode (default exit, quit). Returns: The final output from the flow or FlowStreamingOutput if streaming. """ + if interactive: + if user_message is not None: + raise ValueError( + "Cannot pass user_message with interactive=True; " + "messages are collected via ask()." + ) + if self.stream: + raise ValueError("interactive=True is not supported with stream=True") + return self._kickoff_interactive( + inputs=inputs, + input_files=input_files, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, + interactive_prompt=interactive_prompt, + interactive_timeout=interactive_timeout, + exit_commands=exit_commands, + restore_from_state_id=restore_from_state_id, + ) + + inputs = self._configure_conversational_kickoff( + inputs=inputs, + user_message=user_message, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, + ) + if from_checkpoint is not None and restore_from_state_id is not None: raise ValueError( "Cannot combine `from_checkpoint` and `restore_from_state_id`. " @@ -2064,7 +2228,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) restored = apply_checkpoint(self, from_checkpoint) if restored is not None: - return restored.kickoff(inputs=inputs, input_files=input_files) + return restored.kickoff( + inputs=inputs, + input_files=input_files, + user_message=user_message, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, + ) if self.stream: result_holder: list[Any] = [] current_task_info: TaskInfo = { @@ -2087,6 +2258,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): inputs=inputs, input_files=input_files, restore_from_state_id=restore_from_state_id, + user_message=self._pending_user_message, + session_id=inputs.get("id") if inputs else None, + intents=self._pending_intents, + intent_llm=self._pending_intent_llm, ) result_holder.append(result) except Exception as e: @@ -2110,11 +2285,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): return streaming_output async def _run_flow() -> Any: - return await self.kickoff_async( - inputs, - input_files, - restore_from_state_id=restore_from_state_id, - ) + try: + return await self.kickoff_async( + inputs, + input_files, + restore_from_state_id=restore_from_state_id, + user_message=self._pending_user_message, + session_id=inputs.get("id") if inputs else None, + intents=self._pending_intents, + intent_llm=self._pending_intent_llm, + ) + finally: + self._clear_conversational_kickoff() try: asyncio.get_running_loop() @@ -2124,12 +2306,65 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): except RuntimeError: return asyncio.run(_run_flow()) + def _kickoff_interactive( + self, + *, + inputs: dict[str, Any] | None, + input_files: dict[str, FileInput] | None, + session_id: str | None, + intents: Sequence[str] | None, + intent_llm: str | BaseLLM | None, + interactive_prompt: str | None, + interactive_timeout: float | None, + exit_commands: Sequence[str] | None, + restore_from_state_id: str | None, + ) -> Any: + config = get_conversational_config(self) or self.conversational_config + prompt = interactive_prompt or (config.interactive_prompt if config else "You: ") + timeout = ( + interactive_timeout + if interactive_timeout is not None + else (config.interactive_timeout if config else None) + ) + exits = {c.strip().lower() for c in (exit_commands or (config.exit_commands if config else ("exit", "quit")))} + sid = session_id + if sid is None and inputs and "id" in inputs: + sid = str(inputs["id"]) + if sid is None: + sid = str(uuid4()) + + last_result: Any = None + while True: + line = self.ask(prompt, timeout=timeout) + if line is None or line.strip().lower() in exits: + break + turn_inputs = self._configure_conversational_kickoff( + inputs=inputs, + user_message=line, + session_id=sid, + intents=intents, + intent_llm=intent_llm, + ) + last_result = self.kickoff( + inputs=turn_inputs, + input_files=input_files, + restore_from_state_id=restore_from_state_id, + ) + restore_from_state_id = None + self._clear_conversational_kickoff() + return last_result + async def kickoff_async( self, inputs: dict[str, Any] | None = None, input_files: dict[str, FileInput] | None = None, from_checkpoint: CheckpointConfig | None = None, restore_from_state_id: str | None = None, + *, + user_message: str | dict[str, Any] | None = None, + session_id: str | None = None, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, ) -> Any | FlowStreamingOutput: """Start the flow execution asynchronously. @@ -2150,10 +2385,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): separate persistence key. If the referenced state is not found, falls back silently to baseline. Cannot be combined with ``from_checkpoint``; passing both raises ``ValueError``. + user_message: User text for this conversational turn. + session_id: Session UUID (``inputs["id"]``). + intents: Optional labels for pre-kickoff classification. + intent_llm: LLM for classification when ``intents`` is set. Returns: The final output from the flow, which is the result of the last executed method. """ + inputs = self._configure_conversational_kickoff( + inputs=inputs, + user_message=user_message, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, + ) + if from_checkpoint is not None and restore_from_state_id is not None: raise ValueError( "Cannot combine `from_checkpoint` and `restore_from_state_id`. " @@ -2162,7 +2409,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) restored = apply_checkpoint(self, from_checkpoint) if restored is not None: - return await restored.kickoff_async(inputs=inputs, input_files=input_files) + return await restored.kickoff_async( + inputs=inputs, + input_files=input_files, + user_message=user_message, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, + ) if self.stream: result_holder: list[Any] = [] current_task_info: TaskInfo = { @@ -2215,10 +2469,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): flow_id_token = None request_id_token = None + flow_name_token = None if current_flow_id.get() is None: flow_id_token = current_flow_id.set(self.flow_id) if current_flow_request_id.get() is None: request_id_token = current_flow_request_id.set(self.flow_id) + if current_flow_name.get() is None: + flow_name_token = current_flow_name.set(self._flow_display_name()) try: # Reset flow state for fresh execution unless restoring from persistence @@ -2301,8 +2558,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): f"No flow state found for UUID: {restore_uuid}", color="red" ) - # Update state with any additional inputs (ignoring the 'id' key) - filtered_inputs = {k: v for k, v in inputs.items() if k != "id"} + # Update state with any additional inputs (ignoring conversational keys) + filtered_inputs = { + k: v + for k, v in inputs.items() + if k not in ("id", "user_message", "last_intent") + } if filtered_inputs: self._initialize_state(filtered_inputs) @@ -2310,31 +2571,51 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): reset_emission_counter() reset_last_event_id() - if not self.suppress_flow_events: - future = crewai_event_bus.emit( - self, - FlowStartedEvent( - type="flow_started", - flow_name=self.name or self.__class__.__name__, - inputs=inputs, - ), + skip_flow_started = self._should_defer_trace_finalization() and getattr( + self, "_conversation_trace_started", False + ) + if not skip_flow_started: + started_event = FlowStartedEvent( + type="flow_started", + flow_name=self._flow_display_name(), + inputs=inputs, ) + future = crewai_event_bus.emit(self, started_event) if future: try: await asyncio.wrap_future(future) except Exception: - logger.warning("FlowStartedEvent handler failed", exc_info=True) + logger.warning( + "FlowStartedEvent handler failed", exc_info=True + ) + if self._should_defer_trace_finalization(): + object.__setattr__(self, "_conversation_trace_started", True) + object.__setattr__( + self, + "_conversation_flow_started_event_id", + started_event.event_id, + ) + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + + TraceCollectionListener().batch_manager.defer_session_finalization = ( + True + ) + if not self.suppress_flow_events: self._log_flow_event( f"Flow started with ID: {self.flow_id}", color="bold magenta" ) - # After FlowStarted (when not suppressed): env events must not pre-empt - # trace batch init with implicit "crew" execution_type. + # After FlowStarted: env events must not pre-empt trace batch init + # with implicit "crew" execution_type. get_env_context() if inputs is not None and "id" not in inputs: self._initialize_state(inputs) + self._apply_pending_conversational_turn() + if self._is_execution_resuming: await self._replay_recorded_events() @@ -2428,35 +2709,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) self._event_futures.clear() - if not self.suppress_flow_events: - future = crewai_event_bus.emit( - self, - FlowFinishedEvent( - type="flow_finished", - flow_name=self.name or self.__class__.__name__, - result=final_output, - state=self._copy_and_serialize_state(), - ), - ) - if future: - try: - await asyncio.wrap_future(future) - except Exception: - logger.warning( - "FlowFinishedEvent handler failed", exc_info=True - ) + if not self._should_defer_trace_finalization(): + await self._emit_flow_finished_async(final_output) - if not self.suppress_flow_events: - trace_listener = TraceCollectionListener() - if trace_listener.batch_manager.batch_owner_type == "flow": - if trace_listener.first_time_handler.is_first_time: - trace_listener.first_time_handler.mark_events_collected() - trace_listener.first_time_handler.handle_execution_completion() - else: - trace_listener.batch_manager.finalize_batch() + self._finalize_flow_trace_batch() return final_output finally: + self._clear_conversational_kickoff() # Ensure all background memory saves complete before returning if self.memory is not None and hasattr(self.memory, "drain_writes"): self.memory.drain_writes() @@ -2464,6 +2724,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): current_flow_request_id.reset(request_id_token) if flow_id_token is not None: current_flow_id.reset(flow_id_token) + if flow_name_token is not None: + current_flow_name.reset(flow_name_token) detach(flow_token) async def akickoff( @@ -2472,6 +2734,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): input_files: dict[str, FileInput] | None = None, from_checkpoint: CheckpointConfig | None = None, restore_from_state_id: str | None = None, + *, + user_message: str | dict[str, Any] | None = None, + session_id: str | None = None, + intents: Sequence[str] | None = None, + intent_llm: str | BaseLLM | None = None, ) -> Any | FlowStreamingOutput: """Native async method to start the flow execution. Alias for kickoff_async. @@ -2483,6 +2750,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): restore_from_state_id: Optional UUID of a previously-persisted flow whose latest snapshot should hydrate this run's state. See ``kickoff_async`` for full semantics. + user_message: User text for this conversational turn. + session_id: Session UUID (``inputs["id"]``). + intents: Optional labels for pre-kickoff classification. + intent_llm: LLM for classification when ``intents`` is set. Returns: The final output from the flow, which is the result of the last executed method. @@ -2492,6 +2763,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): input_files, from_checkpoint, restore_from_state_id=restore_from_state_id, + user_message=user_message, + session_id=session_id, + intents=intents, + intent_llm=intent_llm, ) async def _replay_recorded_events(self) -> None: @@ -3354,6 +3629,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ), ) + if response: + _append_conversation_message(self, "user", response) + return response def _request_human_feedback( @@ -3554,6 +3832,99 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) return outcomes[0] + def _flow_display_name(self) -> str: + return self.name or self.__class__.__name__ + + def _should_defer_trace_finalization(self) -> bool: + if self.defer_trace_finalization: + return True + config = get_conversational_config(self) + return bool(config and config.defer_trace_finalization) + + async def _emit_flow_finished_async(self, result: Any) -> None: + """Emit ``FlowFinishedEvent`` and await handlers.""" + future = crewai_event_bus.emit( + self, + FlowFinishedEvent( + type="flow_finished", + flow_name=self._flow_display_name(), + result=result, + state=self._copy_and_serialize_state(), + ), + ) + if not future: + return + try: + if isinstance(future, Future): + await asyncio.wrap_future(future) + else: + await future + except Exception: + logger.warning("FlowFinishedEvent handler failed", exc_info=True) + + def _emit_flow_finished_sync(self, result: Any) -> None: + """Emit ``FlowFinishedEvent`` from synchronous session teardown.""" + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(self._emit_flow_finished_async(result)) + else: + raise RuntimeError( + "Cannot emit flow_finished synchronously while an event loop is running" + ) + + def finalize_session_traces(self) -> None: + """Finalize the trace batch after a multi-turn conversational session.""" + from crewai.events.event_context import restore_event_scope + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + + trace_listener = TraceCollectionListener() + batch_manager = trace_listener.batch_manager + + if batch_manager._batch_finalized or not batch_manager.is_batch_initialized(): + batch_manager.defer_session_finalization = False + object.__setattr__(self, "_conversation_trace_started", False) + object.__setattr__(self, "_conversation_flow_started_event_id", None) + return + + result = self._method_outputs[-1] if self._method_outputs else None + if ( + self._should_defer_trace_finalization() + and getattr(self, "_conversation_trace_started", False) + ): + started_id = getattr(self, "_conversation_flow_started_event_id", None) + if started_id: + restore_event_scope(((started_id, "flow_started"),)) + try: + self._emit_flow_finished_sync(result) + finally: + restore_event_scope(()) + object.__setattr__(self, "_conversation_flow_started_event_id", None) + + self._finalize_flow_trace_batch(force=True) + object.__setattr__(self, "_conversation_trace_started", False) + batch_manager.defer_session_finalization = False + + def _finalize_flow_trace_batch(self, *, force: bool = False) -> None: + """Finalize the active trace batch when this flow owns it.""" + if not force and self._should_defer_trace_finalization(): + return + + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + + trace_listener = TraceCollectionListener() + if trace_listener.batch_manager.batch_owner_type != "flow": + return + if trace_listener.first_time_handler.is_first_time: + trace_listener.first_time_handler.mark_events_collected() + trace_listener.first_time_handler.handle_execution_completion() + else: + trace_listener.batch_manager.finalize_batch() + def _log_flow_event( self, message: str, diff --git a/lib/crewai/src/crewai/flow/flow_context.py b/lib/crewai/src/crewai/flow/flow_context.py index 0ff6cf973..474360aa3 100644 --- a/lib/crewai/src/crewai/flow/flow_context.py +++ b/lib/crewai/src/crewai/flow/flow_context.py @@ -18,3 +18,7 @@ current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar( "flow_method_name", default="unknown" ) + +current_flow_name: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "flow_name", default=None +) diff --git a/lib/crewai/src/crewai/flow/providers/__init__.py b/lib/crewai/src/crewai/flow/providers/__init__.py new file mode 100644 index 000000000..9d63f5160 --- /dev/null +++ b/lib/crewai/src/crewai/flow/providers/__init__.py @@ -0,0 +1,4 @@ +from crewai.flow.providers.queue import QueueInputProvider + + +__all__ = ["QueueInputProvider"] diff --git a/lib/crewai/src/crewai/flow/providers/queue.py b/lib/crewai/src/crewai/flow/providers/queue.py new file mode 100644 index 000000000..4b2b9a12c --- /dev/null +++ b/lib/crewai/src/crewai/flow/providers/queue.py @@ -0,0 +1,82 @@ +"""Queue-backed input provider for conversational flows.""" + +from __future__ import annotations + +import queue +import threading +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + + +class QueueInputProvider: + """Blocks on a per-session queue until a user message is pushed. + + Use for long-running workers where ``Flow.ask()`` should wait on WebSocket + or another transport without blocking the event loop thread (flow runs ask + in a worker thread). + + Example: + ```python + provider = QueueInputProvider() + flow.input_provider = provider + + # From a WebSocket handler: + provider.push(session_id, "hello") + + # Inside the flow: + reply = flow.ask("You: ", metadata={"session_id": session_id}) + ``` + """ + + def __init__(self) -> None: + self._queues: dict[str, queue.Queue[str | None]] = {} + self._lock = threading.Lock() + + def _get_queue(self, session_id: str) -> queue.Queue[str | None]: + with self._lock: + if session_id not in self._queues: + self._queues[session_id] = queue.Queue() + return self._queues[session_id] + + def push(self, session_id: str, text: str) -> None: + """Enqueue a user message for the given session.""" + self._get_queue(session_id).put(text) + + def close_session(self, session_id: str) -> None: + """Signal end of session (unblocks ``ask()`` with None).""" + self._get_queue(session_id).put(None) + + def request_input( + self, + message: str, + flow: Flow[Any], + metadata: dict[str, Any] | None = None, + ) -> str | None: + session_id = self._resolve_session_id(flow, metadata) + if session_id is None: + return None + try: + return self._get_queue(session_id).get() + except Exception: + return None + + @staticmethod + def _resolve_session_id( + flow: Flow[Any], + metadata: dict[str, Any] | None, + ) -> str | None: + if metadata and metadata.get("session_id"): + return str(metadata["session_id"]) + state = getattr(flow, "_state", None) + if state is None: + return None + if isinstance(state, dict): + value = state.get("id") + return str(value) if value else None + if hasattr(state, "id"): + value = getattr(state, "id", None) + return str(value) if value else None + return None diff --git a/lib/crewai/tests/test_flow_conversation.py b/lib/crewai/tests/test_flow_conversation.py new file mode 100644 index 000000000..b613fc096 --- /dev/null +++ b/lib/crewai/tests/test_flow_conversation.py @@ -0,0 +1,480 @@ +"""Tests for conversational Flow helpers and kickoff parameters.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, Field + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.listeners.tracing.trace_listener import TraceCollectionListener +from crewai.events.types.flow_events import FlowStartedEvent +from crewai.events.types.llm_events import LLMCallStartedEvent +from crewai.flow import Flow, ChatState, listen, start +from crewai.flow.flow_context import current_flow_id, current_flow_name +from crewai.flow.conversation import ( + ConversationalConfig, + append_message, + get_conversation_messages, + normalize_kickoff_inputs, + prepare_conversational_turn, +) +from crewai.flow.chat import ChatMessage, ChatSession +from crewai.flow.providers import QueueInputProvider +from crewai.utilities.types import LLMMessage + + +class SimpleChatFlow(Flow[ChatState]): + @start() + def begin(self): + return "done" + + +class DictChatFlow(Flow): + @start() + def begin(self): + return self.state.get("marker", "ok") + + +class TestNormalizeKickoffInputs: + def test_merges_session_and_user_message(self) -> None: + merged = normalize_kickoff_inputs( + {"foo": 1}, + user_message="hello", + session_id="sess-1", + ) + assert merged["id"] == "sess-1" + assert merged["user_message"] == "hello" + assert merged["foo"] == 1 + + +class TestMessageHelpers: + def test_append_message_on_pydantic_state(self) -> None: + flow = SimpleChatFlow() + flow._state = ChatState() + append_message(flow, "user", "hi") + assert get_conversation_messages(flow) == [{"role": "user", "content": "hi"}] + + def test_append_message_fallback_buffer(self) -> None: + flow = DictChatFlow() + + class _State: + id = str(uuid4()) + + flow._state = _State() + append_message(flow, "assistant", "reply") + assert get_conversation_messages(flow) == [ + {"role": "assistant", "content": "reply"} + ] + assert flow._conversation_messages == [ + {"role": "assistant", "content": "reply"} + ] + + +class TestIntentPerTurn: + def test_prepare_clears_stale_last_intent(self) -> None: + flow = SimpleChatFlow() + flow._state = ChatState(last_intent="ORDER", messages=[]) + prepare_conversational_turn(flow, user_message="hello") + assert flow.state.last_intent is None + + +class TestKickoffConversational: + def test_kickoff_user_message_hydrates_state(self) -> None: + flow = SimpleChatFlow() + flow.kickoff(user_message="track my order", session_id="session-abc") + + assert flow.state.last_user_message == "track my order" + assert any( + m.get("role") == "user" and m.get("content") == "track my order" + for m in flow.state.messages + ) + assert flow.state.id == "session-abc" + + def test_kickoff_classifies_intent_when_configured(self) -> None: + flow = SimpleChatFlow() + + with patch.object( + flow, + "_collapse_to_outcome", + return_value="order", + ) as mock_collapse: + flow.kickoff( + user_message="where is my package", + session_id="s1", + intents=["order", "help"], + intent_llm="gpt-4o-mini", + ) + + mock_collapse.assert_called_once() + assert flow.state.last_intent == "order" + + def test_ask_appends_to_messages(self) -> None: + class AskFlow(Flow[ChatState]): + input_provider = MagicMock() + input_provider.request_input = MagicMock(return_value="user reply") + + @start() + def begin(self): + self.ask("Prompt:") + return "ok" + + flow = AskFlow() + flow._state = ChatState() + flow.kickoff() + + assert any( + m.get("role") == "user" and m.get("content") == "user reply" + for m in flow.state.messages + ) + + +class TestClassifyIntent: + def test_uses_collapse_with_context(self) -> None: + flow = SimpleChatFlow() + flow._state = ChatState( + messages=[{"role": "user", "content": "prior"}], + ) + + with patch.object(flow, "_collapse_to_outcome", return_value="help") as mock: + outcome = flow.classify_intent( + "I need help", + ["order", "help"], + llm="gpt-4o-mini", + context=flow.conversation_messages, + ) + + assert outcome == "help" + assert "I need help" in mock.call_args[0][0] + + +class TestQueueInputProvider: + def test_push_and_request_input(self) -> None: + provider = QueueInputProvider() + flow = SimpleChatFlow() + flow._state = ChatState(id="sess-q") + + provider.push("sess-q", "hello") + result = provider.request_input(">", flow, metadata={"session_id": "sess-q"}) + assert result == "hello" + + +class TestChatSession: + def test_handle_turn_returns_turn_result(self) -> None: + flow = SimpleChatFlow() + session = ChatSession( + flow, + session_id="chat-1", + intents=["order", "help"], + intent_llm="gpt-4o-mini", + ) + + with patch.object(flow, "_collapse_to_outcome", return_value="help"): + turn = session.handle_turn("hi there") + + assert turn.session_id == "chat-1" + assert turn.output == "done" + assert turn.intent == "help" + assert any(m["role"] == "user" for m in turn.messages) + session.close() + + def test_chat_message_model(self) -> None: + msg = ChatMessage( + type="assistant_delta", + session_id="x", + payload={"chunk": "hi"}, + ) + assert msg.version == "1" + assert msg.type == "assistant_delta" + + +class TestFlowTracingWhenSuppressed: + def test_flow_started_emitted_when_panel_events_suppressed(self) -> None: + class QuietFlow(Flow[ChatState]): + suppress_flow_events = True + + @start() + def begin(self) -> str: + return "ok" + + started: list[str] = [] + original_emit = crewai_event_bus.emit + + def track_emit(source: Any, event: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(event, FlowStartedEvent): + started.append(event.flow_name) + return original_emit(source, event, *args, **kwargs) + + with patch.object(crewai_event_bus, "emit", side_effect=track_emit): + QuietFlow().kickoff() + + assert started == ["QuietFlow"] + + def test_llm_action_inside_flow_claims_flow_trace_batch(self) -> None: + listener = TraceCollectionListener() + listener.batch_manager.current_batch = None + listener.batch_manager.batch_owner_type = None + listener.batch_manager.batch_owner_id = None + + flow_id_token = current_flow_id.set("flow-test-id") + flow_name_token = current_flow_name.set("DemoSupportFlow") + try: + event = LLMCallStartedEvent( + model="gpt-4o-mini", + messages=[], + call_id="call-test", + ) + listener._handle_action_event("llm_call_started", object(), event) + finally: + current_flow_id.reset(flow_id_token) + current_flow_name.reset(flow_name_token) + + assert listener.batch_manager.batch_owner_type == "flow" + assert listener.batch_manager.batch_owner_id == "flow-test-id" + assert ( + listener.batch_manager.current_batch.execution_metadata["execution_type"] + == "flow" + ) + assert ( + listener.batch_manager.current_batch.execution_metadata["flow_name"] + == "DemoSupportFlow" + ) + + +class TestDeferTraceFinalization: + def test_conversational_kickoff_enables_defer_flag(self) -> None: + class ChatFlow(Flow[ChatState]): + conversational_config = ConversationalConfig( + defer_trace_finalization=True + ) + + @start() + def begin(self) -> str: + return "ok" + + flow = ChatFlow() + flow._configure_conversational_kickoff( + user_message="hi", + session_id="sess-trace", + ) + assert flow.defer_trace_finalization is True + assert flow._should_defer_trace_finalization() is True + + def test_finalize_skipped_until_forced(self) -> None: + flow = SimpleChatFlow() + flow.defer_trace_finalization = True + + with patch( + "crewai.events.listeners.tracing.trace_listener.TraceCollectionListener" + ) as mock_listener_cls: + mock_listener_cls.return_value.batch_manager.batch_owner_type = "flow" + mock_listener_cls.return_value.first_time_handler.is_first_time = False + + flow._finalize_flow_trace_batch() + mock_listener_cls.assert_not_called() + + flow._finalize_flow_trace_batch(force=True) + mock_listener_cls.assert_called_once() + + +class TestDeferredFlowLifecycleEvents: + def test_deferred_kickoff_skips_per_turn_flow_finished(self) -> None: + class ChatFlow(Flow[ChatState]): + conversational_config = ConversationalConfig( + defer_trace_finalization=True + ) + + @start() + def begin(self) -> str: + return "ok" + + flow = ChatFlow() + with patch.object(flow, "_emit_flow_finished_async") as mock_finished: + flow.kickoff(user_message="hi", session_id="sess-lifecycle") + mock_finished.assert_not_called() + + def test_flow_finished_without_flow_started_warns(self, capsys) -> None: + from crewai.events.event_bus import crewai_event_bus + from crewai.events.event_context import restore_event_scope + from crewai.events.types.flow_events import FlowFinishedEvent + + class BareFlow(Flow[ChatState]): + @start() + def begin(self) -> str: + return "ok" + + restore_event_scope(()) + flow = BareFlow() + crewai_event_bus.emit( + flow, + FlowFinishedEvent( + type="flow_finished", + flow_name="BareFlow", + result="ok", + state={}, + ), + ) + captured = capsys.readouterr().out + assert "flow_finished" in captured + assert "Missing starting event" in captured + + def test_finalize_session_restores_flow_started_scope(self, capsys) -> None: + from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch + + class ChatFlow(Flow[ChatState]): + conversational_config = ConversationalConfig( + defer_trace_finalization=True + ) + + @start() + def begin(self) -> str: + return "ok" + + flow = ChatFlow() + flow.defer_trace_finalization = True + object.__setattr__(flow, "_conversation_trace_started", True) + object.__setattr__(flow, "_conversation_flow_started_event_id", "start-evt-1") + flow._method_outputs.append("ok") + + listener = TraceCollectionListener() + listener.batch_manager.batch_owner_type = "flow" + listener.batch_manager.current_batch = TraceBatch( + execution_metadata={"execution_type": "flow", "flow_name": "ChatFlow"}, + ) + listener.batch_manager.defer_session_finalization = True + listener.batch_manager._batch_finalized = False + + with patch.object(flow, "_finalize_flow_trace_batch") as mock_finalize: + flow.finalize_session_traces() + + captured = capsys.readouterr().out + assert "Missing starting event" not in captured + mock_finalize.assert_called_once_with(force=True) + assert listener.batch_manager.defer_session_finalization is False + + def test_finalize_batch_is_idempotent(self) -> None: + from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager + + with patch( + "crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context", + return_value=True, + ): + bm = TraceBatchManager() + bm.current_batch = bm.initialize_batch( + user_context={"privacy_level": "standard"}, + execution_metadata={"execution_type": "flow", "flow_name": "ChatFlow"}, + ) + bm.trace_batch_id = "batch-idempotent" + bm.backend_initialized = True + + with ( + patch.object( + bm.plus_api, + "send_trace_events", + return_value=MagicMock(status_code=200), + ), + patch.object( + bm.plus_api, + "finalize_trace_batch", + return_value=MagicMock(status_code=200, json=MagicMock(return_value={})), + ) as mock_finalize_api, + ): + bm.finalize_batch() + bm.finalize_batch() + + assert mock_finalize_api.call_count == 1 + assert bm._batch_finalized is True + + def test_finalize_session_is_idempotent_after_batch_cleared(self) -> None: + class ChatFlow(Flow[ChatState]): + @start() + def begin(self) -> str: + return "ok" + + flow = ChatFlow() + flow.defer_trace_finalization = True + object.__setattr__(flow, "_conversation_trace_started", True) + + listener = TraceCollectionListener() + listener.batch_manager.current_batch = None + listener.batch_manager.batch_owner_type = None + listener.batch_manager.trace_batch_id = None + listener.batch_manager._batch_finalized = True + + with patch.object(flow, "_emit_flow_finished_sync") as mock_finished: + with patch.object(flow, "_finalize_flow_trace_batch") as mock_finalize: + flow.finalize_session_traces() + flow.finalize_session_traces() + + mock_finished.assert_not_called() + mock_finalize.assert_not_called() + + def test_sigint_skips_deferred_session_batch(self) -> None: + from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch + + listener = TraceCollectionListener() + listener.batch_manager.current_batch = TraceBatch() + listener.batch_manager.defer_session_finalization = True + + with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize: + if listener.batch_manager.is_batch_initialized(): + if not listener.batch_manager.defer_session_finalization: + listener.batch_manager.finalize_batch() + mock_finalize.assert_not_called() + + +class TestNestedCrewTracing: + def test_is_inside_active_flow_context_when_kickoff_running(self) -> None: + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + from crewai.flow.flow_context import current_flow_id + + assert TraceCollectionListener._is_inside_active_flow_context() is False + token = current_flow_id.set("parent-flow-id") + try: + assert TraceCollectionListener._is_inside_active_flow_context() is True + finally: + current_flow_id.reset(token) + + def test_nested_crew_completion_skips_finalize(self) -> None: + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + from crewai.flow.flow_context import current_flow_id + + listener = TraceCollectionListener() + listener.batch_manager.batch_owner_type = "crew" + + token = current_flow_id.set("parent-flow-id") + try: + with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize: + if listener._nested_in_flow_execution(): + pass + elif listener.batch_manager.batch_owner_type == "crew": + listener.batch_manager.finalize_batch() + mock_finalize.assert_not_called() + finally: + current_flow_id.reset(token) + + def test_flow_owned_batch_skips_finalize_without_flow_context(self) -> None: + from crewai.events.listeners.tracing.trace_listener import ( + TraceCollectionListener, + ) + from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch + + listener = TraceCollectionListener() + listener.batch_manager.batch_owner_type = "flow" + listener.batch_manager.current_batch = TraceBatch( + execution_metadata={"execution_type": "flow", "flow_name": "Demo"}, + ) + + with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize: + if listener._nested_in_flow_execution(): + pass + elif listener.batch_manager.batch_owner_type == "crew": + listener.batch_manager.finalize_batch() + mock_finalize.assert_not_called() diff --git a/lib/crewai/tests/tracing/test_tracing.py b/lib/crewai/tests/tracing/test_tracing.py index 723904a8f..8c8f18c13 100644 --- a/lib/crewai/tests/tracing/test_tracing.py +++ b/lib/crewai/tests/tracing/test_tracing.py @@ -884,6 +884,49 @@ class TestTraceListenerSetup: "test_batch_id_12345", "Internal Server Error" ) + def test_finalize_batch_clears_buffer_after_successful_send(self) -> None: + """Successful send must not restore a stale event buffer (duplicate events).""" + from crewai.events.listeners.tracing.types import TraceEvent + + with patch( + "crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context", + return_value=True, + ): + batch_manager = TraceBatchManager() + batch_manager.current_batch = batch_manager.initialize_batch( + user_context={"privacy_level": "standard"}, + execution_metadata={ + "execution_type": "flow", + "flow_name": "TestFlow", + }, + ) + batch_manager.trace_batch_id = "batch-clear-test" + batch_manager.backend_initialized = True + batch_manager.event_buffer = [ + TraceEvent( + type="llm_call_started", + timestamp="2026-01-01T00:00:00", + event_id="evt-1", + emission_sequence=1, + ) + ] + + with ( + patch.object( + batch_manager.plus_api, + "send_trace_events", + return_value=MagicMock(status_code=200), + ), + patch.object( + batch_manager.plus_api, + "finalize_trace_batch", + return_value=MagicMock(status_code=200, json=MagicMock(return_value={})), + ), + ): + batch_manager.finalize_batch() + + assert batch_manager.event_buffer == [] + def test_ephemeral_batch_includes_anon_id(self): """Test that ephemeral batch initialization sends anon_id from get_user_id()""" fake_user_id = "abc123def456"