mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
feat: add conversational flows documentation and chat session support
- Introduced a new guide for building multi-turn chat applications using , detailing session management and message handling. - Added class to facilitate chat interactions, including streaming support and event handling. - Implemented for class-level defaults and improved input normalization for conversational turns. - Enhanced event listeners to manage flow events and tracing more effectively, including support for nested crew executions. - Added tests for conversational flow helpers and kickoff parameters to ensure functionality and reliability.
This commit is contained in:
150
docs/en/guides/flows/conversational-flows.mdx
Normal file
150
docs/en/guides/flows/conversational-flows.mdx
Normal file
@@ -0,0 +1,150 @@
|
||||
---
|
||||
title: Conversational Flows
|
||||
description: Build multi-turn chat apps with kickoff per turn, message history, and intent routing.
|
||||
icon: comments
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## One entry point: `kickoff`
|
||||
|
||||
Chat apps should use **`flow.kickoff(user_message=..., session_id=...)`** for each user line. Do not add a separate `chat()` method — session identity is `state.id` (same as `inputs["id"]` with `@persist`).
|
||||
|
||||
| API | Use for |
|
||||
|-----|---------|
|
||||
| `kickoff(user_message=..., session_id=...)` | Each new user message (API, WebSocket turn, CLI loop) |
|
||||
| `ask()` | Blocking prompt **inside** one step (wizard, clarification) |
|
||||
| `@human_feedback` | Approve/reject **a step output** before continuing — not the next chat line |
|
||||
|
||||
## Quick start
|
||||
|
||||
```python
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.flow import ChatState, Flow, listen, persist, router, start
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
|
||||
|
||||
@persist(SQLiteFlowPersistence())
|
||||
class SupportFlow(Flow[ChatState]):
|
||||
@start()
|
||||
def bootstrap(self):
|
||||
if self.state.session_ready:
|
||||
return "ready"
|
||||
# load permissions once per session
|
||||
self.state.session_ready = True
|
||||
return "ready"
|
||||
|
||||
@router(bootstrap)
|
||||
def route(self):
|
||||
if self.state.last_intent:
|
||||
return self.state.last_intent
|
||||
return self.classify_intent(
|
||||
self.state.last_user_message,
|
||||
outcomes=["order", "help", "goodbye"],
|
||||
llm="gpt-4o-mini",
|
||||
context=self.conversation_messages,
|
||||
)
|
||||
|
||||
@listen("order")
|
||||
def handle_order(self):
|
||||
reply = "Your order is on the way."
|
||||
self.append_message("assistant", reply)
|
||||
return reply
|
||||
|
||||
@listen("help")
|
||||
def handle_help(self):
|
||||
reply = "How can I help?"
|
||||
self.append_message("assistant", reply)
|
||||
return reply
|
||||
|
||||
|
||||
session_id = str(uuid4())
|
||||
flow = SupportFlow()
|
||||
|
||||
# Turn 1
|
||||
flow.kickoff(user_message="Where is my order?", session_id=session_id)
|
||||
|
||||
# Turn 2 — same session, flow finished after turn 1 is normal
|
||||
flow.kickoff(user_message="What about returns?", session_id=session_id)
|
||||
```
|
||||
|
||||
## When the flow finishes but the user keeps chatting
|
||||
|
||||
`FlowFinished` means **this graph run** completed. The conversation continues with **another `kickoff`** and the same `session_id`. `@persist` restores `messages`, permissions flags, and context.
|
||||
|
||||
For multi-turn chat, prefer **`@persist` on a single terminal step** (for example `finalize`) rather than on the whole `Flow` class. Class-level persist saves after every method; `load_state` uses the latest row, which is often a mid-run snapshot (for example right after `bootstrap`) and can omit handler updates from the same turn.
|
||||
|
||||
Do **not** use `@human_feedback` for follow-up questions unless a human must approve a specific payload before it is shown.
|
||||
|
||||
## Tracing across turns
|
||||
|
||||
By default, `ConversationalConfig(defer_trace_finalization=True)` keeps **one trace batch** for the whole chat session instead of finalizing after every `kickoff()`. Call `flow.finalize_session_traces()` when the user leaves (or use `ChatSession.close()`, which does this automatically).
|
||||
|
||||
```python
|
||||
flow.kickoff(user_message="Hello", session_id=session_id)
|
||||
flow.kickoff(user_message="Track my order", session_id=session_id)
|
||||
flow.finalize_session_traces() # one link for the full conversation
|
||||
```
|
||||
|
||||
Crews kicked off **inside** a flow method (for example a research step) append their events to the **parent flow batch**. `CrewKickoffCompleted` does not finalize that batch (even when crew worker threads lose `current_flow_id` context); finalization happens in `finalize_session_traces()`.
|
||||
|
||||
Per-turn `flow_finished` is also deferred: only `flow_started` opens the session scope on the first turn, and `flow_finished` runs once in `finalize_session_traces()` so the event bus does not warn about a missing `flow_started`.
|
||||
|
||||
## Kickoff parameters
|
||||
|
||||
| Parameter | Purpose |
|
||||
|-----------|---------|
|
||||
| `user_message` | This turn's user text (also `inputs["user_message"]`) |
|
||||
| `session_id` | Conversation UUID → `inputs["id"]` / `state.id` |
|
||||
| `intents` | Optional labels for pre-kickoff `classify_intent` |
|
||||
| `intent_llm` | LLM for classification (required with `intents`) |
|
||||
| `interactive=True` | CLI demo loop via `ask()` (not for production APIs) |
|
||||
|
||||
Class-level defaults:
|
||||
|
||||
```python
|
||||
from crewai.flow import ConversationalConfig, Flow
|
||||
|
||||
|
||||
class MyFlow(Flow[ChatState]):
|
||||
conversational_config = ConversationalConfig(
|
||||
default_intents=["order", "help"],
|
||||
intent_llm="gpt-4o-mini",
|
||||
)
|
||||
```
|
||||
|
||||
## Helpers on `Flow`
|
||||
|
||||
- `append_message(role, content)` — update `state.messages`
|
||||
- `conversation_messages` — history for LLM calls
|
||||
- `classify_intent(text, outcomes, llm=..., context=...)` — route labels (same logic as `@human_feedback` collapse)
|
||||
- `receive_user_message(text, ...)` — append user line + optional classify
|
||||
- `input_history` — audit trail from `ask()`
|
||||
|
||||
Recommended state shape: `ChatState` (`id`, `messages`, `last_user_message`, `last_intent`, `session_ready`).
|
||||
|
||||
## ChatSession (WebSocket / SSE bridge)
|
||||
|
||||
For UIs, use `ChatSession` to wrap kickoff and map events to `ChatMessage`:
|
||||
|
||||
```python
|
||||
from crewai.flow import ChatSession
|
||||
|
||||
|
||||
def on_event(msg):
|
||||
print(msg.type, msg.payload)
|
||||
|
||||
|
||||
session = ChatSession(flow, session_id="channel-1", on_event=on_event)
|
||||
turn = session.handle_turn("Hello")
|
||||
print(turn.output, turn.intent)
|
||||
session.close()
|
||||
```
|
||||
|
||||
`QueueInputProvider` supports blocking `ask()` fed by a WebSocket handler (`provider.push(session_id, text)`).
|
||||
|
||||
## Streaming
|
||||
|
||||
Enable `stream = True` on the Flow class and use `kickoff` / `ChatSession.handle_turn(..., stream=True)` to emit `assistant_delta` events through `ConversationEventBridge`.
|
||||
@@ -320,20 +320,28 @@ class EventListener(BaseEventListener):
|
||||
self._telemetry.flow_execution_span(
|
||||
event.flow_name, list(source._methods.keys())
|
||||
)
|
||||
self.formatter.handle_flow_created(event.flow_name, str(source.flow_id))
|
||||
self.formatter.handle_flow_started(event.flow_name, str(source.flow_id))
|
||||
if not getattr(source, "suppress_flow_events", False):
|
||||
self.formatter.handle_flow_created(
|
||||
event.flow_name, str(source.flow_id)
|
||||
)
|
||||
self.formatter.handle_flow_started(
|
||||
event.flow_name, str(source.flow_id)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
|
||||
self.formatter.handle_flow_status(
|
||||
event.flow_name,
|
||||
source.flow_id,
|
||||
)
|
||||
if not getattr(source, "suppress_flow_events", False):
|
||||
self.formatter.handle_flow_status(
|
||||
event.flow_name,
|
||||
source.flow_id,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_execution_started(
|
||||
_: Any, event: MethodExecutionStartedEvent
|
||||
source: Any, event: MethodExecutionStartedEvent
|
||||
) -> None:
|
||||
if getattr(source, "suppress_flow_events", False):
|
||||
return
|
||||
self.formatter.handle_method_status(
|
||||
event.method_name,
|
||||
"running",
|
||||
@@ -341,8 +349,10 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_execution_finished(
|
||||
_: Any, event: MethodExecutionFinishedEvent
|
||||
source: Any, event: MethodExecutionFinishedEvent
|
||||
) -> None:
|
||||
if getattr(source, "suppress_flow_events", False):
|
||||
return
|
||||
self.formatter.handle_method_status(
|
||||
event.method_name,
|
||||
"completed",
|
||||
|
||||
@@ -222,6 +222,8 @@ To enable tracing later, do any one of these:
|
||||
return
|
||||
self.batch_manager.batch_owner_type = None
|
||||
self.batch_manager.batch_owner_id = None
|
||||
self.batch_manager.defer_session_finalization = False
|
||||
self.batch_manager._batch_finalized = False
|
||||
self.batch_manager.current_batch = None
|
||||
self.batch_manager.event_buffer.clear()
|
||||
self.batch_manager.trace_batch_id = None
|
||||
|
||||
@@ -70,6 +70,8 @@ class TraceBatchManager:
|
||||
self.execution_start_times: dict[str, datetime] = {}
|
||||
self.batch_owner_type: str | None = None
|
||||
self.batch_owner_id: str | None = None
|
||||
self.defer_session_finalization: bool = False
|
||||
self._batch_finalized: bool = False
|
||||
self.backend_initialized: bool = False
|
||||
self.ephemeral_trace_url: str | None = None
|
||||
try:
|
||||
@@ -101,6 +103,7 @@ class TraceBatchManager:
|
||||
user_context=user_context, execution_metadata=execution_metadata
|
||||
)
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
self._batch_finalized = False
|
||||
|
||||
self.record_start_time("execution")
|
||||
|
||||
@@ -312,6 +315,9 @@ class TraceBatchManager:
|
||||
def finalize_batch(self) -> TraceBatch | None:
|
||||
"""Finalize batch and return it for sending"""
|
||||
|
||||
if self._batch_finalized:
|
||||
return None
|
||||
|
||||
if not self.current_batch or not is_tracing_enabled_in_context():
|
||||
return None
|
||||
|
||||
@@ -340,10 +346,8 @@ class TraceBatchManager:
|
||||
self.current_batch.events = sorted_events
|
||||
events_sent_count = len(sorted_events)
|
||||
if sorted_events:
|
||||
original_buffer = self.event_buffer
|
||||
self.event_buffer = sorted_events
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
self.event_buffer = original_buffer
|
||||
if events_sent_to_backend_status == 500 and self.trace_batch_id:
|
||||
self._mark_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
@@ -360,6 +364,7 @@ class TraceBatchManager:
|
||||
self.event_buffer.clear()
|
||||
self.trace_batch_id = None
|
||||
self.is_current_batch_ephemeral = False
|
||||
self._batch_finalized = True
|
||||
|
||||
self._cleanup_batch_data()
|
||||
|
||||
@@ -371,7 +376,7 @@ class TraceBatchManager:
|
||||
Args:
|
||||
events_count: Number of events that were successfully sent
|
||||
"""
|
||||
if not self.plus_api or not self.trace_batch_id:
|
||||
if self._batch_finalized or not self.plus_api or not self.trace_batch_id:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -390,6 +395,7 @@ class TraceBatchManager:
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._batch_finalized = True
|
||||
access_code = response.json().get("access_code", None)
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Trace collection listener for orchestrating trace collection."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import os
|
||||
from typing import Any, ClassVar
|
||||
import uuid
|
||||
@@ -264,18 +265,18 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
@event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
if self.batch_manager.batch_owner_type != "flow":
|
||||
# Always call _initialize_crew_batch to claim ownership.
|
||||
# If batch was already initialized by a concurrent action event
|
||||
# (e.g. LLM/tool before crew_kickoff_started), initialize_batch()
|
||||
# returns early but batch_owner_type is still correctly set to "crew".
|
||||
# Skip only when a parent flow already owns the batch.
|
||||
# Nested crew inside Flow.kickoff: never claim an existing flow session batch.
|
||||
if not self._nested_in_flow_execution() and (
|
||||
not self.batch_manager.is_batch_initialized()
|
||||
):
|
||||
self._initialize_crew_batch(source, event)
|
||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||
|
||||
@event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
if self._nested_in_flow_execution():
|
||||
return
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
@@ -286,10 +287,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
|
||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||
if self._nested_in_flow_execution():
|
||||
return
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
elif self.batch_manager.batch_owner_type == "crew":
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(TaskStartedEvent)
|
||||
@@ -708,8 +711,34 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@on_signal
|
||||
def handle_signal(source: Any, event: SignalEvent) -> None:
|
||||
"""Flush trace batch on system signals to prevent data loss."""
|
||||
if self.batch_manager.is_batch_initialized():
|
||||
self.batch_manager.finalize_batch()
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
return
|
||||
# Multi-turn flows defer batch finalization to finalize_session_traces().
|
||||
if self.batch_manager.defer_session_finalization:
|
||||
return
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@staticmethod
|
||||
def _is_inside_active_flow_context() -> bool:
|
||||
"""True when ``kickoff_async`` has set ``current_flow_id`` (nested crew)."""
|
||||
from crewai.flow.flow_context import current_flow_id
|
||||
|
||||
return current_flow_id.get() is not None
|
||||
|
||||
def _flow_owns_trace_batch(self) -> bool:
|
||||
"""True when an in-flight conversational flow already owns the trace batch."""
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
return True
|
||||
batch = self.batch_manager.current_batch
|
||||
if batch is not None:
|
||||
return batch.execution_metadata.get("execution_type") == "flow"
|
||||
return False
|
||||
|
||||
def _nested_in_flow_execution(self) -> bool:
|
||||
"""True when a crew runs inside a flow session (context or batch ownership)."""
|
||||
return (
|
||||
self._is_inside_active_flow_context() or self._flow_owns_trace_batch()
|
||||
)
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Initialize trace batch.
|
||||
@@ -730,6 +759,31 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _try_initialize_flow_batch_from_context(self, event: Any) -> bool:
|
||||
"""Claim a flow trace batch when an action event fires inside kickoff.
|
||||
|
||||
Flows with ``suppress_flow_events=True`` skip ``FlowStartedEvent``, so
|
||||
LLM/tool events must not fall back to implicit crew batches.
|
||||
"""
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_name
|
||||
|
||||
flow_id = current_flow_id.get()
|
||||
if flow_id is None:
|
||||
return False
|
||||
|
||||
started_at = getattr(event, "timestamp", None) or datetime.now(timezone.utc)
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
"flow_name": current_flow_name.get() or "Unknown Flow",
|
||||
"execution_start": started_at,
|
||||
"crewai_version": get_crewai_version(),
|
||||
"execution_type": "flow",
|
||||
}
|
||||
self.batch_manager.batch_owner_type = "flow"
|
||||
self.batch_manager.batch_owner_id = flow_id
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
return True
|
||||
|
||||
def _initialize_flow_batch(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Initialize trace batch for Flow execution.
|
||||
|
||||
@@ -794,12 +848,19 @@ class TraceCollectionListener(BaseEventListener):
|
||||
event: Event object.
|
||||
"""
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
"crew_name": getattr(source, "name", "Unknown Crew"),
|
||||
"crewai_version": get_crewai_version(),
|
||||
}
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
if self._try_initialize_flow_batch_from_context(event):
|
||||
pass
|
||||
elif not self._nested_in_flow_execution():
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
"crew_name": getattr(source, "name", "Unknown Crew"),
|
||||
"crewai_version": get_crewai_version(),
|
||||
}
|
||||
self.batch_manager.batch_owner_type = "crew"
|
||||
self.batch_manager.batch_owner_id = getattr(
|
||||
source, "id", str(uuid.uuid4())
|
||||
)
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
|
||||
@@ -4,12 +4,24 @@ from crewai.flow.async_feedback import (
|
||||
HumanFeedbackProvider,
|
||||
PendingFeedbackContext,
|
||||
)
|
||||
from crewai.flow.chat import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ConversationEventBridge,
|
||||
TurnResult,
|
||||
)
|
||||
from crewai.flow.conversation import (
|
||||
ChatState,
|
||||
ConversationalConfig,
|
||||
ConversationalInputs,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.flow_config import flow_config
|
||||
from crewai.flow.flow_serializer import flow_structure
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback
|
||||
from crewai.flow.input_provider import InputProvider, InputResponse
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.providers import QueueInputProvider
|
||||
from crewai.flow.visualization import (
|
||||
FlowStructure,
|
||||
build_flow_structure,
|
||||
@@ -18,7 +30,13 @@ from crewai.flow.visualization import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"ChatSession",
|
||||
"ChatState",
|
||||
"ConsoleProvider",
|
||||
"ConversationEventBridge",
|
||||
"ConversationalConfig",
|
||||
"ConversationalInputs",
|
||||
"Flow",
|
||||
"FlowStructure",
|
||||
"HumanFeedbackPending",
|
||||
@@ -27,6 +45,8 @@ __all__ = [
|
||||
"InputProvider",
|
||||
"InputResponse",
|
||||
"PendingFeedbackContext",
|
||||
"QueueInputProvider",
|
||||
"TurnResult",
|
||||
"and_",
|
||||
"build_flow_structure",
|
||||
"flow_config",
|
||||
|
||||
307
lib/crewai/src/crewai/flow/chat.py
Normal file
307
lib/crewai/src/crewai/flow/chat.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Transport-agnostic chat session bridge for conversational flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.conversation import get_conversation_messages, get_conversational_config
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.types.streaming import FlowStreamingOutput
|
||||
|
||||
|
||||
ChatMessageType = Literal[
|
||||
"user_message",
|
||||
"assistant_delta",
|
||||
"assistant_done",
|
||||
"turn_started",
|
||||
"turn_finished",
|
||||
"error",
|
||||
"tool_started",
|
||||
"tool_finished",
|
||||
]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Versioned wire format for chat UIs (WebSocket, SSE, webhooks)."""
|
||||
|
||||
version: str = "1"
|
||||
type: ChatMessageType
|
||||
session_id: str
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
seq: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnResult:
|
||||
"""Outcome of a single conversational turn."""
|
||||
|
||||
session_id: str
|
||||
output: Any
|
||||
intent: str | None = None
|
||||
messages: list[LLMMessage] = field(default_factory=list)
|
||||
streaming: FlowStreamingOutput | None = None
|
||||
|
||||
|
||||
class ChatSession:
|
||||
"""Wraps ``Flow.kickoff`` for one chat session (``state.id``)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flow: Flow[Any],
|
||||
session_id: str | None = None,
|
||||
*,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
on_event: Callable[[ChatMessage], None] | None = None,
|
||||
) -> None:
|
||||
self._flow = flow
|
||||
self._session_id = session_id or str(uuid4())
|
||||
self._intents = list(intents) if intents else None
|
||||
self._intent_llm = intent_llm
|
||||
self._on_event = on_event
|
||||
self._seq = 0
|
||||
self._bridge: ConversationEventBridge | None = None
|
||||
config = get_conversational_config(flow)
|
||||
if config is not None and config.defer_trace_finalization:
|
||||
flow.defer_trace_finalization = True
|
||||
if on_event is not None:
|
||||
self._bridge = ConversationEventBridge(
|
||||
session_id=self._session_id,
|
||||
handler=on_event,
|
||||
)
|
||||
self._bridge.register()
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
def handle_turn(
|
||||
self,
|
||||
user_message: str,
|
||||
*,
|
||||
stream: bool | None = None,
|
||||
) -> TurnResult:
|
||||
"""Run one conversational turn and return output plus message history."""
|
||||
self._emit("turn_started", {"user_message": user_message})
|
||||
use_stream = stream if stream is not None else bool(self._flow.stream)
|
||||
|
||||
try:
|
||||
result = self._flow.kickoff(
|
||||
user_message=user_message,
|
||||
session_id=self._session_id,
|
||||
intents=self._intents,
|
||||
intent_llm=self._intent_llm,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._emit("error", {"message": str(exc)})
|
||||
raise
|
||||
|
||||
streaming = None
|
||||
output: Any = result
|
||||
if use_stream and hasattr(result, "__iter__"):
|
||||
from crewai.types.streaming import FlowStreamingOutput
|
||||
|
||||
if isinstance(result, FlowStreamingOutput):
|
||||
streaming = result
|
||||
for chunk in result:
|
||||
text = getattr(chunk, "content", None) or str(chunk)
|
||||
self._emit("assistant_delta", {"chunk": text})
|
||||
output = result.result
|
||||
else:
|
||||
for chunk in result:
|
||||
text = getattr(chunk, "content", None) or str(chunk)
|
||||
self._emit("assistant_delta", {"chunk": text})
|
||||
|
||||
intent = None
|
||||
state = self._flow.state
|
||||
if hasattr(state, "last_intent"):
|
||||
intent = getattr(state, "last_intent", None)
|
||||
elif isinstance(state, dict):
|
||||
intent = state.get("last_intent")
|
||||
|
||||
messages = get_conversation_messages(self._flow)
|
||||
self._emit(
|
||||
"assistant_done",
|
||||
{"output": output, "intent": intent},
|
||||
)
|
||||
self._emit("turn_finished", {"output": output})
|
||||
|
||||
return TurnResult(
|
||||
session_id=self._session_id,
|
||||
output=output,
|
||||
intent=intent,
|
||||
messages=messages,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def iter_turn_stream(
|
||||
self,
|
||||
user_message: str,
|
||||
) -> Iterator[ChatMessage]:
|
||||
"""Run a streaming turn and yield ``ChatMessage`` events."""
|
||||
collected: list[ChatMessage] = []
|
||||
|
||||
def _collect(msg: ChatMessage) -> None:
|
||||
collected.append(msg)
|
||||
|
||||
prior = self._on_event
|
||||
self._on_event = _collect
|
||||
if self._bridge is None:
|
||||
self._bridge = ConversationEventBridge(
|
||||
session_id=self._session_id,
|
||||
handler=_collect,
|
||||
)
|
||||
self._bridge.register()
|
||||
try:
|
||||
self.handle_turn(user_message, stream=True)
|
||||
finally:
|
||||
self._on_event = prior
|
||||
yield from collected
|
||||
|
||||
def close(self) -> None:
|
||||
if self._bridge is not None:
|
||||
self._bridge.unregister()
|
||||
self._bridge = None
|
||||
if self._flow._should_defer_trace_finalization():
|
||||
self._flow.finalize_session_traces()
|
||||
|
||||
def _emit(self, msg_type: ChatMessageType, payload: dict[str, Any]) -> None:
|
||||
if self._on_event is None:
|
||||
return
|
||||
self._seq += 1
|
||||
self._on_event(
|
||||
ChatMessage(
|
||||
type=msg_type,
|
||||
session_id=self._session_id,
|
||||
payload=payload,
|
||||
seq=self._seq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ConversationEventBridge:
|
||||
"""Maps CrewAI bus events to ``ChatMessage`` for a session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
handler: Callable[[ChatMessage], None],
|
||||
) -> None:
|
||||
self._session_id = session_id
|
||||
self._handler = handler
|
||||
self._seq = 0
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
def register(self) -> None:
|
||||
from crewai.events import crewai_event_bus
|
||||
from crewai.events.types.flow_events import FlowFinishedEvent
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
LLMThinkingChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
bus = crewai_event_bus
|
||||
|
||||
@bus.on(LLMStreamChunkEvent)
|
||||
def _on_chunk(_source: Any, event: LLMStreamChunkEvent) -> None:
|
||||
if not self._matches(event):
|
||||
return
|
||||
chunk = getattr(event, "chunk", None)
|
||||
if chunk:
|
||||
self._dispatch(
|
||||
"assistant_delta",
|
||||
{"chunk": chunk, "agent_role": getattr(event, "agent_role", "")},
|
||||
)
|
||||
|
||||
@bus.on(LLMThinkingChunkEvent)
|
||||
def _on_thinking(_source: Any, event: LLMThinkingChunkEvent) -> None:
|
||||
if not self._matches(event):
|
||||
return
|
||||
chunk = getattr(event, "chunk", None)
|
||||
if chunk:
|
||||
self._dispatch(
|
||||
"assistant_delta",
|
||||
{
|
||||
"chunk": chunk,
|
||||
"thinking": True,
|
||||
"agent_role": getattr(event, "agent_role", ""),
|
||||
},
|
||||
)
|
||||
|
||||
@bus.on(ToolUsageStartedEvent)
|
||||
def _on_tool_start(_source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
if not self._matches(event):
|
||||
return
|
||||
self._dispatch(
|
||||
"tool_started",
|
||||
{"tool_name": getattr(event, "tool_name", "")},
|
||||
)
|
||||
|
||||
@bus.on(ToolUsageFinishedEvent)
|
||||
def _on_tool_end(_source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
if not self._matches(event):
|
||||
return
|
||||
self._dispatch(
|
||||
"tool_finished",
|
||||
{"tool_name": getattr(event, "tool_name", "")},
|
||||
)
|
||||
|
||||
@bus.on(FlowFinishedEvent)
|
||||
def _on_finished(_source: Any, event: FlowFinishedEvent) -> None:
|
||||
if not self._matches(event):
|
||||
return
|
||||
self._dispatch("turn_finished", {"result": getattr(event, "result", None)})
|
||||
|
||||
self._handlers = [
|
||||
_on_chunk,
|
||||
_on_thinking,
|
||||
_on_tool_start,
|
||||
_on_tool_end,
|
||||
_on_finished,
|
||||
]
|
||||
|
||||
def unregister(self) -> None:
|
||||
self._handlers.clear()
|
||||
|
||||
def _matches(self, event: Any) -> bool:
|
||||
meta = getattr(event, "fingerprint_metadata", None) or {}
|
||||
if isinstance(meta, dict) and meta.get("conversation_id") == self._session_id:
|
||||
return True
|
||||
fp = getattr(event, "source_fingerprint", None)
|
||||
return fp == self._session_id
|
||||
|
||||
def _dispatch(self, msg_type: ChatMessageType, payload: dict[str, Any]) -> None:
|
||||
self._seq += 1
|
||||
self._handler(
|
||||
ChatMessage(
|
||||
type=msg_type,
|
||||
session_id=self._session_id,
|
||||
payload=payload,
|
||||
seq=self._seq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def stamp_conversation_fingerprint(event: Any, session_id: str) -> None:
|
||||
"""Stamp ``conversation_id`` on an event before dispatch to external systems."""
|
||||
if not getattr(event, "source_fingerprint", None):
|
||||
event.source_fingerprint = session_id
|
||||
meta = getattr(event, "fingerprint_metadata", None)
|
||||
if meta is None:
|
||||
event.fingerprint_metadata = {"conversation_id": session_id}
|
||||
elif isinstance(meta, dict):
|
||||
meta.setdefault("conversation_id", session_id)
|
||||
243
lib/crewai/src/crewai/flow/conversation.py
Normal file
243
lib/crewai/src/crewai/flow/conversation.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Conversational turn helpers for CrewAI Flows.
|
||||
|
||||
Provides message history utilities, kickoff input normalization, and optional
|
||||
class-level defaults via ``ConversationalConfig``. Session identity is ``state.id``
|
||||
(``inputs["id"]`` / ``kickoff(session_id=...)``), not a separate Flow field.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
TurnMode = Literal["auto", "follow_up", "initial"]
|
||||
|
||||
_EXIT_COMMANDS_DEFAULT: tuple[str, ...] = ("exit", "quit")
|
||||
|
||||
|
||||
class ConversationalInputs(TypedDict, total=False):
|
||||
"""Conventional ``kickoff(inputs=...)`` keys for chat turns."""
|
||||
|
||||
id: str
|
||||
user_message: str | dict[str, Any]
|
||||
last_intent: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationalConfig:
|
||||
"""Optional class-level defaults for conversational flows.
|
||||
|
||||
Override per kickoff via ``user_message``, ``session_id``, ``intents``, etc.
|
||||
"""
|
||||
|
||||
default_intents: Sequence[str] | None = None
|
||||
intent_llm: str | None = None
|
||||
interactive_prompt: str = "You: "
|
||||
interactive_timeout: float | None = None
|
||||
exit_commands: Sequence[str] = field(default_factory=lambda: _EXIT_COMMANDS_DEFAULT)
|
||||
defer_trace_finalization: bool = True
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Recommended persisted state shape for multi-turn flows."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
last_user_message: str | None = None
|
||||
last_intent: str | None = None
|
||||
session_ready: bool = False
|
||||
|
||||
|
||||
def _coerce_user_message_text(user_message: str | dict[str, Any] | Any) -> str:
|
||||
if isinstance(user_message, str):
|
||||
return user_message
|
||||
if isinstance(user_message, dict):
|
||||
content = user_message.get("content")
|
||||
if content is not None:
|
||||
return str(content)
|
||||
return str(user_message)
|
||||
|
||||
|
||||
def normalize_kickoff_inputs(
|
||||
inputs: dict[str, Any] | None,
|
||||
*,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Merge conversational kickoff kwargs into the inputs dict."""
|
||||
merged: dict[str, Any] = dict(inputs or {})
|
||||
|
||||
if session_id is not None:
|
||||
merged["id"] = session_id
|
||||
|
||||
if user_message is not None:
|
||||
merged["user_message"] = user_message
|
||||
elif "user_message" in merged and isinstance(merged["user_message"], str):
|
||||
pass
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def get_conversation_messages(flow: Flow[Any]) -> list[LLMMessage]:
|
||||
"""Read message history from flow state or the internal fallback buffer."""
|
||||
buffer: list[LLMMessage] = getattr(flow, "_conversation_messages", [])
|
||||
state = getattr(flow, "_state", None)
|
||||
if state is None:
|
||||
return list(buffer)
|
||||
|
||||
if isinstance(state, dict):
|
||||
messages = state.get("messages")
|
||||
if isinstance(messages, list):
|
||||
return cast(list[LLMMessage], messages)
|
||||
elif isinstance(state, BaseModel) and hasattr(state, "messages"):
|
||||
messages = getattr(state, "messages", None)
|
||||
if isinstance(messages, list):
|
||||
return cast(list[LLMMessage], messages)
|
||||
|
||||
return list(buffer)
|
||||
|
||||
|
||||
def append_message(
|
||||
flow: Flow[Any],
|
||||
role: Literal["user", "assistant", "system", "tool"],
|
||||
content: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Append a message to ``state.messages`` or the flow fallback buffer."""
|
||||
message: LLMMessage = {"role": role, "content": content}
|
||||
for key, value in extra.items():
|
||||
if key in ("tool_call_id", "name", "tool_calls", "files"):
|
||||
message[key] = value # type: ignore[literal-required]
|
||||
|
||||
state = getattr(flow, "_state", None)
|
||||
if state is not None:
|
||||
if isinstance(state, dict):
|
||||
messages = state.get("messages")
|
||||
if isinstance(messages, list):
|
||||
messages.append(message)
|
||||
return
|
||||
elif isinstance(state, BaseModel) and hasattr(state, "messages"):
|
||||
messages = getattr(state, "messages", None)
|
||||
if messages is None:
|
||||
object.__setattr__(state, "messages", [])
|
||||
messages = getattr(state, "messages")
|
||||
if isinstance(messages, list):
|
||||
messages.append(message)
|
||||
return
|
||||
|
||||
if not hasattr(flow, "_conversation_messages"):
|
||||
object.__setattr__(flow, "_conversation_messages", [])
|
||||
flow._conversation_messages.append(message)
|
||||
|
||||
|
||||
def set_state_field(flow: Flow[Any], name: str, value: Any) -> None:
|
||||
"""Set a field on structured or dict flow state when present."""
|
||||
state = getattr(flow, "_state", None)
|
||||
if state is None:
|
||||
return
|
||||
if isinstance(state, dict):
|
||||
state[name] = value
|
||||
elif isinstance(state, BaseModel) and hasattr(state, name):
|
||||
object.__setattr__(state, name, value)
|
||||
|
||||
|
||||
def receive_user_message(
|
||||
flow: Flow[Any],
|
||||
text: str,
|
||||
*,
|
||||
outcomes: Sequence[str] | None = None,
|
||||
llm: str | BaseLLM | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Record a user turn: append message and optionally classify intent."""
|
||||
append_message(flow, "user", text)
|
||||
set_state_field(flow, "last_user_message", text)
|
||||
|
||||
if outcomes and llm is not None:
|
||||
intent = flow.classify_intent(
|
||||
text,
|
||||
outcomes,
|
||||
llm=llm,
|
||||
context=get_conversation_messages(flow),
|
||||
)
|
||||
set_state_field(flow, "last_intent", intent)
|
||||
return intent
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def prepare_conversational_turn(
|
||||
flow: Flow[Any],
|
||||
*,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
config: ConversationalConfig | None = None,
|
||||
) -> None:
|
||||
"""Hydrate conversation state after inputs are merged into flow state."""
|
||||
if user_message is None:
|
||||
state = getattr(flow, "_state", None)
|
||||
if isinstance(state, dict) and "user_message" in state:
|
||||
user_message = state["user_message"]
|
||||
elif isinstance(state, BaseModel) and hasattr(state, "user_message"):
|
||||
user_message = getattr(state, "user_message", None)
|
||||
|
||||
if user_message is None:
|
||||
return
|
||||
|
||||
text = _coerce_user_message_text(user_message)
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# Fresh classification each turn (do not reuse prior turn's route label).
|
||||
set_state_field(flow, "last_intent", None)
|
||||
|
||||
resolved_intents = intents
|
||||
if resolved_intents is None and config is not None:
|
||||
resolved_intents = config.default_intents
|
||||
|
||||
resolved_llm = intent_llm
|
||||
if resolved_llm is None and config is not None:
|
||||
resolved_llm = config.intent_llm
|
||||
|
||||
if resolved_intents:
|
||||
if resolved_llm is None:
|
||||
raise ValueError("intent_llm is required when intents are provided")
|
||||
receive_user_message(
|
||||
flow,
|
||||
text,
|
||||
outcomes=resolved_intents,
|
||||
llm=resolved_llm,
|
||||
)
|
||||
else:
|
||||
receive_user_message(flow, text)
|
||||
|
||||
|
||||
def input_history_to_messages(entries: Sequence[Any]) -> list[LLMMessage]:
|
||||
"""Convert ``Flow.input_history`` entries to LLM message format."""
|
||||
messages: list[LLMMessage] = []
|
||||
for entry in entries:
|
||||
prompt = entry.get("message") if isinstance(entry, dict) else None
|
||||
response = entry.get("response") if isinstance(entry, dict) else None
|
||||
if prompt:
|
||||
messages.append({"role": "assistant", "content": str(prompt)})
|
||||
if response:
|
||||
messages.append({"role": "user", "content": str(response)})
|
||||
return messages
|
||||
|
||||
|
||||
def get_conversational_config(flow: Flow[Any]) -> ConversationalConfig | None:
|
||||
"""Return class-level ``conversational_config`` if defined."""
|
||||
return getattr(type(flow), "conversational_config", None)
|
||||
@@ -83,7 +83,20 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
|
||||
from crewai.flow.conversation import (
|
||||
ConversationalConfig,
|
||||
append_message as _append_conversation_message,
|
||||
get_conversation_messages,
|
||||
get_conversational_config,
|
||||
normalize_kickoff_inputs,
|
||||
prepare_conversational_turn,
|
||||
receive_user_message as _receive_user_message,
|
||||
)
|
||||
from crewai.flow.flow_context import (
|
||||
current_flow_id,
|
||||
current_flow_name,
|
||||
current_flow_request_id,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowConditions,
|
||||
@@ -952,6 +965,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
|
||||
input_provider: InputProvider | None = Field(default=None)
|
||||
suppress_flow_events: bool = Field(default=False)
|
||||
defer_trace_finalization: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"When True, do not finalize the trace batch at the end of each kickoff. "
|
||||
"Call finalize_session_traces() when the chat session ends."
|
||||
),
|
||||
)
|
||||
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
|
||||
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
|
||||
|
||||
@@ -1073,8 +1093,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None)
|
||||
_human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list)
|
||||
_conversation_messages: list[dict[str, Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_user_message: str | dict[str, Any] | None = PrivateAttr(default=None)
|
||||
_pending_intents: Sequence[str] | None = PrivateAttr(default=None)
|
||||
_pending_intent_llm: str | BaseLLM | None = PrivateAttr(default=None)
|
||||
_state: Any = PrivateAttr(default=None)
|
||||
|
||||
conversational_config: ClassVar[ConversationalConfig | None] = None
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
|
||||
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
|
||||
pass
|
||||
@@ -1199,6 +1225,116 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
result: list[str] = self.memory.extract_memories(content)
|
||||
return result
|
||||
|
||||
@property
|
||||
def conversation_messages(self) -> list[dict[str, Any]]:
|
||||
"""Message history from state or the internal conversation buffer."""
|
||||
return get_conversation_messages(self)
|
||||
|
||||
@property
|
||||
def input_history(self) -> list[InputHistoryEntry]:
|
||||
"""Read-only view of prompts and responses from ``ask()``."""
|
||||
return list(self._input_history)
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
role: Literal["user", "assistant", "system", "tool"],
|
||||
content: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Append a message to conversation history on state or the fallback buffer."""
|
||||
_append_conversation_message(self, role, content, **extra)
|
||||
|
||||
def classify_intent(
|
||||
self,
|
||||
text: str,
|
||||
outcomes: Sequence[str],
|
||||
*,
|
||||
llm: str | BaseLLM,
|
||||
context: Sequence[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Map user text to one of the given outcomes using an LLM."""
|
||||
if context:
|
||||
context_blob = "\n".join(
|
||||
f"{m.get('role', 'user')}: {m.get('content', '')}" for m in context
|
||||
)
|
||||
feedback = f"{context_blob}\n\nLatest user message: {text}"
|
||||
else:
|
||||
feedback = text
|
||||
return self._collapse_to_outcome(feedback, outcomes, llm)
|
||||
|
||||
def receive_user_message(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
outcomes: Sequence[str] | None = None,
|
||||
llm: str | BaseLLM | None = None,
|
||||
) -> str:
|
||||
"""Append a user message and optionally set ``last_intent`` on state."""
|
||||
return _receive_user_message(
|
||||
self,
|
||||
text,
|
||||
outcomes=outcomes,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
def _configure_conversational_kickoff(
|
||||
self,
|
||||
*,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Store pending conversational turn options for ``kickoff_async``."""
|
||||
config = get_conversational_config(self) or self.conversational_config
|
||||
resolved_intents = intents
|
||||
resolved_llm = intent_llm
|
||||
if config is not None:
|
||||
if resolved_intents is None:
|
||||
resolved_intents = config.default_intents
|
||||
if resolved_llm is None:
|
||||
resolved_llm = config.intent_llm
|
||||
|
||||
resolved_message = user_message
|
||||
if resolved_message is None and inputs and "user_message" in inputs:
|
||||
resolved_message = inputs["user_message"]
|
||||
|
||||
self._pending_user_message = resolved_message
|
||||
self._pending_intents = list(resolved_intents) if resolved_intents else None
|
||||
self._pending_intent_llm = resolved_llm
|
||||
|
||||
if config is not None and config.defer_trace_finalization:
|
||||
self.defer_trace_finalization = True
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
TraceCollectionListener().batch_manager.defer_session_finalization = True
|
||||
|
||||
return normalize_kickoff_inputs(
|
||||
inputs,
|
||||
user_message=resolved_message,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _clear_conversational_kickoff(self) -> None:
|
||||
self._pending_user_message = None
|
||||
self._pending_intents = None
|
||||
self._pending_intent_llm = None
|
||||
|
||||
def _apply_pending_conversational_turn(self) -> None:
|
||||
if self._pending_user_message is None:
|
||||
return
|
||||
config = get_conversational_config(self) or self.conversational_config
|
||||
prepare_conversational_turn(
|
||||
self,
|
||||
user_message=self._pending_user_message,
|
||||
intents=self._pending_intents,
|
||||
intent_llm=self._pending_intent_llm,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
||||
"""Mark an OR listener as fired atomically.
|
||||
|
||||
@@ -1532,20 +1668,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
reset_emission_counter()
|
||||
reset_last_event_id()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
inputs=None,
|
||||
),
|
||||
)
|
||||
if future and isinstance(future, Future):
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowStartedEvent handler failed", exc_info=True)
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self._flow_display_name(),
|
||||
inputs=None,
|
||||
),
|
||||
)
|
||||
if future and isinstance(future, Future):
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowStartedEvent handler failed", exc_info=True)
|
||||
|
||||
get_env_context()
|
||||
|
||||
@@ -1698,29 +1833,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
self._event_futures.clear()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_result,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future and isinstance(future, Future):
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowFinishedEvent handler failed", exc_info=True)
|
||||
if not self._should_defer_trace_finalization():
|
||||
await self._emit_flow_finished_async(final_result)
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
self._finalize_flow_trace_batch()
|
||||
|
||||
return final_result
|
||||
|
||||
@@ -2033,6 +2149,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
restore_from_state_id: str | None = None,
|
||||
*,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
interactive: bool = False,
|
||||
interactive_prompt: str | None = None,
|
||||
interactive_timeout: float | None = None,
|
||||
exit_commands: Sequence[str] | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Start the flow execution in a synchronous context.
|
||||
|
||||
@@ -2052,10 +2177,49 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
If the referenced state is not found, the kickoff falls back
|
||||
silently to baseline behavior. Cannot be combined with
|
||||
``from_checkpoint``; passing both raises ``ValueError``.
|
||||
user_message: Text or ``{"role": "user", "content": "..."}`` for this
|
||||
chat turn. Appended to ``state.messages`` before the graph runs.
|
||||
session_id: Conversation session UUID; merged into ``inputs["id"]``
|
||||
for ``@persist`` restoration.
|
||||
intents: Optional outcome labels for pre-kickoff intent classification.
|
||||
intent_llm: LLM used when ``intents`` is set.
|
||||
interactive: If True, run a CLI loop (``ask`` per line) until exit or
|
||||
timeout. For local demos only; APIs should pass ``user_message``.
|
||||
interactive_prompt: Prompt shown by ``ask()`` in interactive mode.
|
||||
interactive_timeout: Per-line timeout for interactive ``ask()``.
|
||||
exit_commands: Words that end interactive mode (default exit, quit).
|
||||
|
||||
Returns:
|
||||
The final output from the flow or FlowStreamingOutput if streaming.
|
||||
"""
|
||||
if interactive:
|
||||
if user_message is not None:
|
||||
raise ValueError(
|
||||
"Cannot pass user_message with interactive=True; "
|
||||
"messages are collected via ask()."
|
||||
)
|
||||
if self.stream:
|
||||
raise ValueError("interactive=True is not supported with stream=True")
|
||||
return self._kickoff_interactive(
|
||||
inputs=inputs,
|
||||
input_files=input_files,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
interactive_prompt=interactive_prompt,
|
||||
interactive_timeout=interactive_timeout,
|
||||
exit_commands=exit_commands,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
)
|
||||
|
||||
inputs = self._configure_conversational_kickoff(
|
||||
inputs=inputs,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
|
||||
if from_checkpoint is not None and restore_from_state_id is not None:
|
||||
raise ValueError(
|
||||
"Cannot combine `from_checkpoint` and `restore_from_state_id`. "
|
||||
@@ -2064,7 +2228,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return restored.kickoff(inputs=inputs, input_files=input_files)
|
||||
return restored.kickoff(
|
||||
inputs=inputs,
|
||||
input_files=input_files,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
@@ -2087,6 +2258,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
inputs=inputs,
|
||||
input_files=input_files,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
user_message=self._pending_user_message,
|
||||
session_id=inputs.get("id") if inputs else None,
|
||||
intents=self._pending_intents,
|
||||
intent_llm=self._pending_intent_llm,
|
||||
)
|
||||
result_holder.append(result)
|
||||
except Exception as e:
|
||||
@@ -2110,11 +2285,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return streaming_output
|
||||
|
||||
async def _run_flow() -> Any:
|
||||
return await self.kickoff_async(
|
||||
inputs,
|
||||
input_files,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
)
|
||||
try:
|
||||
return await self.kickoff_async(
|
||||
inputs,
|
||||
input_files,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
user_message=self._pending_user_message,
|
||||
session_id=inputs.get("id") if inputs else None,
|
||||
intents=self._pending_intents,
|
||||
intent_llm=self._pending_intent_llm,
|
||||
)
|
||||
finally:
|
||||
self._clear_conversational_kickoff()
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
@@ -2124,12 +2306,65 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
except RuntimeError:
|
||||
return asyncio.run(_run_flow())
|
||||
|
||||
def _kickoff_interactive(
|
||||
self,
|
||||
*,
|
||||
inputs: dict[str, Any] | None,
|
||||
input_files: dict[str, FileInput] | None,
|
||||
session_id: str | None,
|
||||
intents: Sequence[str] | None,
|
||||
intent_llm: str | BaseLLM | None,
|
||||
interactive_prompt: str | None,
|
||||
interactive_timeout: float | None,
|
||||
exit_commands: Sequence[str] | None,
|
||||
restore_from_state_id: str | None,
|
||||
) -> Any:
|
||||
config = get_conversational_config(self) or self.conversational_config
|
||||
prompt = interactive_prompt or (config.interactive_prompt if config else "You: ")
|
||||
timeout = (
|
||||
interactive_timeout
|
||||
if interactive_timeout is not None
|
||||
else (config.interactive_timeout if config else None)
|
||||
)
|
||||
exits = {c.strip().lower() for c in (exit_commands or (config.exit_commands if config else ("exit", "quit")))}
|
||||
sid = session_id
|
||||
if sid is None and inputs and "id" in inputs:
|
||||
sid = str(inputs["id"])
|
||||
if sid is None:
|
||||
sid = str(uuid4())
|
||||
|
||||
last_result: Any = None
|
||||
while True:
|
||||
line = self.ask(prompt, timeout=timeout)
|
||||
if line is None or line.strip().lower() in exits:
|
||||
break
|
||||
turn_inputs = self._configure_conversational_kickoff(
|
||||
inputs=inputs,
|
||||
user_message=line,
|
||||
session_id=sid,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
last_result = self.kickoff(
|
||||
inputs=turn_inputs,
|
||||
input_files=input_files,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
)
|
||||
restore_from_state_id = None
|
||||
self._clear_conversational_kickoff()
|
||||
return last_result
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
restore_from_state_id: str | None = None,
|
||||
*,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Start the flow execution asynchronously.
|
||||
|
||||
@@ -2150,10 +2385,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
separate persistence key. If the referenced state is not
|
||||
found, falls back silently to baseline. Cannot be combined
|
||||
with ``from_checkpoint``; passing both raises ``ValueError``.
|
||||
user_message: User text for this conversational turn.
|
||||
session_id: Session UUID (``inputs["id"]``).
|
||||
intents: Optional labels for pre-kickoff classification.
|
||||
intent_llm: LLM for classification when ``intents`` is set.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
inputs = self._configure_conversational_kickoff(
|
||||
inputs=inputs,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
|
||||
if from_checkpoint is not None and restore_from_state_id is not None:
|
||||
raise ValueError(
|
||||
"Cannot combine `from_checkpoint` and `restore_from_state_id`. "
|
||||
@@ -2162,7 +2409,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
restored = apply_checkpoint(self, from_checkpoint)
|
||||
if restored is not None:
|
||||
return await restored.kickoff_async(inputs=inputs, input_files=input_files)
|
||||
return await restored.kickoff_async(
|
||||
inputs=inputs,
|
||||
input_files=input_files,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
if self.stream:
|
||||
result_holder: list[Any] = []
|
||||
current_task_info: TaskInfo = {
|
||||
@@ -2215,10 +2469,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
flow_id_token = None
|
||||
request_id_token = None
|
||||
flow_name_token = None
|
||||
if current_flow_id.get() is None:
|
||||
flow_id_token = current_flow_id.set(self.flow_id)
|
||||
if current_flow_request_id.get() is None:
|
||||
request_id_token = current_flow_request_id.set(self.flow_id)
|
||||
if current_flow_name.get() is None:
|
||||
flow_name_token = current_flow_name.set(self._flow_display_name())
|
||||
|
||||
try:
|
||||
# Reset flow state for fresh execution unless restoring from persistence
|
||||
@@ -2301,8 +2558,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
)
|
||||
|
||||
# Update state with any additional inputs (ignoring the 'id' key)
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
|
||||
# Update state with any additional inputs (ignoring conversational keys)
|
||||
filtered_inputs = {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k not in ("id", "user_message", "last_intent")
|
||||
}
|
||||
if filtered_inputs:
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
@@ -2310,31 +2571,51 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
reset_emission_counter()
|
||||
reset_last_event_id()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
inputs=inputs,
|
||||
),
|
||||
skip_flow_started = self._should_defer_trace_finalization() and getattr(
|
||||
self, "_conversation_trace_started", False
|
||||
)
|
||||
if not skip_flow_started:
|
||||
started_event = FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self._flow_display_name(),
|
||||
inputs=inputs,
|
||||
)
|
||||
future = crewai_event_bus.emit(self, started_event)
|
||||
if future:
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("FlowStartedEvent handler failed", exc_info=True)
|
||||
logger.warning(
|
||||
"FlowStartedEvent handler failed", exc_info=True
|
||||
)
|
||||
if self._should_defer_trace_finalization():
|
||||
object.__setattr__(self, "_conversation_trace_started", True)
|
||||
object.__setattr__(
|
||||
self,
|
||||
"_conversation_flow_started_event_id",
|
||||
started_event.event_id,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
TraceCollectionListener().batch_manager.defer_session_finalization = (
|
||||
True
|
||||
)
|
||||
if not self.suppress_flow_events:
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold magenta"
|
||||
)
|
||||
|
||||
# After FlowStarted (when not suppressed): env events must not pre-empt
|
||||
# trace batch init with implicit "crew" execution_type.
|
||||
# After FlowStarted: env events must not pre-empt trace batch init
|
||||
# with implicit "crew" execution_type.
|
||||
get_env_context()
|
||||
|
||||
if inputs is not None and "id" not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
self._apply_pending_conversational_turn()
|
||||
|
||||
if self._is_execution_resuming:
|
||||
await self._replay_recorded_events()
|
||||
|
||||
@@ -2428,35 +2709,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
self._event_futures.clear()
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"FlowFinishedEvent handler failed", exc_info=True
|
||||
)
|
||||
if not self._should_defer_trace_finalization():
|
||||
await self._emit_flow_finished_async(final_output)
|
||||
|
||||
if not self.suppress_flow_events:
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
self._finalize_flow_trace_batch()
|
||||
|
||||
return final_output
|
||||
finally:
|
||||
self._clear_conversational_kickoff()
|
||||
# Ensure all background memory saves complete before returning
|
||||
if self.memory is not None and hasattr(self.memory, "drain_writes"):
|
||||
self.memory.drain_writes()
|
||||
@@ -2464,6 +2724,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
current_flow_request_id.reset(request_id_token)
|
||||
if flow_id_token is not None:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
if flow_name_token is not None:
|
||||
current_flow_name.reset(flow_name_token)
|
||||
detach(flow_token)
|
||||
|
||||
async def akickoff(
|
||||
@@ -2472,6 +2734,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
from_checkpoint: CheckpointConfig | None = None,
|
||||
restore_from_state_id: str | None = None,
|
||||
*,
|
||||
user_message: str | dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
intents: Sequence[str] | None = None,
|
||||
intent_llm: str | BaseLLM | None = None,
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||
|
||||
@@ -2483,6 +2750,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
restore_from_state_id: Optional UUID of a previously-persisted flow
|
||||
whose latest snapshot should hydrate this run's state. See
|
||||
``kickoff_async`` for full semantics.
|
||||
user_message: User text for this conversational turn.
|
||||
session_id: Session UUID (``inputs["id"]``).
|
||||
intents: Optional labels for pre-kickoff classification.
|
||||
intent_llm: LLM for classification when ``intents`` is set.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
@@ -2492,6 +2763,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
input_files,
|
||||
from_checkpoint,
|
||||
restore_from_state_id=restore_from_state_id,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
intents=intents,
|
||||
intent_llm=intent_llm,
|
||||
)
|
||||
|
||||
async def _replay_recorded_events(self) -> None:
|
||||
@@ -3354,6 +3629,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
|
||||
if response:
|
||||
_append_conversation_message(self, "user", response)
|
||||
|
||||
return response
|
||||
|
||||
def _request_human_feedback(
|
||||
@@ -3554,6 +3832,99 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
return outcomes[0]
|
||||
|
||||
def _flow_display_name(self) -> str:
|
||||
return self.name or self.__class__.__name__
|
||||
|
||||
def _should_defer_trace_finalization(self) -> bool:
|
||||
if self.defer_trace_finalization:
|
||||
return True
|
||||
config = get_conversational_config(self)
|
||||
return bool(config and config.defer_trace_finalization)
|
||||
|
||||
async def _emit_flow_finished_async(self, result: Any) -> None:
|
||||
"""Emit ``FlowFinishedEvent`` and await handlers."""
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self._flow_display_name(),
|
||||
result=result,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if not future:
|
||||
return
|
||||
try:
|
||||
if isinstance(future, Future):
|
||||
await asyncio.wrap_future(future)
|
||||
else:
|
||||
await future
|
||||
except Exception:
|
||||
logger.warning("FlowFinishedEvent handler failed", exc_info=True)
|
||||
|
||||
def _emit_flow_finished_sync(self, result: Any) -> None:
|
||||
"""Emit ``FlowFinishedEvent`` from synchronous session teardown."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
asyncio.run(self._emit_flow_finished_async(result))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot emit flow_finished synchronously while an event loop is running"
|
||||
)
|
||||
|
||||
def finalize_session_traces(self) -> None:
|
||||
"""Finalize the trace batch after a multi-turn conversational session."""
|
||||
from crewai.events.event_context import restore_event_scope
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
batch_manager = trace_listener.batch_manager
|
||||
|
||||
if batch_manager._batch_finalized or not batch_manager.is_batch_initialized():
|
||||
batch_manager.defer_session_finalization = False
|
||||
object.__setattr__(self, "_conversation_trace_started", False)
|
||||
object.__setattr__(self, "_conversation_flow_started_event_id", None)
|
||||
return
|
||||
|
||||
result = self._method_outputs[-1] if self._method_outputs else None
|
||||
if (
|
||||
self._should_defer_trace_finalization()
|
||||
and getattr(self, "_conversation_trace_started", False)
|
||||
):
|
||||
started_id = getattr(self, "_conversation_flow_started_event_id", None)
|
||||
if started_id:
|
||||
restore_event_scope(((started_id, "flow_started"),))
|
||||
try:
|
||||
self._emit_flow_finished_sync(result)
|
||||
finally:
|
||||
restore_event_scope(())
|
||||
object.__setattr__(self, "_conversation_flow_started_event_id", None)
|
||||
|
||||
self._finalize_flow_trace_batch(force=True)
|
||||
object.__setattr__(self, "_conversation_trace_started", False)
|
||||
batch_manager.defer_session_finalization = False
|
||||
|
||||
def _finalize_flow_trace_batch(self, *, force: bool = False) -> None:
|
||||
"""Finalize the active trace batch when this flow owns it."""
|
||||
if not force and self._should_defer_trace_finalization():
|
||||
return
|
||||
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type != "flow":
|
||||
return
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
|
||||
def _log_flow_event(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@@ -18,3 +18,7 @@ current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"flow_method_name", default="unknown"
|
||||
)
|
||||
|
||||
current_flow_name: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"flow_name", default=None
|
||||
)
|
||||
|
||||
4
lib/crewai/src/crewai/flow/providers/__init__.py
Normal file
4
lib/crewai/src/crewai/flow/providers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from crewai.flow.providers.queue import QueueInputProvider
|
||||
|
||||
|
||||
__all__ = ["QueueInputProvider"]
|
||||
82
lib/crewai/src/crewai/flow/providers/queue.py
Normal file
82
lib/crewai/src/crewai/flow/providers/queue.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Queue-backed input provider for conversational flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
|
||||
class QueueInputProvider:
|
||||
"""Blocks on a per-session queue until a user message is pushed.
|
||||
|
||||
Use for long-running workers where ``Flow.ask()`` should wait on WebSocket
|
||||
or another transport without blocking the event loop thread (flow runs ask
|
||||
in a worker thread).
|
||||
|
||||
Example:
|
||||
```python
|
||||
provider = QueueInputProvider()
|
||||
flow.input_provider = provider
|
||||
|
||||
# From a WebSocket handler:
|
||||
provider.push(session_id, "hello")
|
||||
|
||||
# Inside the flow:
|
||||
reply = flow.ask("You: ", metadata={"session_id": session_id})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queues: dict[str, queue.Queue[str | None]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_queue(self, session_id: str) -> queue.Queue[str | None]:
|
||||
with self._lock:
|
||||
if session_id not in self._queues:
|
||||
self._queues[session_id] = queue.Queue()
|
||||
return self._queues[session_id]
|
||||
|
||||
def push(self, session_id: str, text: str) -> None:
|
||||
"""Enqueue a user message for the given session."""
|
||||
self._get_queue(session_id).put(text)
|
||||
|
||||
def close_session(self, session_id: str) -> None:
|
||||
"""Signal end of session (unblocks ``ask()`` with None)."""
|
||||
self._get_queue(session_id).put(None)
|
||||
|
||||
def request_input(
|
||||
self,
|
||||
message: str,
|
||||
flow: Flow[Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
session_id = self._resolve_session_id(flow, metadata)
|
||||
if session_id is None:
|
||||
return None
|
||||
try:
|
||||
return self._get_queue(session_id).get()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session_id(
|
||||
flow: Flow[Any],
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
if metadata and metadata.get("session_id"):
|
||||
return str(metadata["session_id"])
|
||||
state = getattr(flow, "_state", None)
|
||||
if state is None:
|
||||
return None
|
||||
if isinstance(state, dict):
|
||||
value = state.get("id")
|
||||
return str(value) if value else None
|
||||
if hasattr(state, "id"):
|
||||
value = getattr(state, "id", None)
|
||||
return str(value) if value else None
|
||||
return None
|
||||
480
lib/crewai/tests/test_flow_conversation.py
Normal file
480
lib/crewai/tests/test_flow_conversation.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Tests for conversational Flow helpers and kickoff parameters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.trace_listener import TraceCollectionListener
|
||||
from crewai.events.types.flow_events import FlowStartedEvent
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent
|
||||
from crewai.flow import Flow, ChatState, listen, start
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_name
|
||||
from crewai.flow.conversation import (
|
||||
ConversationalConfig,
|
||||
append_message,
|
||||
get_conversation_messages,
|
||||
normalize_kickoff_inputs,
|
||||
prepare_conversational_turn,
|
||||
)
|
||||
from crewai.flow.chat import ChatMessage, ChatSession
|
||||
from crewai.flow.providers import QueueInputProvider
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class SimpleChatFlow(Flow[ChatState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "done"
|
||||
|
||||
|
||||
class DictChatFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
return self.state.get("marker", "ok")
|
||||
|
||||
|
||||
class TestNormalizeKickoffInputs:
|
||||
def test_merges_session_and_user_message(self) -> None:
|
||||
merged = normalize_kickoff_inputs(
|
||||
{"foo": 1},
|
||||
user_message="hello",
|
||||
session_id="sess-1",
|
||||
)
|
||||
assert merged["id"] == "sess-1"
|
||||
assert merged["user_message"] == "hello"
|
||||
assert merged["foo"] == 1
|
||||
|
||||
|
||||
class TestMessageHelpers:
|
||||
def test_append_message_on_pydantic_state(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
flow._state = ChatState()
|
||||
append_message(flow, "user", "hi")
|
||||
assert get_conversation_messages(flow) == [{"role": "user", "content": "hi"}]
|
||||
|
||||
def test_append_message_fallback_buffer(self) -> None:
|
||||
flow = DictChatFlow()
|
||||
|
||||
class _State:
|
||||
id = str(uuid4())
|
||||
|
||||
flow._state = _State()
|
||||
append_message(flow, "assistant", "reply")
|
||||
assert get_conversation_messages(flow) == [
|
||||
{"role": "assistant", "content": "reply"}
|
||||
]
|
||||
assert flow._conversation_messages == [
|
||||
{"role": "assistant", "content": "reply"}
|
||||
]
|
||||
|
||||
|
||||
class TestIntentPerTurn:
|
||||
def test_prepare_clears_stale_last_intent(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
flow._state = ChatState(last_intent="ORDER", messages=[])
|
||||
prepare_conversational_turn(flow, user_message="hello")
|
||||
assert flow.state.last_intent is None
|
||||
|
||||
|
||||
class TestKickoffConversational:
|
||||
def test_kickoff_user_message_hydrates_state(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
flow.kickoff(user_message="track my order", session_id="session-abc")
|
||||
|
||||
assert flow.state.last_user_message == "track my order"
|
||||
assert any(
|
||||
m.get("role") == "user" and m.get("content") == "track my order"
|
||||
for m in flow.state.messages
|
||||
)
|
||||
assert flow.state.id == "session-abc"
|
||||
|
||||
def test_kickoff_classifies_intent_when_configured(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
|
||||
with patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
return_value="order",
|
||||
) as mock_collapse:
|
||||
flow.kickoff(
|
||||
user_message="where is my package",
|
||||
session_id="s1",
|
||||
intents=["order", "help"],
|
||||
intent_llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
mock_collapse.assert_called_once()
|
||||
assert flow.state.last_intent == "order"
|
||||
|
||||
def test_ask_appends_to_messages(self) -> None:
|
||||
class AskFlow(Flow[ChatState]):
|
||||
input_provider = MagicMock()
|
||||
input_provider.request_input = MagicMock(return_value="user reply")
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
self.ask("Prompt:")
|
||||
return "ok"
|
||||
|
||||
flow = AskFlow()
|
||||
flow._state = ChatState()
|
||||
flow.kickoff()
|
||||
|
||||
assert any(
|
||||
m.get("role") == "user" and m.get("content") == "user reply"
|
||||
for m in flow.state.messages
|
||||
)
|
||||
|
||||
|
||||
class TestClassifyIntent:
|
||||
def test_uses_collapse_with_context(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
flow._state = ChatState(
|
||||
messages=[{"role": "user", "content": "prior"}],
|
||||
)
|
||||
|
||||
with patch.object(flow, "_collapse_to_outcome", return_value="help") as mock:
|
||||
outcome = flow.classify_intent(
|
||||
"I need help",
|
||||
["order", "help"],
|
||||
llm="gpt-4o-mini",
|
||||
context=flow.conversation_messages,
|
||||
)
|
||||
|
||||
assert outcome == "help"
|
||||
assert "I need help" in mock.call_args[0][0]
|
||||
|
||||
|
||||
class TestQueueInputProvider:
|
||||
def test_push_and_request_input(self) -> None:
|
||||
provider = QueueInputProvider()
|
||||
flow = SimpleChatFlow()
|
||||
flow._state = ChatState(id="sess-q")
|
||||
|
||||
provider.push("sess-q", "hello")
|
||||
result = provider.request_input(">", flow, metadata={"session_id": "sess-q"})
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
class TestChatSession:
|
||||
def test_handle_turn_returns_turn_result(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
session = ChatSession(
|
||||
flow,
|
||||
session_id="chat-1",
|
||||
intents=["order", "help"],
|
||||
intent_llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with patch.object(flow, "_collapse_to_outcome", return_value="help"):
|
||||
turn = session.handle_turn("hi there")
|
||||
|
||||
assert turn.session_id == "chat-1"
|
||||
assert turn.output == "done"
|
||||
assert turn.intent == "help"
|
||||
assert any(m["role"] == "user" for m in turn.messages)
|
||||
session.close()
|
||||
|
||||
def test_chat_message_model(self) -> None:
|
||||
msg = ChatMessage(
|
||||
type="assistant_delta",
|
||||
session_id="x",
|
||||
payload={"chunk": "hi"},
|
||||
)
|
||||
assert msg.version == "1"
|
||||
assert msg.type == "assistant_delta"
|
||||
|
||||
|
||||
class TestFlowTracingWhenSuppressed:
|
||||
def test_flow_started_emitted_when_panel_events_suppressed(self) -> None:
|
||||
class QuietFlow(Flow[ChatState]):
|
||||
suppress_flow_events = True
|
||||
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
started: list[str] = []
|
||||
original_emit = crewai_event_bus.emit
|
||||
|
||||
def track_emit(source: Any, event: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if isinstance(event, FlowStartedEvent):
|
||||
started.append(event.flow_name)
|
||||
return original_emit(source, event, *args, **kwargs)
|
||||
|
||||
with patch.object(crewai_event_bus, "emit", side_effect=track_emit):
|
||||
QuietFlow().kickoff()
|
||||
|
||||
assert started == ["QuietFlow"]
|
||||
|
||||
def test_llm_action_inside_flow_claims_flow_trace_batch(self) -> None:
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.current_batch = None
|
||||
listener.batch_manager.batch_owner_type = None
|
||||
listener.batch_manager.batch_owner_id = None
|
||||
|
||||
flow_id_token = current_flow_id.set("flow-test-id")
|
||||
flow_name_token = current_flow_name.set("DemoSupportFlow")
|
||||
try:
|
||||
event = LLMCallStartedEvent(
|
||||
model="gpt-4o-mini",
|
||||
messages=[],
|
||||
call_id="call-test",
|
||||
)
|
||||
listener._handle_action_event("llm_call_started", object(), event)
|
||||
finally:
|
||||
current_flow_id.reset(flow_id_token)
|
||||
current_flow_name.reset(flow_name_token)
|
||||
|
||||
assert listener.batch_manager.batch_owner_type == "flow"
|
||||
assert listener.batch_manager.batch_owner_id == "flow-test-id"
|
||||
assert (
|
||||
listener.batch_manager.current_batch.execution_metadata["execution_type"]
|
||||
== "flow"
|
||||
)
|
||||
assert (
|
||||
listener.batch_manager.current_batch.execution_metadata["flow_name"]
|
||||
== "DemoSupportFlow"
|
||||
)
|
||||
|
||||
|
||||
class TestDeferTraceFinalization:
|
||||
def test_conversational_kickoff_enables_defer_flag(self) -> None:
|
||||
class ChatFlow(Flow[ChatState]):
|
||||
conversational_config = ConversationalConfig(
|
||||
defer_trace_finalization=True
|
||||
)
|
||||
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
flow = ChatFlow()
|
||||
flow._configure_conversational_kickoff(
|
||||
user_message="hi",
|
||||
session_id="sess-trace",
|
||||
)
|
||||
assert flow.defer_trace_finalization is True
|
||||
assert flow._should_defer_trace_finalization() is True
|
||||
|
||||
def test_finalize_skipped_until_forced(self) -> None:
|
||||
flow = SimpleChatFlow()
|
||||
flow.defer_trace_finalization = True
|
||||
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_listener.TraceCollectionListener"
|
||||
) as mock_listener_cls:
|
||||
mock_listener_cls.return_value.batch_manager.batch_owner_type = "flow"
|
||||
mock_listener_cls.return_value.first_time_handler.is_first_time = False
|
||||
|
||||
flow._finalize_flow_trace_batch()
|
||||
mock_listener_cls.assert_not_called()
|
||||
|
||||
flow._finalize_flow_trace_batch(force=True)
|
||||
mock_listener_cls.assert_called_once()
|
||||
|
||||
|
||||
class TestDeferredFlowLifecycleEvents:
|
||||
def test_deferred_kickoff_skips_per_turn_flow_finished(self) -> None:
|
||||
class ChatFlow(Flow[ChatState]):
|
||||
conversational_config = ConversationalConfig(
|
||||
defer_trace_finalization=True
|
||||
)
|
||||
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
flow = ChatFlow()
|
||||
with patch.object(flow, "_emit_flow_finished_async") as mock_finished:
|
||||
flow.kickoff(user_message="hi", session_id="sess-lifecycle")
|
||||
mock_finished.assert_not_called()
|
||||
|
||||
def test_flow_finished_without_flow_started_warns(self, capsys) -> None:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_context import restore_event_scope
|
||||
from crewai.events.types.flow_events import FlowFinishedEvent
|
||||
|
||||
class BareFlow(Flow[ChatState]):
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
restore_event_scope(())
|
||||
flow = BareFlow()
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name="BareFlow",
|
||||
result="ok",
|
||||
state={},
|
||||
),
|
||||
)
|
||||
captured = capsys.readouterr().out
|
||||
assert "flow_finished" in captured
|
||||
assert "Missing starting event" in captured
|
||||
|
||||
def test_finalize_session_restores_flow_started_scope(self, capsys) -> None:
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch
|
||||
|
||||
class ChatFlow(Flow[ChatState]):
|
||||
conversational_config = ConversationalConfig(
|
||||
defer_trace_finalization=True
|
||||
)
|
||||
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
flow = ChatFlow()
|
||||
flow.defer_trace_finalization = True
|
||||
object.__setattr__(flow, "_conversation_trace_started", True)
|
||||
object.__setattr__(flow, "_conversation_flow_started_event_id", "start-evt-1")
|
||||
flow._method_outputs.append("ok")
|
||||
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.batch_owner_type = "flow"
|
||||
listener.batch_manager.current_batch = TraceBatch(
|
||||
execution_metadata={"execution_type": "flow", "flow_name": "ChatFlow"},
|
||||
)
|
||||
listener.batch_manager.defer_session_finalization = True
|
||||
listener.batch_manager._batch_finalized = False
|
||||
|
||||
with patch.object(flow, "_finalize_flow_trace_batch") as mock_finalize:
|
||||
flow.finalize_session_traces()
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Missing starting event" not in captured
|
||||
mock_finalize.assert_called_once_with(force=True)
|
||||
assert listener.batch_manager.defer_session_finalization is False
|
||||
|
||||
def test_finalize_batch_is_idempotent(self) -> None:
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
):
|
||||
bm = TraceBatchManager()
|
||||
bm.current_batch = bm.initialize_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={"execution_type": "flow", "flow_name": "ChatFlow"},
|
||||
)
|
||||
bm.trace_batch_id = "batch-idempotent"
|
||||
bm.backend_initialized = True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"send_trace_events",
|
||||
return_value=MagicMock(status_code=200),
|
||||
),
|
||||
patch.object(
|
||||
bm.plus_api,
|
||||
"finalize_trace_batch",
|
||||
return_value=MagicMock(status_code=200, json=MagicMock(return_value={})),
|
||||
) as mock_finalize_api,
|
||||
):
|
||||
bm.finalize_batch()
|
||||
bm.finalize_batch()
|
||||
|
||||
assert mock_finalize_api.call_count == 1
|
||||
assert bm._batch_finalized is True
|
||||
|
||||
def test_finalize_session_is_idempotent_after_batch_cleared(self) -> None:
|
||||
class ChatFlow(Flow[ChatState]):
|
||||
@start()
|
||||
def begin(self) -> str:
|
||||
return "ok"
|
||||
|
||||
flow = ChatFlow()
|
||||
flow.defer_trace_finalization = True
|
||||
object.__setattr__(flow, "_conversation_trace_started", True)
|
||||
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.current_batch = None
|
||||
listener.batch_manager.batch_owner_type = None
|
||||
listener.batch_manager.trace_batch_id = None
|
||||
listener.batch_manager._batch_finalized = True
|
||||
|
||||
with patch.object(flow, "_emit_flow_finished_sync") as mock_finished:
|
||||
with patch.object(flow, "_finalize_flow_trace_batch") as mock_finalize:
|
||||
flow.finalize_session_traces()
|
||||
flow.finalize_session_traces()
|
||||
|
||||
mock_finished.assert_not_called()
|
||||
mock_finalize.assert_not_called()
|
||||
|
||||
def test_sigint_skips_deferred_session_batch(self) -> None:
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch
|
||||
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.current_batch = TraceBatch()
|
||||
listener.batch_manager.defer_session_finalization = True
|
||||
|
||||
with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize:
|
||||
if listener.batch_manager.is_batch_initialized():
|
||||
if not listener.batch_manager.defer_session_finalization:
|
||||
listener.batch_manager.finalize_batch()
|
||||
mock_finalize.assert_not_called()
|
||||
|
||||
|
||||
class TestNestedCrewTracing:
|
||||
def test_is_inside_active_flow_context_when_kickoff_running(self) -> None:
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.flow.flow_context import current_flow_id
|
||||
|
||||
assert TraceCollectionListener._is_inside_active_flow_context() is False
|
||||
token = current_flow_id.set("parent-flow-id")
|
||||
try:
|
||||
assert TraceCollectionListener._is_inside_active_flow_context() is True
|
||||
finally:
|
||||
current_flow_id.reset(token)
|
||||
|
||||
def test_nested_crew_completion_skips_finalize(self) -> None:
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.flow.flow_context import current_flow_id
|
||||
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.batch_owner_type = "crew"
|
||||
|
||||
token = current_flow_id.set("parent-flow-id")
|
||||
try:
|
||||
with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize:
|
||||
if listener._nested_in_flow_execution():
|
||||
pass
|
||||
elif listener.batch_manager.batch_owner_type == "crew":
|
||||
listener.batch_manager.finalize_batch()
|
||||
mock_finalize.assert_not_called()
|
||||
finally:
|
||||
current_flow_id.reset(token)
|
||||
|
||||
def test_flow_owned_batch_skips_finalize_without_flow_context(self) -> None:
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatch
|
||||
|
||||
listener = TraceCollectionListener()
|
||||
listener.batch_manager.batch_owner_type = "flow"
|
||||
listener.batch_manager.current_batch = TraceBatch(
|
||||
execution_metadata={"execution_type": "flow", "flow_name": "Demo"},
|
||||
)
|
||||
|
||||
with patch.object(listener.batch_manager, "finalize_batch") as mock_finalize:
|
||||
if listener._nested_in_flow_execution():
|
||||
pass
|
||||
elif listener.batch_manager.batch_owner_type == "crew":
|
||||
listener.batch_manager.finalize_batch()
|
||||
mock_finalize.assert_not_called()
|
||||
@@ -884,6 +884,49 @@ class TestTraceListenerSetup:
|
||||
"test_batch_id_12345", "Internal Server Error"
|
||||
)
|
||||
|
||||
def test_finalize_batch_clears_buffer_after_successful_send(self) -> None:
|
||||
"""Successful send must not restore a stale event buffer (duplicate events)."""
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
|
||||
with patch(
|
||||
"crewai.events.listeners.tracing.trace_batch_manager.is_tracing_enabled_in_context",
|
||||
return_value=True,
|
||||
):
|
||||
batch_manager = TraceBatchManager()
|
||||
batch_manager.current_batch = batch_manager.initialize_batch(
|
||||
user_context={"privacy_level": "standard"},
|
||||
execution_metadata={
|
||||
"execution_type": "flow",
|
||||
"flow_name": "TestFlow",
|
||||
},
|
||||
)
|
||||
batch_manager.trace_batch_id = "batch-clear-test"
|
||||
batch_manager.backend_initialized = True
|
||||
batch_manager.event_buffer = [
|
||||
TraceEvent(
|
||||
type="llm_call_started",
|
||||
timestamp="2026-01-01T00:00:00",
|
||||
event_id="evt-1",
|
||||
emission_sequence=1,
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
batch_manager.plus_api,
|
||||
"send_trace_events",
|
||||
return_value=MagicMock(status_code=200),
|
||||
),
|
||||
patch.object(
|
||||
batch_manager.plus_api,
|
||||
"finalize_trace_batch",
|
||||
return_value=MagicMock(status_code=200, json=MagicMock(return_value={})),
|
||||
),
|
||||
):
|
||||
batch_manager.finalize_batch()
|
||||
|
||||
assert batch_manager.event_buffer == []
|
||||
|
||||
def test_ephemeral_batch_includes_anon_id(self):
|
||||
"""Test that ephemeral batch initialization sends anon_id from get_user_id()"""
|
||||
fake_user_id = "abc123def456"
|
||||
|
||||
Reference in New Issue
Block a user