refactor: convert Flow to Pydantic BaseModel

This commit is contained in:
Greyson LaLonde
2026-04-01 03:48:41 +08:00
committed by GitHub
parent 107bc7f7be
commit d6714a0e60
10 changed files with 373 additions and 304 deletions

View File

@@ -4,6 +4,8 @@ from typing import Any
import urllib.request
import warnings
from pydantic import PydanticUserError
from crewai.agent.core import Agent
from crewai.agent.planning_config import PlanningConfig
from crewai.crew import Crew
@@ -93,6 +95,38 @@ def __getattr__(name: str) -> Any:
raise AttributeError(f"module 'crewai' has no attribute {name!r}")
try:
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
from crewai.hooks.llm_hooks import LLMCallHookContext as _LLMCallHookContext
from crewai.tools.tool_types import ToolResult as _ToolResult
from crewai.utilities.prompts import (
StandardPromptResult as _StandardPromptResult,
SystemPromptResult as _SystemPromptResult,
)
_AgentExecutor.model_rebuild(
force=True,
_types_namespace={
"Agent": Agent,
"ToolsHandler": _ToolsHandler,
"Crew": Crew,
"BaseLLM": BaseLLM,
"Task": Task,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
"LLMCallHookContext": _LLMCallHookContext,
"ToolResult": _ToolResult,
},
)
except (ImportError, PydanticUserError):
import logging as _logging
_logging.getLogger(__name__).warning(
"AgentExecutor.model_rebuild() failed; forward refs may be unresolved.",
exc_info=True,
)
__all__ = [
"LLM",
"Agent",

View File

@@ -1011,7 +1011,7 @@ class Agent(BaseAgent):
self.agent_executor.tools = tools
self.agent_executor.original_tools = raw_tools
self.agent_executor.prompt = prompt
self.agent_executor.stop = stop_words
self.agent_executor.stop_words = stop_words
self.agent_executor.tools_names = get_tool_names(tools)
self.agent_executor.tools_description = render_text_description_and_args(tools)
self.agent_executor.response_model = (

View File

@@ -11,10 +11,15 @@ import threading
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from pydantic import (
BaseModel,
Field,
PrivateAttr,
model_validator,
)
from rich.console import Console
from rich.text import Text
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -119,6 +124,7 @@ class AgentExecutorState(BaseModel):
(todos, observations, replan tracking) in a single validated model.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
messages: list[LLMMessage] = Field(default_factory=list)
iterations: int = Field(default=0)
current_answer: AgentAction | AgentFinish | None = Field(default=None)
@@ -152,6 +158,9 @@ class AgentExecutorState(BaseModel):
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""Agent Executor for both standalone agents and crew-bound agents.
_skip_auto_memory prevents Flow from eagerly allocating a Memory
instance — the executor uses agent/crew memory, not its own.
Inherits from:
- Flow[AgentExecutorState]: Provides flow orchestration capabilities
- CrewAgentExecutorMixin: Provides memory methods (short/long/external term)
@@ -159,136 +168,74 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
This executor can operate in two modes:
- Standalone mode: When crew and task are None (used by Agent.kickoff())
- Crew mode: When crew and task are provided (used by Agent.execute_task())
Note: Multiple instances may be created during agent initialization
(cache setup, RPM controller setup, etc.) but only the final instance
should execute tasks via invoke().
"""
def __init__(
self,
llm: BaseLLM,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
task: Task | None = None,
crew: Crew | None = None,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize the flow-based agent executor.
_skip_auto_memory: bool = True
Args:
llm: Language model instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
task: Optional task to execute (None for standalone agent execution).
crew: Optional crew instance (None for standalone agent execution).
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task: Task | None = task
self.agent = agent
self.crew: Crew | None = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.log_error_after = 3
self._console: Console = Console()
suppress_flow_events: bool = True # always suppress for executor
llm: BaseLLM = Field(exclude=True)
agent: Agent = Field(exclude=True)
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
max_iter: int = Field(default=25, exclude=True)
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
tools_names: str = Field(default="", exclude=True)
stop_words: list[str] = Field(default_factory=list, exclude=True)
tools_description: str = Field(default="", exclude=True)
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
crew: Crew | None = Field(default=None, exclude=True)
step_callback: Any = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
respect_context_window: bool = Field(default=False, exclude=True)
request_within_rpm_limit: Callable[[], bool] | None = Field(
default=None, exclude=True
)
callbacks: list[Any] = Field(default_factory=list, exclude=True)
response_model: type[BaseModel] | None = Field(default=None, exclude=True)
i18n: I18N | None = Field(default=None, exclude=True)
log_error_after: int = Field(default=3, exclude=True)
before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = (
Field(default_factory=list, exclude=True)
)
after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = Field(
default_factory=list, exclude=True
)
# Error context storage for recovery
self._last_parser_error: OutputParserError | None = None
self._last_context_error: Exception | None = None
_i18n: I18N = PrivateAttr(default_factory=get_i18n)
_printer: Printer = PrivateAttr(default_factory=Printer)
_console: Console = PrivateAttr(default_factory=Console)
_last_parser_error: OutputParserError | None = PrivateAttr(default=None)
_last_context_error: Exception | None = PrivateAttr(default=None)
_execution_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_finalize_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_finalize_called: bool = PrivateAttr(default=False)
_is_executing: bool = PrivateAttr(default=False)
_has_been_invoked: bool = PrivateAttr(default=False)
_instance_id: str = PrivateAttr(default_factory=lambda: str(uuid4())[:8])
_step_executor: Any = PrivateAttr(default=None)
_planner_observer: Any = PrivateAttr(default=None)
# Execution guard to prevent concurrent/duplicate executions
self._execution_lock = threading.Lock()
self._finalize_lock = threading.Lock()
self._finalize_called: bool = False
self._is_executing: bool = False
self._has_been_invoked: bool = False
self._flow_initialized: bool = False
self._instance_id = str(uuid4())[:8]
self.before_llm_call_hooks: list[
BeforeLLMCallHookType | BeforeLLMCallHookCallable
] = []
self.after_llm_call_hooks: list[
AfterLLMCallHookType | AfterLLMCallHookCallable
] = []
@model_validator(mode="after")
def _setup_executor(self) -> Self:
"""Configure executor after Pydantic field initialization."""
self._i18n = self.i18n or get_i18n()
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
if not isinstance(existing_stop, list):
existing_stop = []
self.llm.stop = list(set(existing_stop + self.stop_words))
self._state = AgentExecutorState()
self.max_method_calls = self.max_iter * 10
# Plan-and-Execute components (Phase 2)
# Lazy-imported to avoid circular imports during module load
self._step_executor: Any = None
self._planner_observer: Any = None
def _ensure_flow_initialized(self) -> None:
"""Ensure Flow.__init__() has been called.
This is deferred from __init__ to prevent FlowCreatedEvent emission
during agent setup when multiple executor instances are created.
Only the instance that actually executes via invoke() will emit events.
"""
if not self._flow_initialized:
current_tracing = is_tracing_enabled_in_context()
# Now call Flow's __init__ which will replace self._state
# with Flow's managed state. Suppress flow events since this is
# an agent executor, not a user-facing flow.
super().__init__(
suppress_flow_events=True,
tracing=current_tracing if current_tracing else None,
max_method_calls=self.max_iter * 10,
)
self._flow_initialized = True
current_tracing = is_tracing_enabled_in_context()
self.tracing = current_tracing if current_tracing else None
self._flow_post_init()
return self
def _check_native_tool_support(self) -> bool:
"""Check if LLM supports native function calling."""
@@ -318,19 +265,13 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
@property
def state(self) -> AgentExecutorState:
"""Get state - returns temporary state if Flow not yet initialized.
Flow initialization is deferred to prevent event emission during agent setup.
Returns the temporary state until invoke() is called.
"""
if self._flow_initialized and hasattr(self, "_state_lock"):
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
return self._state
"""Get thread-safe state proxy."""
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
@property
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations
return self._state.iterations # type: ignore[no-any-return]
@iterations.setter
def iterations(self, value: int) -> None:
@@ -340,7 +281,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
@property
def messages(self) -> list[LLMMessage]:
"""Compatibility property - returns state messages."""
return self._state.messages
return self._state.messages # type: ignore[no-any-return]
@messages.setter
def messages(self, value: list[LLMMessage]) -> None:
@@ -1969,8 +1910,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
@listen("initialized")
def continue_iteration(self) -> Literal["check_iteration"]:
"""Bridge listener that connects iteration loop back to iteration check."""
if self._flow_initialized:
self._discard_or_listener(FlowMethodName("continue_iteration"))
self._discard_or_listener(FlowMethodName("continue_iteration"))
return "check_iteration"
@router(or_(initialize_reasoning, continue_iteration))
@@ -2598,8 +2538,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
if is_inside_event_loop():
return self.invoke_async(inputs)
self._ensure_flow_initialized()
with self._execution_lock:
if self._is_executing:
raise RuntimeError(
@@ -2690,8 +2628,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
Returns:
Dictionary with agent output.
"""
self._ensure_flow_initialized()
with self._execution_lock:
if self._is_executing:
raise RuntimeError(
@@ -3007,17 +2943,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""
return bool(self.crew and self.crew._train)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Protocol compatibility.
Allows the executor to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
# Backward compatibility alias (deprecated)
CrewAgentExecutorFlow = AgentExecutor

View File

@@ -39,7 +39,14 @@ from uuid import uuid4
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from pydantic import BaseModel, Field, ValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
ValidationError,
)
from pydantic._internal._model_construction import ModelMetaclass
from rich.console import Console
from rich.panel import Panel
@@ -81,6 +88,7 @@ from crewai.flow.flow_wrappers import (
SimpleFlowCondition,
StartMethod,
)
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.flow.input_provider import InputProvider
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import (
@@ -108,7 +116,6 @@ if TYPE_CHECKING:
from crewai_files import FileInput
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.llms.base_llm import BaseLLM
from crewai.flow.visualization import build_flow_structure, render_interactive
@@ -728,7 +735,7 @@ class StateProxy(Generic[T]):
return result
class FlowMeta(type):
class FlowMeta(ModelMetaclass):
def __new__(
mcs,
name: str,
@@ -736,6 +743,45 @@ class FlowMeta(type):
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
parent_fields: set[str] = set()
for base in bases:
if hasattr(base, "model_fields"):
parent_fields.update(base.model_fields)
annotations = namespace.get("__annotations__", {})
_skip_types = (classmethod, staticmethod, property)
for base in bases:
if isinstance(base, ModelMetaclass):
continue
for attr_name in getattr(base, "__annotations__", {}):
if attr_name not in annotations and attr_name not in namespace:
annotations[attr_name] = ClassVar
for attr_name, attr_value in namespace.items():
if isinstance(attr_value, property) and attr_name not in annotations:
for base in bases:
base_ann = getattr(base, "__annotations__", {})
if attr_name in base_ann:
annotations[attr_name] = ClassVar
for attr_name, attr_value in list(namespace.items()):
if attr_name in annotations or attr_name.startswith("_"):
continue
if attr_name in parent_fields:
annotations[attr_name] = Any
if isinstance(attr_value, BaseModel):
namespace[attr_name] = Field(
default_factory=lambda v=attr_value: v, exclude=True
)
continue
if callable(attr_value) or isinstance(
attr_value, (*_skip_types, FlowMethod)
):
continue
annotations[attr_name] = ClassVar[type(attr_value)]
namespace["__annotations__"] = annotations
cls = super().__new__(mcs, name, bases, namespace)
start_methods = []
@@ -820,88 +866,90 @@ class FlowMeta(type):
return cls
class Flow(Generic[T], metaclass=FlowMeta):
class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
ignored_types=(StartMethod, ListenMethod, RouterMethod),
revalidate_instances="never",
)
__hash__ = object.__hash__
_start_methods: ClassVar[list[FlowMethodName]] = []
_listeners: ClassVar[dict[FlowMethodName, SimpleFlowCondition | FlowCondition]] = {}
_routers: ClassVar[set[FlowMethodName]] = set()
_router_paths: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {}
initial_state: type[T] | T | None = None
name: str | None = None
tracing: bool | None = None
stream: bool = False
memory: Memory | MemoryScope | MemorySlice | None = None
input_provider: InputProvider | None = None
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
_initial_state_t = item
initial_state: Any = Field(default=None)
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
input_provider: InputProvider | None = Field(default=None)
suppress_flow_events: bool = Field(default=False)
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
persistence: Any = Field(default=None, exclude=True)
max_method_calls: int = Field(default=100, exclude=True)
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
default_factory=dict
)
_method_execution_counts: dict[FlowMethodName, int] = PrivateAttr(
default_factory=dict
)
_pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = PrivateAttr(
default_factory=dict
)
_fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict)
_is_execution_resuming: bool = PrivateAttr(default=False)
_event_futures: list[Future[None]] = PrivateAttr(default_factory=list)
_pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None)
_human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
_input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list)
_state: Any = PrivateAttr(default=None)
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
pass
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
_FlowGeneric._initial_state_t = item
return _FlowGeneric
def __init__(
self,
persistence: FlowPersistence | None = None,
tracing: bool | None = None,
suppress_flow_events: bool = False,
max_method_calls: int = 100,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
def __setattr__(self, name: str, value: Any) -> None:
"""Allow arbitrary attribute assignment for backward compat with plain class."""
if name in self.model_fields or name in self.__private_attributes__:
super().__setattr__(name, value)
else:
object.__setattr__(self, name, value)
Args:
persistence: Optional persistence backend for storing flow states
tracing: Whether to enable tracing. True=always enable, False=always disable, None=check environment/user settings
suppress_flow_events: Whether to suppress flow event emissions (internal use)
max_method_calls: Maximum times a single method can be called per execution before raising RecursionError
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: dict[FlowMethodName, FlowMethod[Any, Any]] = {}
self._method_execution_counts: dict[FlowMethodName, int] = {}
self._pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = {}
self._fired_or_listeners: set[FlowMethodName] = (
set()
) # Track OR listeners that already fired
self._method_outputs: list[Any] = [] # list to store all method outputs
self._state_lock = threading.Lock()
self._or_listeners_lock = threading.Lock()
self._completed_methods: set[FlowMethodName] = (
set()
) # Track completed methods for reload
self._method_call_counts: dict[FlowMethodName, int] = {}
self._max_method_calls = max_method_calls
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
self._event_futures: list[Future[None]] = []
def model_post_init(self, __context: Any) -> None:
self._flow_post_init()
# Human feedback storage
self.human_feedback_history: list[HumanFeedbackResult] = []
self.last_human_feedback: HumanFeedbackResult | None = None
self._pending_feedback_context: PendingFeedbackContext | None = None
# Per-method stash for real @human_feedback output (keyed by method name)
# Used to decouple routing outcome from method return value when emit is set
self._human_feedback_method_outputs: dict[str, Any] = {}
self.suppress_flow_events: bool = suppress_flow_events
def _flow_post_init(self) -> None:
"""Heavy initialization: state creation, events, memory, method registration."""
if getattr(self, "_flow_post_init_done", False):
return
object.__setattr__(self, "_flow_post_init_done", True)
# User input history (for self.ask())
self._input_history: list[InputHistoryEntry] = []
if self._state is None:
self._state = self._create_initial_state()
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
tracing_enabled = should_enable_tracing(override=self.tracing)
set_tracing_enabled(tracing_enabled)
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
if not self.suppress_flow_events:
crewai_event_bus.emit(
@@ -1385,8 +1433,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = None
# Clear pending feedback from persistence
if self._persistence:
self._persistence.clear_pending_feedback(context.flow_id)
if self.persistence:
self.persistence.clear_pending_feedback(context.flow_id)
# Emit feedback received event
crewai_event_bus.emit(
@@ -1427,17 +1475,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
self._pending_feedback_context = e.context
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_pending_feedback(
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
@@ -1487,39 +1535,33 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
init_state = self.initial_state
# Handle case where initial_state is None but we have a type parameter
if init_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance - FlowState auto-generates id via default_factory
instance = state_type()
# Ensure id is set - generate UUID if empty
if not getattr(instance, "id", None):
object.__setattr__(instance, "id", str(uuid4()))
return cast(T, instance)
if issubclass(state_type, BaseModel):
# Create a new type with FlowState first for proper id default
class StateWithId(FlowState, state_type): # type: ignore
pass
instance = StateWithId()
# Ensure id is set - generate UUID if empty
if not getattr(instance, "id", None):
object.__setattr__(instance, "id", str(uuid4()))
return cast(T, instance)
if state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if init_state is None:
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(init_state, type):
state_class = init_state
if issubclass(state_class, FlowState):
return state_class()
return cast(T, state_class())
if issubclass(state_class, BaseModel):
model_fields = getattr(state_class, "model_fields", None)
if not model_fields or "id" not in model_fields:
@@ -1527,7 +1569,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
model_instance = state_class()
if not getattr(model_instance, "id", None):
object.__setattr__(model_instance, "id", str(uuid4()))
return model_instance
return cast(T, model_instance)
if init_state is dict:
return cast(T, {"id": str(uuid4())})
@@ -1538,32 +1580,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
new_state["id"] = str(uuid4())
return cast(T, new_state)
# Handle BaseModel instance case
if isinstance(init_state, BaseModel):
model = cast(BaseModel, init_state)
if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field")
# Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"):
# Pydantic v2
model = init_state
if hasattr(model, "id"):
state_dict = model.model_dump()
elif hasattr(model, "dict"):
# Pydantic v1
state_dict = model.dict()
else:
# Fallback for other BaseModel implementations
state_dict = {
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
if not state_dict.get("id"):
state_dict["id"] = str(uuid4())
model_class = type(model)
return cast(T, model_class(**state_dict))
# Ensure id is set - generate UUID if empty
if not state_dict.get("id"):
state_dict["id"] = str(uuid4())
class StateWithId(FlowState, type(model)): # type: ignore
pass
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))
state_dict = model.model_dump()
state_dict["id"] = str(uuid4())
return cast(T, StateWithId(**state_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
@@ -1576,17 +1607,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
if isinstance(self._state, BaseModel):
try:
return self._state.model_copy(deep=True)
return cast(T, self._state.model_copy(deep=True))
except (TypeError, AttributeError):
try:
state_dict = self._state.model_dump()
model_class = type(self._state)
return model_class(**state_dict)
return cast(T, model_class(**state_dict))
except Exception:
return self._state.model_copy(deep=False)
return cast(T, self._state.model_copy(deep=False))
else:
try:
return copy.deepcopy(self._state)
return cast(T, copy.deepcopy(self._state))
except (TypeError, AttributeError):
return cast(T, self._state.copy())
@@ -1662,7 +1693,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel):
# For BaseModel states, preserve existing fields unless overridden
try:
model = cast(BaseModel, self._state)
model = self._state
# Get current state as dict
if hasattr(model, "model_dump"):
current_state = model.model_dump()
@@ -1713,7 +1744,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._state.update(stored_state)
elif isinstance(self._state, BaseModel):
# For BaseModel states, create new instance with stored values
model = cast(BaseModel, self._state)
model = self._state
if hasattr(model, "model_validate"):
# Pydantic v2
self._state = cast(T, type(model).model_validate(stored_state))
@@ -1938,7 +1969,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
try:
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = inputs and "id" in inputs and self._persistence is not None
is_restoring = inputs and "id" in inputs and self.persistence is not None
if not is_restoring:
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
@@ -1964,9 +1995,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
setattr(self._state, "id", inputs["id"]) # noqa: B010
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
if "id" in inputs and self.persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
stored_state = self.persistence.load_state(restore_uuid)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}"
@@ -2036,17 +2067,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_pending_feedback(
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
@@ -2332,10 +2363,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
e.context.method_name = method_name
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
# Emit paused event (not failed)
if not self.suppress_flow_events:
@@ -2696,9 +2727,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Catches and logs any exceptions during execution, preventing individual listener failures from breaking the entire flow
"""
count = self._method_call_counts.get(listener_name, 0) + 1
if count > self._max_method_calls:
if count > self.max_method_calls:
raise RecursionError(
f"Method '{listener_name}' has been called {self._max_method_calls} times in "
f"Method '{listener_name}' has been called {self.max_method_calls} times in "
f"this flow execution, which indicates an infinite loop. "
f"This commonly happens when a @listen label matches the "
f"method's own name."
@@ -2805,7 +2836,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
This is best-effort: if persistence is not configured, this is a no-op.
"""
if self._persistence is None:
if self.persistence is None:
return
try:
state_data = (
@@ -2813,7 +2844,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_state(
self.persistence.save_state(
flow_uuid=self.flow_id,
method_name="_ask_checkpoint",
state_data=state_data,

View File

@@ -98,7 +98,7 @@ class EncodingFlow(Flow[EncodingState]):
_skip_auto_memory: bool = True
initial_state = EncodingState
initial_state: type[EncodingState] = EncodingState
def __init__(
self,

View File

@@ -65,7 +65,7 @@ class RecallFlow(Flow[RecallState]):
_skip_auto_memory: bool = True
initial_state = RecallState
initial_state: type[RecallState] = RecallState
def __init__(
self,

View File

@@ -148,6 +148,36 @@ class Memory(BaseModel):
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
import copy as _copy
cls = type(self)
new = cls.__new__(cls)
if memo is None:
memo = {}
memo[id(self)] = new
object.__setattr__(new, "__dict__", _copy.deepcopy(self.__dict__, memo))
object.__setattr__(
new, "__pydantic_fields_set__", _copy.copy(self.__pydantic_fields_set__)
)
object.__setattr__(
new, "__pydantic_extra__", _copy.deepcopy(self.__pydantic_extra__, memo)
)
# Private attrs: create fresh pool/lock instead of deepcopying
private = {}
for k, v in (self.__pydantic_private__ or {}).items():
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
else:
try:
private[k] = _copy.deepcopy(v, memo)
except Exception:
private[k] = v
object.__setattr__(new, "__pydantic_private__", private)
return new
def model_post_init(self, __context: Any) -> None:
"""Initialize runtime state from field values."""
self._config = MemoryConfig(

View File

@@ -2,9 +2,10 @@
from __future__ import annotations
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from crewai.utilities.i18n import I18N, get_i18n

View File

@@ -4,13 +4,55 @@ Tests the Flow-based agent executor implementation including state management,
flow methods, routing logic, and error handling.
"""
from __future__ import annotations
import asyncio
import time
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.agents.step_executor import StepExecutor
def _build_executor(**kwargs: Any) -> AgentExecutor:
"""Create an AgentExecutor without validation — for unit tests.
Uses model_construct to skip Pydantic validators so plain Mock()
objects are accepted for typed fields like llm, agent, crew, task.
"""
executor = AgentExecutor.model_construct(**kwargs)
executor._state = AgentExecutorState()
executor._methods = {}
executor._method_outputs = []
executor._completed_methods = set()
executor._fired_or_listeners = set()
executor._pending_and_listeners = {}
executor._method_execution_counts = {}
executor._method_call_counts = {}
executor._event_futures = []
executor._human_feedback_method_outputs = {}
executor._input_history = []
executor._is_execution_resuming = False
import threading
executor._state_lock = threading.Lock()
executor._or_listeners_lock = threading.Lock()
executor._execution_lock = threading.Lock()
executor._finalize_lock = threading.Lock()
executor._finalize_called = False
executor._is_executing = False
executor._has_been_invoked = False
executor._last_parser_error = None
executor._last_context_error = None
executor._step_executor = None
executor._planner_observer = None
from crewai.utilities.printer import Printer
executor._printer = Printer()
from crewai.utilities.i18n import get_i18n
executor._i18n = kwargs.get("i18n") or get_i18n()
return executor
from crewai.agents.planner_observer import PlannerObserver
from crewai.experimental.agent_executor import (
AgentExecutorState,
@@ -75,6 +117,7 @@ class TestAgentExecutor:
"""Create mock dependencies for executor."""
llm = Mock()
llm.supports_stop_words.return_value = True
llm.stop = []
task = Mock()
task.description = "Test task"
@@ -94,7 +137,7 @@ class TestAgentExecutor:
prompt = {"prompt": "Test prompt with {input}, {tool_names}, {tools}"}
tools = []
tools_handler = Mock()
tools_handler = Mock(spec=_ToolsHandler)
return {
"llm": llm,
@@ -112,7 +155,7 @@ class TestAgentExecutor:
def test_executor_initialization(self, mock_dependencies):
"""Test AgentExecutor initialization."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.llm == mock_dependencies["llm"]
assert executor.task == mock_dependencies["task"]
@@ -126,7 +169,7 @@ class TestAgentExecutor:
with patch.object(
AgentExecutor, "_show_start_logs"
) as mock_show_start:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.initialize_reasoning()
assert result == "initialized"
@@ -134,7 +177,7 @@ class TestAgentExecutor:
def test_check_max_iterations_not_reached(self, mock_dependencies):
"""Test routing when iterations < max."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.iterations = 5
result = executor.check_max_iterations()
@@ -142,7 +185,7 @@ class TestAgentExecutor:
def test_check_max_iterations_reached(self, mock_dependencies):
"""Test routing when iterations >= max."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.iterations = 10
result = executor.check_max_iterations()
@@ -150,7 +193,7 @@ class TestAgentExecutor:
def test_route_by_answer_type_action(self, mock_dependencies):
"""Test routing for AgentAction."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
@@ -160,7 +203,7 @@ class TestAgentExecutor:
def test_route_by_answer_type_finish(self, mock_dependencies):
"""Test routing for AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thoughts", output="Final answer", text="complete"
)
@@ -170,7 +213,7 @@ class TestAgentExecutor:
def test_continue_iteration(self, mock_dependencies):
"""Test iteration continuation."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.continue_iteration()
@@ -179,7 +222,7 @@ class TestAgentExecutor:
def test_finalize_success(self, mock_dependencies):
"""Test finalize with valid AgentFinish."""
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thinking", output="Done", text="complete"
)
@@ -192,7 +235,7 @@ class TestAgentExecutor:
def test_finalize_failure(self, mock_dependencies):
"""Test finalize skips when given AgentAction instead of AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
@@ -208,7 +251,7 @@ class TestAgentExecutor:
):
"""Finalize should skip synthesis when last todo is already a complete answer."""
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.todos.items = [
TodoItem(
step_number=1,
@@ -252,7 +295,7 @@ class TestAgentExecutor:
):
"""Finalize should still synthesize when response_model is configured."""
with patch.object(AgentExecutor, "_show_logs"):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.response_model = Mock()
executor.state.todos.items = [
TodoItem(
@@ -287,7 +330,7 @@ class TestAgentExecutor:
def test_format_prompt(self, mock_dependencies):
"""Test prompt formatting."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
inputs = {"input": "test input", "tool_names": "tool1, tool2", "tools": "desc"}
result = executor._format_prompt("Prompt {input} {tool_names} {tools}", inputs)
@@ -298,18 +341,18 @@ class TestAgentExecutor:
def test_is_training_mode_false(self, mock_dependencies):
"""Test training mode detection when not in training."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor._is_training_mode() is False
def test_is_training_mode_true(self, mock_dependencies):
"""Test training mode detection when in training."""
mock_dependencies["crew"]._train = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor._is_training_mode() is True
def test_append_message_to_state(self, mock_dependencies):
"""Test message appending to state."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
initial_count = len(executor.state.messages)
executor._append_message_to_state("test message")
@@ -322,7 +365,7 @@ class TestAgentExecutor:
callback = Mock()
mock_dependencies["step_callback"] = callback
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
@@ -332,7 +375,7 @@ class TestAgentExecutor:
def test_invoke_step_callback_none(self, mock_dependencies):
"""Test step callback when none provided."""
mock_dependencies["step_callback"] = None
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
# Should not raise error
executor._invoke_step_callback(
@@ -346,7 +389,7 @@ class TestAgentExecutor:
"""Test async step callback scheduling when already in an event loop."""
callback = AsyncMock()
mock_dependencies["step_callback"] = callback
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
answer = AgentFinish(thought="thinking", output="test", text="final")
with patch("crewai.experimental.agent_executor.asyncio.run") as mock_run:
@@ -364,6 +407,7 @@ class TestStepExecutorCriticalFixes:
def mock_dependencies(self):
"""Create mock dependencies for AgentExecutor tests in this class."""
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -393,6 +437,7 @@ class TestStepExecutorCriticalFixes:
@pytest.fixture
def step_executor(self):
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
agent = Mock()
@@ -485,7 +530,7 @@ class TestStepExecutorCriticalFixes:
mock_handle_exception.return_value = None
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor._last_parser_error = OutputParserError("test error")
initial_iterations = executor.state.iterations
@@ -500,7 +545,7 @@ class TestStepExecutorCriticalFixes:
self, mock_handle_context, mock_dependencies
):
"""Test recovery from context length error."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor._last_context_error = Exception("context too long")
initial_iterations = executor.state.iterations
@@ -513,16 +558,16 @@ class TestStepExecutorCriticalFixes:
def test_use_stop_words_property(self, mock_dependencies):
"""Test use_stop_words property."""
mock_dependencies["llm"].supports_stop_words.return_value = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.use_stop_words is True
mock_dependencies["llm"].supports_stop_words.return_value = False
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.use_stop_words is False
def test_compatibility_properties(self, mock_dependencies):
"""Test compatibility properties for mixin."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.messages = [{"role": "user", "content": "test"}]
executor.state.iterations = 5
@@ -538,6 +583,7 @@ class TestFlowErrorHandling:
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -575,7 +621,7 @@ class TestFlowErrorHandling:
mock_enforce_rpm.return_value = None
mock_get_llm.side_effect = OutputParserError("parse failed")
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "parser_error"
@@ -596,7 +642,7 @@ class TestFlowErrorHandling:
mock_get_llm.side_effect = Exception("context length")
mock_is_context_exceeded.return_value = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "context_error"
@@ -610,6 +656,7 @@ class TestFlowInvoke:
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
llm.stop = []
task = Mock()
task.description = "Test"
task.human_input = False
@@ -646,7 +693,7 @@ class TestFlowInvoke:
mock_dependencies,
):
"""Test successful invoke without human feedback."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
# Mock kickoff to set the final answer in state
def mock_kickoff_side_effect():
@@ -666,7 +713,7 @@ class TestFlowInvoke:
@patch.object(AgentExecutor, "kickoff")
def test_invoke_failure_no_agent_finish(self, mock_kickoff, mock_dependencies):
"""Test invoke fails without AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="test", tool_input="test", text="action text"
)
@@ -689,7 +736,7 @@ class TestFlowInvoke:
"system": "System: {input}",
"user": "User: {input} {tool_names} {tools}",
}
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def mock_kickoff_side_effect():
executor.state.current_answer = AgentFinish(
@@ -713,6 +760,7 @@ class TestNativeToolExecution:
@pytest.fixture
def mock_dependencies(self):
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -734,7 +782,7 @@ class TestNativeToolExecution:
prompt = {"prompt": "Test {input} {tool_names} {tools}"}
tools_handler = Mock()
tools_handler = Mock(spec=_ToolsHandler)
tools_handler.cache = None
return {
@@ -754,7 +802,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_runs_parallel_for_multiple_calls(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def slow_one() -> str:
time.sleep(0.2)
@@ -790,7 +838,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_falls_back_to_sequential_for_result_as_answer(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def slow_one() -> str:
time.sleep(0.2)
@@ -832,7 +880,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_result_as_answer_short_circuits_remaining_calls(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
call_counts = {"slow_one": 0, "slow_two": 0}
def slow_one() -> str:

View File

@@ -873,7 +873,7 @@ class TestAutoPersistence:
# Create flow WITHOUT persistence
flow = TestFlow()
assert flow._persistence is None # No persistence initially
assert flow.persistence is None # No persistence initially
# kickoff should auto-create persistence when HumanFeedbackPending is raised
result = flow.kickoff()
@@ -882,11 +882,11 @@ class TestAutoPersistence:
assert isinstance(result, HumanFeedbackPending)
# Persistence should have been auto-created
assert flow._persistence is not None
assert flow.persistence is not None
# The pending feedback should be saved
flow_id = result.context.flow_id
loaded = flow._persistence.load_pending_feedback(flow_id)
loaded = flow.persistence.load_pending_feedback(flow_id)
assert loaded is not None