From a4a14df72e120d4dd673639f3a23c22ad4cce43d Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 11 Sep 2024 11:59:12 -0400 Subject: [PATCH] Working but not clean engouth --- src/crewai/flow/flow.py | 16 ++++++++++------ src/crewai/flow/structured_test_flow.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index a0f075c3e..ead696e90 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, get_args +from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, cast from pydantic import BaseModel @@ -31,12 +31,11 @@ class FlowMeta(type): class Flow(Generic[TState], metaclass=FlowMeta): _start_methods: List[str] = [] _listeners: Dict[str, List[str]] = {} - state: TState - state_class: Type[TState] + state_class: Type[TState] # Class-level state_class defined once in the subclass def __init__(self): self._methods: Dict[str, Callable] = {} - self.state = self._create_default_state() + self._state: TState = self._create_default_state() for method_name in dir(self): if callable(getattr(self, method_name)) and not method_name.startswith( @@ -44,14 +43,19 @@ class Flow(Generic[TState], metaclass=FlowMeta): ): self._methods[method_name] = getattr(self, method_name) + @property + def state(self) -> TState: + """Ensure state has the correct type.""" + return self._state + def _create_default_state(self) -> TState: if not hasattr(self, "state_class"): raise AttributeError("state_class must be defined in the Flow subclass") if issubclass(self.state_class, BaseModel): - return self.state_class() + return self.state_class() # Automatically initialize with Pydantic defaults elif self.state_class is dict: - return DictWrapper() # type: ignore + return cast(TState, DictWrapper()) # Cast to TState for DictWrapper else: raise TypeError(f"Unsupported state type: {self.state_class}") diff --git a/src/crewai/flow/structured_test_flow.py b/src/crewai/flow/structured_test_flow.py index d3f9e116b..c0d2382dd 100644 --- a/src/crewai/flow/structured_test_flow.py +++ b/src/crewai/flow/structured_test_flow.py @@ -32,6 +32,6 @@ class StructuredExampleFlow(Flow[ExampleState]): return "Second result" -# Run the flow +# Instantiate and run the flow structured_flow = StructuredExampleFlow() structured_flow.run()