From 6677c9c192bceb504a29f653f2b1609af33be09e Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 5 Mar 2025 16:14:50 -0500 Subject: [PATCH] nested models in flow persist --- chat.py | 91 +++++++++++++++++++++++++++ src/crewai/flow/persistence/base.py | 56 ++++++++++++++--- src/crewai/flow/persistence/sqlite.py | 73 +++++++++++++-------- 3 files changed, 184 insertions(+), 36 deletions(-) create mode 100644 chat.py diff --git a/chat.py b/chat.py new file mode 100644 index 000000000..3614ee52b --- /dev/null +++ b/chat.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +import json +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from crewai.flow import Flow, start +from crewai.flow.persistence.decorators import persist +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence +from crewai.llm import LLM + + +class Message(BaseModel): + role: str = Field( + description="The role of the message sender (e.g., 'user', 'assistant')" + ) + content: str = Field(description="The actual content/text of the message") + + +class ChatState(BaseModel): + message: Optional[Message] = None + history: list[Message] = Field(default_factory=list) + + +@persist(SQLiteFlowPersistence(), verbose=True) +class PersonalAssistantFlow(Flow[ChatState]): + @start() + def chat(self): + user_message_pydantic = self.state.message + + # Safety check for None message + if not user_message_pydantic: + return "No message provided" + + # Format history for prompt + history_formatted = "\n".join( + [f"{msg.role}: {msg.content}" for msg in self.state.history] + ) + + prompt = f""" + You are a helpful assistant. + Answer the user's question: {user_message_pydantic.content} + + Just for the sake of being context-aware, this is the entire conversation history: + {history_formatted} + + Be friendly and helpful, yet to the point. + """ + + response = LLM(model="gemini/gemini-2.0-flash", response_format=Message).call( + prompt + ) + + # Parse the response + if isinstance(response, str): + try: + llm_response_json = json.loads(response) + llm_response_pydantic = Message(**llm_response_json) + except json.JSONDecodeError: + # Fallback if response isn't valid JSON + llm_response_pydantic = Message( + role="assistant", + content="I'm sorry, I encountered an error processing your request.", + ) + else: + # If response is already a Message object + llm_response_pydantic = response + + # Update history - with type safety + if user_message_pydantic: # Ensure message is not None before adding to history + self.state.history.append(user_message_pydantic) + self.state.history.append(llm_response_pydantic) + + print("History", self.state.history) + return llm_response_pydantic.content + + +if __name__ == "__main__": + # Example usage + import sys + + if len(sys.argv) > 1: + user_input = " ".join(sys.argv[1:]) + else: + user_input = input("> ") + + flow = PersonalAssistantFlow() + flow.state.message = Message(role="user", content=user_input) + + response = flow.kickoff() + print(response) diff --git a/src/crewai/flow/persistence/base.py b/src/crewai/flow/persistence/base.py index c926f6f34..d36a59ebe 100644 --- a/src/crewai/flow/persistence/base.py +++ b/src/crewai/flow/persistence/base.py @@ -8,46 +8,84 @@ from pydantic import BaseModel class FlowPersistence(abc.ABC): """Abstract base class for flow state persistence. - + This class defines the interface that all persistence implementations must follow. It supports both structured (Pydantic BaseModel) and unstructured (dict) states. """ - + @abc.abstractmethod def init_db(self) -> None: """Initialize the persistence backend. - + This method should handle any necessary setup, such as: - Creating tables - Establishing connections - Setting up indexes """ pass - + @abc.abstractmethod def save_state( self, flow_uuid: str, method_name: str, - state_data: Union[Dict[str, Any], BaseModel] + state_data: Union[Dict[str, Any], BaseModel], ) -> None: """Persist the flow state after method completion. - + Args: flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) """ pass - + @abc.abstractmethod def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: """Load the most recent state for a given flow UUID. - + Args: flow_uuid: Unique identifier for the flow instance - + Returns: The most recent state as a dictionary, or None if no state exists """ pass + + def _convert_to_dict(self, obj: Any) -> Any: + """Recursively convert Pydantic models to dictionaries. + + This helper method ensures all Pydantic models in the state are + properly converted to dictionaries for JSON serialization. + + Args: + obj: The object to convert + + Returns: + A JSON-serializable version of the object + """ + if isinstance(obj, BaseModel): + # Convert Pydantic model to dict + if hasattr(obj, "model_dump"): + # Pydantic v2 + obj_dict = obj.model_dump() + else: + # Pydantic v1 + obj_dict = obj.dict() + # Recursively convert any nested models + return {k: self._convert_to_dict(v) for k, v in obj_dict.items()} + elif isinstance(obj, dict): + # Recursively convert dict values + return {k: self._convert_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively convert list items + return [self._convert_to_dict(item) for item in obj] + elif isinstance(obj, tuple): + # Recursively convert tuple items + return tuple(self._convert_to_dict(item) for item in obj) + elif isinstance(obj, set): + # Recursively convert set items + return {self._convert_to_dict(item) for item in obj} + else: + # Return primitive types as is + return obj diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py index 21e906afd..a0b678d66 100644 --- a/src/crewai/flow/persistence/sqlite.py +++ b/src/crewai/flow/persistence/sqlite.py @@ -78,34 +78,53 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) - """ - # Convert state_data to dict, handling both Pydantic and dict cases - if isinstance(state_data, BaseModel): - state_dict = dict(state_data) # Use dict() for better type compatibility - elif isinstance(state_data, dict): - state_dict = state_data - else: - raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" - ) - with sqlite3.connect(self.db_path) as conn: - conn.execute( - """ - INSERT INTO flow_states ( - flow_uuid, - method_name, - timestamp, - state_json - ) VALUES (?, ?, ?, ?) - """, - ( - flow_uuid, - method_name, - datetime.now(timezone.utc).isoformat(), - json.dumps(state_dict), - ), - ) + Raises: + ValueError: If state_data is neither a dict nor a BaseModel + RuntimeError: If database operations fail + TypeError: If JSON serialization fails + """ + try: + # Convert state_data to a JSON-serializable dict using the helper method + state_dict = self._convert_to_dict(state_data) + + # Try to serialize to JSON to catch any serialization issues early + try: + state_json = json.dumps(state_dict) + except (TypeError, ValueError, OverflowError) as json_err: + raise TypeError( + f"Failed to serialize state to JSON: {json_err}" + ) from json_err + + # Perform database operation with error handling + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT INTO flow_states ( + flow_uuid, + method_name, + timestamp, + state_json + ) VALUES (?, ?, ?, ?) + """, + ( + flow_uuid, + method_name, + datetime.now(timezone.utc).isoformat(), + state_json, + ), + ) + except sqlite3.Error as db_err: + raise RuntimeError(f"Database operation failed: {db_err}") from db_err + + except Exception as e: + # Log the error but don't crash the application + import logging + + logging.error(f"Failed to save flow state: {e}") + # Re-raise to allow caller to handle or ignore + raise def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: """Load the most recent state for a given flow UUID.