Everything is workign

This commit is contained in:
Brandon Hancock
2024-09-11 13:01:36 -04:00
parent a4a14df72e
commit 3a266d6b40
3 changed files with 33 additions and 41 deletions

View File

@@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, cast
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union
from pydantic import BaseModel
TState = TypeVar("TState", bound=Union[BaseModel, Dict[str, Any]])
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
class FlowMeta(type):
@@ -25,17 +25,25 @@ class FlowMeta(type):
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
# Inject the state type hint
if "initial_state" in dct:
initial_state = dct["initial_state"]
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
cls.__annotations__["state"] = initial_state
elif isinstance(initial_state, dict):
cls.__annotations__["state"] = Dict[str, Any]
return cls
class Flow(Generic[TState], metaclass=FlowMeta):
class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = []
_listeners: Dict[str, List[str]] = {}
state_class: Type[TState] # Class-level state_class defined once in the subclass
initial_state: Union[Type[T], T, None] = None
def __init__(self):
self._methods: Dict[str, Callable] = {}
self._state: TState = self._create_default_state()
self._state = self._create_initial_state()
for method_name in dir(self):
if callable(getattr(self, method_name)) and not method_name.startswith(
@@ -43,21 +51,17 @@ class Flow(Generic[TState], metaclass=FlowMeta):
):
self._methods[method_name] = getattr(self, method_name)
@property
def state(self) -> TState:
"""Ensure state has the correct type."""
return self._state
def _create_default_state(self) -> TState:
if not hasattr(self, "state_class"):
raise AttributeError("state_class must be defined in the Flow subclass")
if issubclass(self.state_class, BaseModel):
return self.state_class() # Automatically initialize with Pydantic defaults
elif self.state_class is dict:
return cast(TState, DictWrapper()) # Cast to TState for DictWrapper
def _create_initial_state(self) -> T:
if self.initial_state is None:
return {} # type: ignore
elif isinstance(self.initial_state, type):
return self.initial_state()
else:
raise TypeError(f"Unsupported state type: {self.state_class}")
return self.initial_state
@property
def state(self) -> T:
return self._state
def run(self):
if not self._start_methods:
@@ -78,14 +82,6 @@ class Flow(Generic[TState], metaclass=FlowMeta):
return
class DictWrapper(Dict[str, Any]):
def __getattr__(self, name: str) -> Any:
return self.get(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def start():
def decorator(func):
func.__is_start_method__ = True