Everything is workign

This commit is contained in:
Brandon Hancock
2024-09-11 13:01:36 -04:00
parent a4a14df72e
commit 3a266d6b40
3 changed files with 33 additions and 41 deletions

View File

@@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, cast from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union
from pydantic import BaseModel from pydantic import BaseModel
TState = TypeVar("TState", bound=Union[BaseModel, Dict[str, Any]]) T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
class FlowMeta(type): class FlowMeta(type):
@@ -25,17 +25,25 @@ class FlowMeta(type):
setattr(cls, "_start_methods", start_methods) setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners) setattr(cls, "_listeners", listeners)
# Inject the state type hint
if "initial_state" in dct:
initial_state = dct["initial_state"]
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
cls.__annotations__["state"] = initial_state
elif isinstance(initial_state, dict):
cls.__annotations__["state"] = Dict[str, Any]
return cls return cls
class Flow(Generic[TState], metaclass=FlowMeta): class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = [] _start_methods: List[str] = []
_listeners: Dict[str, List[str]] = {} _listeners: Dict[str, List[str]] = {}
state_class: Type[TState] # Class-level state_class defined once in the subclass initial_state: Union[Type[T], T, None] = None
def __init__(self): def __init__(self):
self._methods: Dict[str, Callable] = {} self._methods: Dict[str, Callable] = {}
self._state: TState = self._create_default_state() self._state = self._create_initial_state()
for method_name in dir(self): for method_name in dir(self):
if callable(getattr(self, method_name)) and not method_name.startswith( if callable(getattr(self, method_name)) and not method_name.startswith(
@@ -43,21 +51,17 @@ class Flow(Generic[TState], metaclass=FlowMeta):
): ):
self._methods[method_name] = getattr(self, method_name) self._methods[method_name] = getattr(self, method_name)
@property def _create_initial_state(self) -> T:
def state(self) -> TState: if self.initial_state is None:
"""Ensure state has the correct type.""" return {} # type: ignore
return self._state elif isinstance(self.initial_state, type):
return self.initial_state()
def _create_default_state(self) -> TState:
if not hasattr(self, "state_class"):
raise AttributeError("state_class must be defined in the Flow subclass")
if issubclass(self.state_class, BaseModel):
return self.state_class() # Automatically initialize with Pydantic defaults
elif self.state_class is dict:
return cast(TState, DictWrapper()) # Cast to TState for DictWrapper
else: else:
raise TypeError(f"Unsupported state type: {self.state_class}") return self.initial_state
@property
def state(self) -> T:
return self._state
def run(self): def run(self):
if not self._start_methods: if not self._start_methods:
@@ -78,14 +82,6 @@ class Flow(Generic[TState], metaclass=FlowMeta):
return return
class DictWrapper(Dict[str, Any]):
def __getattr__(self, name: str) -> Any:
return self.get(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def start(): def start():
def decorator(func): def decorator(func):
func.__is_start_method__ = True func.__is_start_method__ = True

View File

@@ -8,27 +8,23 @@ class ExampleState(BaseModel):
class StructuredExampleFlow(Flow[ExampleState]): class StructuredExampleFlow(Flow[ExampleState]):
state_class = ExampleState initial_state = ExampleState
def __init__(self):
super().__init__()
print(f"Initial state after __init__: {self.state}") # Debug print
@start() @start()
def start_method(self): def start_method(self):
print("Starting the structured flow") print("Starting the structured flow")
print(f"State in start_method: {self.state}") # Debug print print(f"State in start_method: {self.state}")
self.state.message = "Hello from structured flow" self.state.message = "Hello from structured flow"
print(f"State after start_method: {self.state}") # Debug print print(f"State after start_method: {self.state}")
return "Start result" return "Start result"
@listen(start_method) @listen(start_method)
def second_method(self, result): def second_method(self, result):
print(f"Second method, received: {result}") print(f"Second method, received: {result}")
print(f"State before increment: {self.state}") # Debug print print(f"State before increment: {self.state}")
self.state.counter += 1 self.state.counter += 1
self.state.message += " - updated" self.state.message += " - updated"
print(f"State after second_method: {self.state}") # Debug print print(f"State after second_method: {self.state}")
return "Second result" return "Second result"

View File

@@ -5,21 +5,21 @@ class FlexibleExampleFlow(Flow):
@start() @start()
def start_method(self): def start_method(self):
print("Starting the flexible flow") print("Starting the flexible flow")
self.state.counter = 1 self.state["counter"] = 1
return "Start result" return "Start result"
@listen(start_method) @listen(start_method)
def second_method(self, result): def second_method(self, result):
print(f"Second method, received: {result}") print(f"Second method, received: {result}")
self.state.counter += 1 self.state["counter"] += 1
self.state.message = "Hello from flexible flow" self.state["message"] = "Hello from flexible flow"
return "Second result" return "Second result"
@listen(second_method) @listen(second_method)
def third_method(self, result): def third_method(self, result):
print(f"Third method, received: {result}") print(f"Third method, received: {result}")
print(f"Final counter value: {self.state.counter}") print(f"Final counter value: {self.state["counter"]}")
print(f"Final message: {self.state.message}") print(f"Final message: {self.state["message"]}")
return "Third result" return "Third result"