From 8664f3912b6417e9e8d01311d7c80411ef870d3f Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 11 Sep 2024 10:45:24 -0400 Subject: [PATCH] It fully works but not clean enought --- src/crewai/flow/flow.py | 22 ++++++++++------------ src/crewai/flow/structured_test_flow.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 6cac9a7b7..a0f075c3e 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union, get_args +from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, get_args from pydantic import BaseModel @@ -32,6 +32,7 @@ class Flow(Generic[TState], metaclass=FlowMeta): _start_methods: List[str] = [] _listeners: Dict[str, List[str]] = {} state: TState + state_class: Type[TState] def __init__(self): self._methods: Dict[str, Callable] = {} @@ -44,18 +45,15 @@ class Flow(Generic[TState], metaclass=FlowMeta): self._methods[method_name] = getattr(self, method_name) def _create_default_state(self) -> TState: - state_type = self._get_state_type() - if state_type and issubclass(state_type, BaseModel): - return state_type() - return DictWrapper() # type: ignore + if not hasattr(self, "state_class"): + raise AttributeError("state_class must be defined in the Flow subclass") - def _get_state_type(self) -> type[TState] | None: - for base in self.__class__.__bases__: - if hasattr(base, "__origin__") and base.__origin__ is Flow: - args = get_args(base) - if args: - return args[0] - return None + if issubclass(self.state_class, BaseModel): + return self.state_class() + elif self.state_class is dict: + return DictWrapper() # type: ignore + else: + raise TypeError(f"Unsupported state type: {self.state_class}") def run(self): if not self._start_methods: diff --git a/src/crewai/flow/structured_test_flow.py b/src/crewai/flow/structured_test_flow.py index dfd92972c..d3f9e116b 100644 --- a/src/crewai/flow/structured_test_flow.py +++ b/src/crewai/flow/structured_test_flow.py @@ -8,17 +8,27 @@ class ExampleState(BaseModel): class StructuredExampleFlow(Flow[ExampleState]): + state_class = ExampleState + + def __init__(self): + super().__init__() + print(f"Initial state after __init__: {self.state}") # Debug print + @start() def start_method(self): print("Starting the structured flow") + print(f"State in start_method: {self.state}") # Debug print self.state.message = "Hello from structured flow" + print(f"State after start_method: {self.state}") # Debug print return "Start result" @listen(start_method) def second_method(self, result): print(f"Second method, received: {result}") + print(f"State before increment: {self.state}") # Debug print self.state.counter += 1 - self.state.message = "Hello from structured flow" + self.state.message += " - updated" + print(f"State after second_method: {self.state}") # Debug print return "Second result"