mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
fix(flow): serialize initial_state class refs as JSON schema
This commit is contained in:
@@ -45,6 +45,7 @@ from pydantic import (
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PlainSerializer,
|
||||
PrivateAttr,
|
||||
SerializeAsAny,
|
||||
ValidationError,
|
||||
@@ -157,6 +158,37 @@ def _resolve_persistence(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
|
||||
|
||||
|
||||
def _serialize_initial_state(value: Any) -> Any:
|
||||
"""Make ``initial_state`` safe for JSON checkpoint serialization.
|
||||
|
||||
``BaseModel`` class refs are emitted as their JSON schema under a sentinel
|
||||
marker key so deserialization can round-trip them back to a class.
|
||||
``BaseModel`` instances are dumped to JSON (round-trip as plain dicts,
|
||||
which ``_create_initial_state`` accepts). Bare ``type`` values that are
|
||||
not ``BaseModel`` subclasses (e.g. ``dict``) are dropped since they
|
||||
can't be represented in JSON.
|
||||
"""
|
||||
if isinstance(value, type):
|
||||
if issubclass(value, BaseModel):
|
||||
return {_INITIAL_STATE_CLASS_MARKER: value.model_json_schema()}
|
||||
return None
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
return value
|
||||
|
||||
|
||||
def _deserialize_initial_state(value: Any) -> Any:
|
||||
"""Rehydrate a class ref serialized by :func:`_serialize_initial_state`."""
|
||||
if isinstance(value, dict) and _INITIAL_STATE_CLASS_MARKER in value:
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
return create_model_from_schema(value[_INITIAL_STATE_CLASS_MARKER])
|
||||
return value
|
||||
|
||||
|
||||
class FlowState(BaseModel):
|
||||
"""Base model for all flow states, ensuring each state has a unique ID."""
|
||||
|
||||
@@ -908,7 +940,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
entity_type: Literal["flow"] = "flow"
|
||||
|
||||
initial_state: Any = Field(default=None)
|
||||
initial_state: Annotated[ # type: ignore[type-arg]
|
||||
type[BaseModel] | type[dict] | dict[str, Any] | BaseModel | None,
|
||||
BeforeValidator(_deserialize_initial_state),
|
||||
PlainSerializer(_serialize_initial_state, return_type=Any, when_used="json"),
|
||||
] = Field(default=None)
|
||||
name: str | None = Field(default=None)
|
||||
tracing: bool | None = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
|
||||
@@ -11,11 +11,12 @@ from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.flow.flow import _INITIAL_STATE_CLASS_MARKER, Flow, start
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_listener import (
|
||||
_find_checkpoint,
|
||||
@@ -310,6 +311,65 @@ class TestRuntimeStateLineage:
|
||||
assert state._branch != first
|
||||
|
||||
|
||||
class TestFlowInitialStateSerialization:
|
||||
"""Regression tests for checkpoint serialization of ``Flow.initial_state``."""
|
||||
|
||||
def test_class_ref_serializes_as_schema(self) -> None:
|
||||
class MyState(BaseModel):
|
||||
id: str = "x"
|
||||
foo: str = "bar"
|
||||
|
||||
flow = Flow(initial_state=MyState)
|
||||
state = RuntimeState(root=[flow])
|
||||
dumped = json.loads(state.model_dump_json())
|
||||
entity = dumped["entities"][0]
|
||||
wrapped = entity["initial_state"]
|
||||
assert isinstance(wrapped, dict)
|
||||
assert _INITIAL_STATE_CLASS_MARKER in wrapped
|
||||
assert wrapped[_INITIAL_STATE_CLASS_MARKER].get("title") == "MyState"
|
||||
|
||||
def test_class_ref_round_trips_to_basemodel_subclass(self) -> None:
|
||||
class MyState(BaseModel):
|
||||
id: str = "x"
|
||||
foo: str = "bar"
|
||||
|
||||
flow = Flow(initial_state=MyState)
|
||||
raw = RuntimeState(root=[flow]).model_dump_json()
|
||||
restored = RuntimeState.model_validate_json(
|
||||
raw, context={"from_checkpoint": True}
|
||||
)
|
||||
rehydrated = restored.root[0].initial_state
|
||||
assert isinstance(rehydrated, type)
|
||||
assert issubclass(rehydrated, BaseModel)
|
||||
assert set(rehydrated.model_fields.keys()) == {"id", "foo"}
|
||||
|
||||
def test_instance_serializes_as_values(self) -> None:
|
||||
class MyState(BaseModel):
|
||||
id: str = "x"
|
||||
foo: str = "bar"
|
||||
|
||||
flow = Flow(initial_state=MyState(foo="baz"))
|
||||
state = RuntimeState(root=[flow])
|
||||
dumped = json.loads(state.model_dump_json())
|
||||
entity = dumped["entities"][0]
|
||||
assert entity["initial_state"] == {"id": "x", "foo": "baz"}
|
||||
|
||||
def test_dict_passthrough(self) -> None:
|
||||
flow = Flow(initial_state={"id": "x", "foo": "bar"})
|
||||
state = RuntimeState(root=[flow])
|
||||
dumped = json.loads(state.model_dump_json())
|
||||
entity = dumped["entities"][0]
|
||||
assert entity["initial_state"] == {"id": "x", "foo": "bar"}
|
||||
|
||||
def test_dict_round_trips_as_dict(self) -> None:
|
||||
flow = Flow(initial_state={"id": "x", "foo": "bar"})
|
||||
raw = RuntimeState(root=[flow]).model_dump_json()
|
||||
restored = RuntimeState.model_validate_json(
|
||||
raw, context={"from_checkpoint": True}
|
||||
)
|
||||
assert restored.root[0].initial_state == {"id": "x", "foo": "bar"}
|
||||
|
||||
|
||||
# ---------- JsonProvider forking ----------
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user