From 2ef5cc37a9cdd39a1a5163e53784b238d708a768 Mon Sep 17 00:00:00 2001 From: Thiago Moretto Date: Fri, 13 Sep 2024 15:47:46 -0300 Subject: [PATCH] Get initial state type from generic --- src/crewai/flow/flow.py | 10 +++++++++- src/crewai/flow/structured_test_flow.py | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 3fa91f31e..dcc8068a1 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -115,6 +115,12 @@ class Flow(Generic[T], metaclass=FlowMeta): _listeners: Dict[str, tuple[str, List[str]]] = {} initial_state: Union[Type[T], T, None] = None + def __class_getitem__(cls, item): + print(f"[Flow.__class_getitem__] Getting initial state type: {item}") + class _FlowGeneric(cls): + _initial_state_T = item + return _FlowGeneric + def __init__(self): print("[Flow.__init__] Initializing Flow") self._methods: Dict[str, Callable] = {} @@ -134,7 +140,9 @@ class Flow(Generic[T], metaclass=FlowMeta): def _create_initial_state(self) -> T: print("[Flow._create_initial_state] Creating initial state") - if self.initial_state is None: + if self.initial_state is None and hasattr(self, "_initial_state_T"): + return self._initial_state_T() + elif self.initial_state is None: return {} # type: ignore elif isinstance(self.initial_state, type): return self.initial_state() diff --git a/src/crewai/flow/structured_test_flow.py b/src/crewai/flow/structured_test_flow.py index 20b687225..9872d0caa 100644 --- a/src/crewai/flow/structured_test_flow.py +++ b/src/crewai/flow/structured_test_flow.py @@ -9,8 +9,7 @@ class ExampleState(BaseModel): message: str = "" -class StructuredExampleFlow(Flow): - initial_state = ExampleState +class StructuredExampleFlow(Flow[ExampleState]): @start() async def start_method(self):