Everything is workign

This commit is contained in:
Brandon Hancock
2024-09-11 13:01:36 -04:00
parent a4a14df72e
commit 3a266d6b40
3 changed files with 33 additions and 41 deletions

View File

@@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, cast
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union
from pydantic import BaseModel
TState = TypeVar("TState", bound=Union[BaseModel, Dict[str, Any]])
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
class FlowMeta(type):
@@ -25,17 +25,25 @@ class FlowMeta(type):
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
# Inject the state type hint
if "initial_state" in dct:
initial_state = dct["initial_state"]
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
cls.__annotations__["state"] = initial_state
elif isinstance(initial_state, dict):
cls.__annotations__["state"] = Dict[str, Any]
return cls
class Flow(Generic[TState], metaclass=FlowMeta):
class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = []
_listeners: Dict[str, List[str]] = {}
state_class: Type[TState] # Class-level state_class defined once in the subclass
initial_state: Union[Type[T], T, None] = None
def __init__(self):
self._methods: Dict[str, Callable] = {}
self._state: TState = self._create_default_state()
self._state = self._create_initial_state()
for method_name in dir(self):
if callable(getattr(self, method_name)) and not method_name.startswith(
@@ -43,21 +51,17 @@ class Flow(Generic[TState], metaclass=FlowMeta):
):
self._methods[method_name] = getattr(self, method_name)
@property
def state(self) -> TState:
"""Ensure state has the correct type."""
return self._state
def _create_default_state(self) -> TState:
if not hasattr(self, "state_class"):
raise AttributeError("state_class must be defined in the Flow subclass")
if issubclass(self.state_class, BaseModel):
return self.state_class() # Automatically initialize with Pydantic defaults
elif self.state_class is dict:
return cast(TState, DictWrapper()) # Cast to TState for DictWrapper
def _create_initial_state(self) -> T:
if self.initial_state is None:
return {} # type: ignore
elif isinstance(self.initial_state, type):
return self.initial_state()
else:
raise TypeError(f"Unsupported state type: {self.state_class}")
return self.initial_state
@property
def state(self) -> T:
return self._state
def run(self):
if not self._start_methods:
@@ -78,14 +82,6 @@ class Flow(Generic[TState], metaclass=FlowMeta):
return
class DictWrapper(Dict[str, Any]):
def __getattr__(self, name: str) -> Any:
return self.get(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def start():
def decorator(func):
func.__is_start_method__ = True

View File

@@ -8,27 +8,23 @@ class ExampleState(BaseModel):
class StructuredExampleFlow(Flow[ExampleState]):
state_class = ExampleState
def __init__(self):
super().__init__()
print(f"Initial state after __init__: {self.state}") # Debug print
initial_state = ExampleState
@start()
def start_method(self):
print("Starting the structured flow")
print(f"State in start_method: {self.state}") # Debug print
print(f"State in start_method: {self.state}")
self.state.message = "Hello from structured flow"
print(f"State after start_method: {self.state}") # Debug print
print(f"State after start_method: {self.state}")
return "Start result"
@listen(start_method)
def second_method(self, result):
print(f"Second method, received: {result}")
print(f"State before increment: {self.state}") # Debug print
print(f"State before increment: {self.state}")
self.state.counter += 1
self.state.message += " - updated"
print(f"State after second_method: {self.state}") # Debug print
print(f"State after second_method: {self.state}")
return "Second result"

View File

@@ -5,21 +5,21 @@ class FlexibleExampleFlow(Flow):
@start()
def start_method(self):
print("Starting the flexible flow")
self.state.counter = 1
self.state["counter"] = 1
return "Start result"
@listen(start_method)
def second_method(self, result):
print(f"Second method, received: {result}")
self.state.counter += 1
self.state.message = "Hello from flexible flow"
self.state["counter"] += 1
self.state["message"] = "Hello from flexible flow"
return "Second result"
@listen(second_method)
def third_method(self, result):
print(f"Third method, received: {result}")
print(f"Final counter value: {self.state.counter}")
print(f"Final message: {self.state.message}")
print(f"Final counter value: {self.state["counter"]}")
print(f"Final message: {self.state["message"]}")
return "Third result"