From a4fad7cafd33df055c54b3ec38948465a721bae1 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 19 Sep 2024 13:32:04 -0400 Subject: [PATCH] Added in Thiago fix --- src/crewai/flow/flow.py | 8 ++++++++ src/crewai/flow/structured_test_flow.py | 9 +++------ src/crewai/flow/structured_test_flow_or.py | 1 - src/crewai/flow/unstructured_test_flow.py | 8 ++++---- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 8623009c6..0ff5d111b 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -137,6 +137,12 @@ class Flow(Generic[T], metaclass=FlowMeta): _routers: Dict[str, str] = {} initial_state: Union[Type[T], T, None] = None + def __class_getitem__(cls, item): + class _FlowGeneric(cls): + _initial_state_T = item + + return _FlowGeneric + def __init__(self): print("[Flow.__init__] Initializing Flow") self._methods: Dict[str, Callable] = {} @@ -152,6 +158,8 @@ class Flow(Generic[T], metaclass=FlowMeta): def _create_initial_state(self) -> T: print("[Flow._create_initial_state] Creating initial state") + if self.initial_state is None and hasattr(self, "_initial_state_T"): + return self._initial_state_T() # type: ignore if self.initial_state is None: return {} # type: ignore elif isinstance(self.initial_state, type): diff --git a/src/crewai/flow/structured_test_flow.py b/src/crewai/flow/structured_test_flow.py index 20b687225..d6c589714 100644 --- a/src/crewai/flow/structured_test_flow.py +++ b/src/crewai/flow/structured_test_flow.py @@ -9,8 +9,7 @@ class ExampleState(BaseModel): message: str = "" -class StructuredExampleFlow(Flow): - initial_state = ExampleState +class StructuredExampleFlow(Flow[ExampleState]): @start() async def start_method(self): @@ -21,8 +20,7 @@ class StructuredExampleFlow(Flow): return "Start result" @listen(start_method) - async def second_method(self, result): - print(f"Second method, received: {result}") + async def second_method(self): print(f"State before increment: {self.state}") self.state.counter += 1 self.state.message += " - updated" @@ -30,8 +28,7 @@ class StructuredExampleFlow(Flow): return "Second result" @listen(start_method) - async def third_method(self, result): - print(f"Third method, received: {result}") + async def third_method(self): print(f"State before increment: {self.state}") self.state.counter += 1 self.state.message += " - updated" diff --git a/src/crewai/flow/structured_test_flow_or.py b/src/crewai/flow/structured_test_flow_or.py index a7e58ac35..e7716277e 100644 --- a/src/crewai/flow/structured_test_flow_or.py +++ b/src/crewai/flow/structured_test_flow_or.py @@ -10,7 +10,6 @@ class ExampleState(BaseModel): class StructuredExampleFlow(Flow[ExampleState]): - initial_state = ExampleState @start() async def start_method(self): diff --git a/src/crewai/flow/unstructured_test_flow.py b/src/crewai/flow/unstructured_test_flow.py index cff142e4f..73b4cffaf 100644 --- a/src/crewai/flow/unstructured_test_flow.py +++ b/src/crewai/flow/unstructured_test_flow.py @@ -11,15 +11,15 @@ class FlexibleExampleFlow(Flow): return "Start result" @listen(start_method) - def second_method(self, result): - print(f"Second method, received: {result}") + def second_method(self): + print("Second method") self.state["counter"] += 1 self.state["message"] = "Hello from flexible flow" return "Second result" @listen(second_method) - def third_method(self, result): - print(f"Third method, received: {result}") + def third_method(self): + print("Third method") print(f"Final counter value: {self.state["counter"]}") print(f"Final message: {self.state["message"]}") return "Third result"