mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: allow persist Flow state with BaseModel entries (#3276)
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -81,7 +81,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
"""
|
"""
|
||||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||||
if isinstance(state_data, BaseModel):
|
if isinstance(state_data, BaseModel):
|
||||||
state_dict = dict(state_data) # Use dict() for better type compatibility
|
state_dict = state_data.model_dump()
|
||||||
elif isinstance(state_data, dict):
|
elif isinstance(state_data, dict):
|
||||||
state_dict = state_data
|
state_dict = state_data
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""Test flow state persistence functionality."""
|
"""Test flow state persistence functionality."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.flow.flow import Flow, FlowState, listen, start
|
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||||
@@ -208,3 +207,44 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog):
|
|||||||
flow = VerboseFlow(persistence=persistence)
|
flow = VerboseFlow(persistence=persistence)
|
||||||
flow.kickoff()
|
flow.kickoff()
|
||||||
assert "Saving flow state" in caplog.text
|
assert "Saving flow state" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistence_with_base_model(tmp_path):
|
||||||
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||||
|
persistence = SQLiteFlowPersistence(db_path)
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
type: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class State(FlowState):
|
||||||
|
latest_message: Message | None = None
|
||||||
|
history: List[Message] = []
|
||||||
|
|
||||||
|
@persist(persistence)
|
||||||
|
class BaseModelFlow(Flow[State]):
|
||||||
|
initial_state = State(latest_message=None, history=[])
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def init_step(self):
|
||||||
|
self.state.latest_message = Message(role="user", type="text", content="Hello, World!")
|
||||||
|
self.state.history.append(self.state.latest_message)
|
||||||
|
|
||||||
|
flow = BaseModelFlow(persistence=persistence)
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
latest_message = flow.state.latest_message
|
||||||
|
message, = flow.state.history
|
||||||
|
|
||||||
|
assert latest_message is not None
|
||||||
|
assert latest_message.role == "user"
|
||||||
|
assert latest_message.type == "text"
|
||||||
|
assert latest_message.content == "Hello, World!"
|
||||||
|
|
||||||
|
assert len(flow.state.history) == 1
|
||||||
|
assert message.role == "user"
|
||||||
|
assert message.type == "text"
|
||||||
|
assert message.content == "Hello, World!"
|
||||||
|
assert isinstance(flow.state, State)
|
||||||
|
|||||||
Reference in New Issue
Block a user