It fully works but not clean enought

This commit is contained in:
Brandon Hancock
2024-09-11 10:45:24 -04:00
parent d67c12a5a3
commit 8664f3912b
2 changed files with 21 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Generic, List, TypeVar, Union, get_args from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, get_args
from pydantic import BaseModel from pydantic import BaseModel
@@ -32,6 +32,7 @@ class Flow(Generic[TState], metaclass=FlowMeta):
_start_methods: List[str] = [] _start_methods: List[str] = []
_listeners: Dict[str, List[str]] = {} _listeners: Dict[str, List[str]] = {}
state: TState state: TState
state_class: Type[TState]
def __init__(self): def __init__(self):
self._methods: Dict[str, Callable] = {} self._methods: Dict[str, Callable] = {}
@@ -44,18 +45,15 @@ class Flow(Generic[TState], metaclass=FlowMeta):
self._methods[method_name] = getattr(self, method_name) self._methods[method_name] = getattr(self, method_name)
def _create_default_state(self) -> TState: def _create_default_state(self) -> TState:
state_type = self._get_state_type() if not hasattr(self, "state_class"):
if state_type and issubclass(state_type, BaseModel): raise AttributeError("state_class must be defined in the Flow subclass")
return state_type()
return DictWrapper() # type: ignore
def _get_state_type(self) -> type[TState] | None: if issubclass(self.state_class, BaseModel):
for base in self.__class__.__bases__: return self.state_class()
if hasattr(base, "__origin__") and base.__origin__ is Flow: elif self.state_class is dict:
args = get_args(base) return DictWrapper() # type: ignore
if args: else:
return args[0] raise TypeError(f"Unsupported state type: {self.state_class}")
return None
def run(self): def run(self):
if not self._start_methods: if not self._start_methods:

View File

@@ -8,17 +8,27 @@ class ExampleState(BaseModel):
class StructuredExampleFlow(Flow[ExampleState]): class StructuredExampleFlow(Flow[ExampleState]):
state_class = ExampleState
def __init__(self):
super().__init__()
print(f"Initial state after __init__: {self.state}") # Debug print
@start() @start()
def start_method(self): def start_method(self):
print("Starting the structured flow") print("Starting the structured flow")
print(f"State in start_method: {self.state}") # Debug print
self.state.message = "Hello from structured flow" self.state.message = "Hello from structured flow"
print(f"State after start_method: {self.state}") # Debug print
return "Start result" return "Start result"
@listen(start_method) @listen(start_method)
def second_method(self, result): def second_method(self, result):
print(f"Second method, received: {result}") print(f"Second method, received: {result}")
print(f"State before increment: {self.state}") # Debug print
self.state.counter += 1 self.state.counter += 1
self.state.message = "Hello from structured flow" self.state.message += " - updated"
print(f"State after second_method: {self.state}") # Debug print
return "Second result" return "Second result"