From 3a266d6b40a336176ffa462e6f31208e846dfeda Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 11 Sep 2024 13:01:36 -0400 Subject: [PATCH] Everything is workign --- src/crewai/flow/flow.py | 50 +++++++++++------------ src/crewai/flow/structured_test_flow.py | 14 +++---- src/crewai/flow/unstructured_test_flow.py | 10 ++--- 3 files changed, 33 insertions(+), 41 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index ead696e90..a5181982d 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,8 +1,8 @@ -from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union from pydantic import BaseModel -TState = TypeVar("TState", bound=Union[BaseModel, Dict[str, Any]]) +T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) class FlowMeta(type): @@ -25,17 +25,25 @@ class FlowMeta(type): setattr(cls, "_start_methods", start_methods) setattr(cls, "_listeners", listeners) + # Inject the state type hint + if "initial_state" in dct: + initial_state = dct["initial_state"] + if isinstance(initial_state, type) and issubclass(initial_state, BaseModel): + cls.__annotations__["state"] = initial_state + elif isinstance(initial_state, dict): + cls.__annotations__["state"] = Dict[str, Any] + return cls -class Flow(Generic[TState], metaclass=FlowMeta): +class Flow(Generic[T], metaclass=FlowMeta): _start_methods: List[str] = [] _listeners: Dict[str, List[str]] = {} - state_class: Type[TState] # Class-level state_class defined once in the subclass + initial_state: Union[Type[T], T, None] = None def __init__(self): self._methods: Dict[str, Callable] = {} - self._state: TState = self._create_default_state() + self._state = self._create_initial_state() for method_name in dir(self): if callable(getattr(self, method_name)) and not method_name.startswith( @@ -43,21 +51,17 @@ class Flow(Generic[TState], metaclass=FlowMeta): ): self._methods[method_name] = getattr(self, method_name) - @property - def state(self) -> TState: - """Ensure state has the correct type.""" - return self._state - - def _create_default_state(self) -> TState: - if not hasattr(self, "state_class"): - raise AttributeError("state_class must be defined in the Flow subclass") - - if issubclass(self.state_class, BaseModel): - return self.state_class() # Automatically initialize with Pydantic defaults - elif self.state_class is dict: - return cast(TState, DictWrapper()) # Cast to TState for DictWrapper + def _create_initial_state(self) -> T: + if self.initial_state is None: + return {} # type: ignore + elif isinstance(self.initial_state, type): + return self.initial_state() else: - raise TypeError(f"Unsupported state type: {self.state_class}") + return self.initial_state + + @property + def state(self) -> T: + return self._state def run(self): if not self._start_methods: @@ -78,14 +82,6 @@ class Flow(Generic[TState], metaclass=FlowMeta): return -class DictWrapper(Dict[str, Any]): - def __getattr__(self, name: str) -> Any: - return self.get(name) - - def __setattr__(self, name: str, value: Any) -> None: - self[name] = value - - def start(): def decorator(func): func.__is_start_method__ = True diff --git a/src/crewai/flow/structured_test_flow.py b/src/crewai/flow/structured_test_flow.py index c0d2382dd..befb65ddb 100644 --- a/src/crewai/flow/structured_test_flow.py +++ b/src/crewai/flow/structured_test_flow.py @@ -8,27 +8,23 @@ class ExampleState(BaseModel): class StructuredExampleFlow(Flow[ExampleState]): - state_class = ExampleState - - def __init__(self): - super().__init__() - print(f"Initial state after __init__: {self.state}") # Debug print + initial_state = ExampleState @start() def start_method(self): print("Starting the structured flow") - print(f"State in start_method: {self.state}") # Debug print + print(f"State in start_method: {self.state}") self.state.message = "Hello from structured flow" - print(f"State after start_method: {self.state}") # Debug print + print(f"State after start_method: {self.state}") return "Start result" @listen(start_method) def second_method(self, result): print(f"Second method, received: {result}") - print(f"State before increment: {self.state}") # Debug print + print(f"State before increment: {self.state}") self.state.counter += 1 self.state.message += " - updated" - print(f"State after second_method: {self.state}") # Debug print + print(f"State after second_method: {self.state}") return "Second result" diff --git a/src/crewai/flow/unstructured_test_flow.py b/src/crewai/flow/unstructured_test_flow.py index d3a7c5f8b..b6cc5e573 100644 --- a/src/crewai/flow/unstructured_test_flow.py +++ b/src/crewai/flow/unstructured_test_flow.py @@ -5,21 +5,21 @@ class FlexibleExampleFlow(Flow): @start() def start_method(self): print("Starting the flexible flow") - self.state.counter = 1 + self.state["counter"] = 1 return "Start result" @listen(start_method) def second_method(self, result): print(f"Second method, received: {result}") - self.state.counter += 1 - self.state.message = "Hello from flexible flow" + self.state["counter"] += 1 + self.state["message"] = "Hello from flexible flow" return "Second result" @listen(second_method) def third_method(self, result): print(f"Third method, received: {result}") - print(f"Final counter value: {self.state.counter}") - print(f"Final message: {self.state.message}") + print(f"Final counter value: {self.state["counter"]}") + print(f"Final message: {self.state["message"]}") return "Third result"