nested models in flow persist

This commit is contained in:
Brandon Hancock
2025-03-05 16:14:50 -05:00
parent 00eede0d5d
commit 6677c9c192
3 changed files with 184 additions and 36 deletions

91
chat.py Normal file
View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python
import json
from typing import Any, Optional
from pydantic import BaseModel, Field
from crewai.flow import Flow, start
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from crewai.llm import LLM
class Message(BaseModel):
role: str = Field(
description="The role of the message sender (e.g., 'user', 'assistant')"
)
content: str = Field(description="The actual content/text of the message")
class ChatState(BaseModel):
message: Optional[Message] = None
history: list[Message] = Field(default_factory=list)
@persist(SQLiteFlowPersistence(), verbose=True)
class PersonalAssistantFlow(Flow[ChatState]):
@start()
def chat(self):
user_message_pydantic = self.state.message
# Safety check for None message
if not user_message_pydantic:
return "No message provided"
# Format history for prompt
history_formatted = "\n".join(
[f"{msg.role}: {msg.content}" for msg in self.state.history]
)
prompt = f"""
You are a helpful assistant.
Answer the user's question: {user_message_pydantic.content}
Just for the sake of being context-aware, this is the entire conversation history:
{history_formatted}
Be friendly and helpful, yet to the point.
"""
response = LLM(model="gemini/gemini-2.0-flash", response_format=Message).call(
prompt
)
# Parse the response
if isinstance(response, str):
try:
llm_response_json = json.loads(response)
llm_response_pydantic = Message(**llm_response_json)
except json.JSONDecodeError:
# Fallback if response isn't valid JSON
llm_response_pydantic = Message(
role="assistant",
content="I'm sorry, I encountered an error processing your request.",
)
else:
# If response is already a Message object
llm_response_pydantic = response
# Update history - with type safety
if user_message_pydantic: # Ensure message is not None before adding to history
self.state.history.append(user_message_pydantic)
self.state.history.append(llm_response_pydantic)
print("History", self.state.history)
return llm_response_pydantic.content
if __name__ == "__main__":
# Example usage
import sys
if len(sys.argv) > 1:
user_input = " ".join(sys.argv[1:])
else:
user_input = input("> ")
flow = PersonalAssistantFlow()
flow.state.message = Message(role="user", content=user_input)
response = flow.kickoff()
print(response)

View File

@@ -8,46 +8,84 @@ from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
state_data: Union[Dict[str, Any], BaseModel],
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass
def _convert_to_dict(self, obj: Any) -> Any:
"""Recursively convert Pydantic models to dictionaries.
This helper method ensures all Pydantic models in the state are
properly converted to dictionaries for JSON serialization.
Args:
obj: The object to convert
Returns:
A JSON-serializable version of the object
"""
if isinstance(obj, BaseModel):
# Convert Pydantic model to dict
if hasattr(obj, "model_dump"):
# Pydantic v2
obj_dict = obj.model_dump()
else:
# Pydantic v1
obj_dict = obj.dict()
# Recursively convert any nested models
return {k: self._convert_to_dict(v) for k, v in obj_dict.items()}
elif isinstance(obj, dict):
# Recursively convert dict values
return {k: self._convert_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recursively convert list items
return [self._convert_to_dict(item) for item in obj]
elif isinstance(obj, tuple):
# Recursively convert tuple items
return tuple(self._convert_to_dict(item) for item in obj)
elif isinstance(obj, set):
# Recursively convert set items
return {self._convert_to_dict(item) for item in obj}
else:
# Return primitive types as is
return obj

View File

@@ -78,34 +78,53 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = dict(state_data) # Use dict() for better type compatibility
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
Raises:
ValueError: If state_data is neither a dict nor a BaseModel
RuntimeError: If database operations fail
TypeError: If JSON serialization fails
"""
try:
# Convert state_data to a JSON-serializable dict using the helper method
state_dict = self._convert_to_dict(state_data)
# Try to serialize to JSON to catch any serialization issues early
try:
state_json = json.dumps(state_dict)
except (TypeError, ValueError, OverflowError) as json_err:
raise TypeError(
f"Failed to serialize state to JSON: {json_err}"
) from json_err
# Perform database operation with error handling
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
state_json,
),
)
except sqlite3.Error as db_err:
raise RuntimeError(f"Database operation failed: {db_err}") from db_err
except Exception as e:
# Log the error but don't crash the application
import logging
logging.error(f"Failed to save flow state: {e}")
# Re-raise to allow caller to handle or ignore
raise
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.