mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
nested models in flow persist
This commit is contained in:
91
chat.py
Normal file
91
chat.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.flow import Flow, start
|
||||||
|
from crewai.flow.persistence.decorators import persist
|
||||||
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str = Field(
|
||||||
|
description="The role of the message sender (e.g., 'user', 'assistant')"
|
||||||
|
)
|
||||||
|
content: str = Field(description="The actual content/text of the message")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatState(BaseModel):
|
||||||
|
message: Optional[Message] = None
|
||||||
|
history: list[Message] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@persist(SQLiteFlowPersistence(), verbose=True)
|
||||||
|
class PersonalAssistantFlow(Flow[ChatState]):
|
||||||
|
@start()
|
||||||
|
def chat(self):
|
||||||
|
user_message_pydantic = self.state.message
|
||||||
|
|
||||||
|
# Safety check for None message
|
||||||
|
if not user_message_pydantic:
|
||||||
|
return "No message provided"
|
||||||
|
|
||||||
|
# Format history for prompt
|
||||||
|
history_formatted = "\n".join(
|
||||||
|
[f"{msg.role}: {msg.content}" for msg in self.state.history]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are a helpful assistant.
|
||||||
|
Answer the user's question: {user_message_pydantic.content}
|
||||||
|
|
||||||
|
Just for the sake of being context-aware, this is the entire conversation history:
|
||||||
|
{history_formatted}
|
||||||
|
|
||||||
|
Be friendly and helpful, yet to the point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = LLM(model="gemini/gemini-2.0-flash", response_format=Message).call(
|
||||||
|
prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
if isinstance(response, str):
|
||||||
|
try:
|
||||||
|
llm_response_json = json.loads(response)
|
||||||
|
llm_response_pydantic = Message(**llm_response_json)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Fallback if response isn't valid JSON
|
||||||
|
llm_response_pydantic = Message(
|
||||||
|
role="assistant",
|
||||||
|
content="I'm sorry, I encountered an error processing your request.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If response is already a Message object
|
||||||
|
llm_response_pydantic = response
|
||||||
|
|
||||||
|
# Update history - with type safety
|
||||||
|
if user_message_pydantic: # Ensure message is not None before adding to history
|
||||||
|
self.state.history.append(user_message_pydantic)
|
||||||
|
self.state.history.append(llm_response_pydantic)
|
||||||
|
|
||||||
|
print("History", self.state.history)
|
||||||
|
return llm_response_pydantic.content
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
user_input = " ".join(sys.argv[1:])
|
||||||
|
else:
|
||||||
|
user_input = input("> ")
|
||||||
|
|
||||||
|
flow = PersonalAssistantFlow()
|
||||||
|
flow.state.message = Message(role="user", content=user_input)
|
||||||
|
|
||||||
|
response = flow.kickoff()
|
||||||
|
print(response)
|
||||||
@@ -29,7 +29,7 @@ class FlowPersistence(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
flow_uuid: str,
|
flow_uuid: str,
|
||||||
method_name: str,
|
method_name: str,
|
||||||
state_data: Union[Dict[str, Any], BaseModel]
|
state_data: Union[Dict[str, Any], BaseModel],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Persist the flow state after method completion.
|
"""Persist the flow state after method completion.
|
||||||
|
|
||||||
@@ -51,3 +51,41 @@ class FlowPersistence(abc.ABC):
|
|||||||
The most recent state as a dictionary, or None if no state exists
|
The most recent state as a dictionary, or None if no state exists
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _convert_to_dict(self, obj: Any) -> Any:
|
||||||
|
"""Recursively convert Pydantic models to dictionaries.
|
||||||
|
|
||||||
|
This helper method ensures all Pydantic models in the state are
|
||||||
|
properly converted to dictionaries for JSON serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-serializable version of the object
|
||||||
|
"""
|
||||||
|
if isinstance(obj, BaseModel):
|
||||||
|
# Convert Pydantic model to dict
|
||||||
|
if hasattr(obj, "model_dump"):
|
||||||
|
# Pydantic v2
|
||||||
|
obj_dict = obj.model_dump()
|
||||||
|
else:
|
||||||
|
# Pydantic v1
|
||||||
|
obj_dict = obj.dict()
|
||||||
|
# Recursively convert any nested models
|
||||||
|
return {k: self._convert_to_dict(v) for k, v in obj_dict.items()}
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
# Recursively convert dict values
|
||||||
|
return {k: self._convert_to_dict(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
# Recursively convert list items
|
||||||
|
return [self._convert_to_dict(item) for item in obj]
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
# Recursively convert tuple items
|
||||||
|
return tuple(self._convert_to_dict(item) for item in obj)
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
# Recursively convert set items
|
||||||
|
return {self._convert_to_dict(item) for item in obj}
|
||||||
|
else:
|
||||||
|
# Return primitive types as is
|
||||||
|
return obj
|
||||||
|
|||||||
@@ -78,34 +78,53 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
flow_uuid: Unique identifier for the flow instance
|
flow_uuid: Unique identifier for the flow instance
|
||||||
method_name: Name of the method that just completed
|
method_name: Name of the method that just completed
|
||||||
state_data: Current state data (either dict or Pydantic model)
|
state_data: Current state data (either dict or Pydantic model)
|
||||||
"""
|
|
||||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
|
||||||
if isinstance(state_data, BaseModel):
|
|
||||||
state_dict = dict(state_data) # Use dict() for better type compatibility
|
|
||||||
elif isinstance(state_data, dict):
|
|
||||||
state_dict = state_data
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
Raises:
|
||||||
conn.execute(
|
ValueError: If state_data is neither a dict nor a BaseModel
|
||||||
"""
|
RuntimeError: If database operations fail
|
||||||
INSERT INTO flow_states (
|
TypeError: If JSON serialization fails
|
||||||
flow_uuid,
|
"""
|
||||||
method_name,
|
try:
|
||||||
timestamp,
|
# Convert state_data to a JSON-serializable dict using the helper method
|
||||||
state_json
|
state_dict = self._convert_to_dict(state_data)
|
||||||
) VALUES (?, ?, ?, ?)
|
|
||||||
""",
|
# Try to serialize to JSON to catch any serialization issues early
|
||||||
(
|
try:
|
||||||
flow_uuid,
|
state_json = json.dumps(state_dict)
|
||||||
method_name,
|
except (TypeError, ValueError, OverflowError) as json_err:
|
||||||
datetime.now(timezone.utc).isoformat(),
|
raise TypeError(
|
||||||
json.dumps(state_dict),
|
f"Failed to serialize state to JSON: {json_err}"
|
||||||
),
|
) from json_err
|
||||||
)
|
|
||||||
|
# Perform database operation with error handling
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO flow_states (
|
||||||
|
flow_uuid,
|
||||||
|
method_name,
|
||||||
|
timestamp,
|
||||||
|
state_json
|
||||||
|
) VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
flow_uuid,
|
||||||
|
method_name,
|
||||||
|
datetime.now(timezone.utc).isoformat(),
|
||||||
|
state_json,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except sqlite3.Error as db_err:
|
||||||
|
raise RuntimeError(f"Database operation failed: {db_err}") from db_err
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but don't crash the application
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.error(f"Failed to save flow state: {e}")
|
||||||
|
# Re-raise to allow caller to handle or ignore
|
||||||
|
raise
|
||||||
|
|
||||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Load the most recent state for a given flow UUID.
|
"""Load the most recent state for a given flow UUID.
|
||||||
|
|||||||
Reference in New Issue
Block a user