mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
nested models in flow persist
This commit is contained in:
91
chat.py
Normal file
91
chat.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str = Field(
|
||||
description="The role of the message sender (e.g., 'user', 'assistant')"
|
||||
)
|
||||
content: str = Field(description="The actual content/text of the message")
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
message: Optional[Message] = None
|
||||
history: list[Message] = Field(default_factory=list)
|
||||
|
||||
|
||||
@persist(SQLiteFlowPersistence(), verbose=True)
|
||||
class PersonalAssistantFlow(Flow[ChatState]):
|
||||
@start()
|
||||
def chat(self):
|
||||
user_message_pydantic = self.state.message
|
||||
|
||||
# Safety check for None message
|
||||
if not user_message_pydantic:
|
||||
return "No message provided"
|
||||
|
||||
# Format history for prompt
|
||||
history_formatted = "\n".join(
|
||||
[f"{msg.role}: {msg.content}" for msg in self.state.history]
|
||||
)
|
||||
|
||||
prompt = f"""
|
||||
You are a helpful assistant.
|
||||
Answer the user's question: {user_message_pydantic.content}
|
||||
|
||||
Just for the sake of being context-aware, this is the entire conversation history:
|
||||
{history_formatted}
|
||||
|
||||
Be friendly and helpful, yet to the point.
|
||||
"""
|
||||
|
||||
response = LLM(model="gemini/gemini-2.0-flash", response_format=Message).call(
|
||||
prompt
|
||||
)
|
||||
|
||||
# Parse the response
|
||||
if isinstance(response, str):
|
||||
try:
|
||||
llm_response_json = json.loads(response)
|
||||
llm_response_pydantic = Message(**llm_response_json)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if response isn't valid JSON
|
||||
llm_response_pydantic = Message(
|
||||
role="assistant",
|
||||
content="I'm sorry, I encountered an error processing your request.",
|
||||
)
|
||||
else:
|
||||
# If response is already a Message object
|
||||
llm_response_pydantic = response
|
||||
|
||||
# Update history - with type safety
|
||||
if user_message_pydantic: # Ensure message is not None before adding to history
|
||||
self.state.history.append(user_message_pydantic)
|
||||
self.state.history.append(llm_response_pydantic)
|
||||
|
||||
print("History", self.state.history)
|
||||
return llm_response_pydantic.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
user_input = " ".join(sys.argv[1:])
|
||||
else:
|
||||
user_input = input("> ")
|
||||
|
||||
flow = PersonalAssistantFlow()
|
||||
flow.state.message = Message(role="user", content=user_input)
|
||||
|
||||
response = flow.kickoff()
|
||||
print(response)
|
||||
@@ -8,46 +8,84 @@ from pydantic import BaseModel
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
pass
|
||||
|
||||
def _convert_to_dict(self, obj: Any) -> Any:
|
||||
"""Recursively convert Pydantic models to dictionaries.
|
||||
|
||||
This helper method ensures all Pydantic models in the state are
|
||||
properly converted to dictionaries for JSON serialization.
|
||||
|
||||
Args:
|
||||
obj: The object to convert
|
||||
|
||||
Returns:
|
||||
A JSON-serializable version of the object
|
||||
"""
|
||||
if isinstance(obj, BaseModel):
|
||||
# Convert Pydantic model to dict
|
||||
if hasattr(obj, "model_dump"):
|
||||
# Pydantic v2
|
||||
obj_dict = obj.model_dump()
|
||||
else:
|
||||
# Pydantic v1
|
||||
obj_dict = obj.dict()
|
||||
# Recursively convert any nested models
|
||||
return {k: self._convert_to_dict(v) for k, v in obj_dict.items()}
|
||||
elif isinstance(obj, dict):
|
||||
# Recursively convert dict values
|
||||
return {k: self._convert_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
# Recursively convert list items
|
||||
return [self._convert_to_dict(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
# Recursively convert tuple items
|
||||
return tuple(self._convert_to_dict(item) for item in obj)
|
||||
elif isinstance(obj, set):
|
||||
# Recursively convert set items
|
||||
return {self._convert_to_dict(item) for item in obj}
|
||||
else:
|
||||
# Return primitive types as is
|
||||
return obj
|
||||
|
||||
@@ -78,34 +78,53 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
state_dict = dict(state_data) # Use dict() for better type compatibility
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
),
|
||||
)
|
||||
Raises:
|
||||
ValueError: If state_data is neither a dict nor a BaseModel
|
||||
RuntimeError: If database operations fail
|
||||
TypeError: If JSON serialization fails
|
||||
"""
|
||||
try:
|
||||
# Convert state_data to a JSON-serializable dict using the helper method
|
||||
state_dict = self._convert_to_dict(state_data)
|
||||
|
||||
# Try to serialize to JSON to catch any serialization issues early
|
||||
try:
|
||||
state_json = json.dumps(state_dict)
|
||||
except (TypeError, ValueError, OverflowError) as json_err:
|
||||
raise TypeError(
|
||||
f"Failed to serialize state to JSON: {json_err}"
|
||||
) from json_err
|
||||
|
||||
# Perform database operation with error handling
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
state_json,
|
||||
),
|
||||
)
|
||||
except sqlite3.Error as db_err:
|
||||
raise RuntimeError(f"Database operation failed: {db_err}") from db_err
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't crash the application
|
||||
import logging
|
||||
|
||||
logging.error(f"Failed to save flow state: {e}")
|
||||
# Re-raise to allow caller to handle or ignore
|
||||
raise
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Reference in New Issue
Block a user