mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
It fully works but not clean enought
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Dict, Generic, List, TypeVar, Union, get_args
|
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -32,6 +32,7 @@ class Flow(Generic[TState], metaclass=FlowMeta):
|
|||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
_listeners: Dict[str, List[str]] = {}
|
_listeners: Dict[str, List[str]] = {}
|
||||||
state: TState
|
state: TState
|
||||||
|
state_class: Type[TState]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._methods: Dict[str, Callable] = {}
|
self._methods: Dict[str, Callable] = {}
|
||||||
@@ -44,18 +45,15 @@ class Flow(Generic[TState], metaclass=FlowMeta):
|
|||||||
self._methods[method_name] = getattr(self, method_name)
|
self._methods[method_name] = getattr(self, method_name)
|
||||||
|
|
||||||
def _create_default_state(self) -> TState:
|
def _create_default_state(self) -> TState:
|
||||||
state_type = self._get_state_type()
|
if not hasattr(self, "state_class"):
|
||||||
if state_type and issubclass(state_type, BaseModel):
|
raise AttributeError("state_class must be defined in the Flow subclass")
|
||||||
return state_type()
|
|
||||||
return DictWrapper() # type: ignore
|
|
||||||
|
|
||||||
def _get_state_type(self) -> type[TState] | None:
|
if issubclass(self.state_class, BaseModel):
|
||||||
for base in self.__class__.__bases__:
|
return self.state_class()
|
||||||
if hasattr(base, "__origin__") and base.__origin__ is Flow:
|
elif self.state_class is dict:
|
||||||
args = get_args(base)
|
return DictWrapper() # type: ignore
|
||||||
if args:
|
else:
|
||||||
return args[0]
|
raise TypeError(f"Unsupported state type: {self.state_class}")
|
||||||
return None
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if not self._start_methods:
|
if not self._start_methods:
|
||||||
|
|||||||
@@ -8,17 +8,27 @@ class ExampleState(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class StructuredExampleFlow(Flow[ExampleState]):
|
class StructuredExampleFlow(Flow[ExampleState]):
|
||||||
|
state_class = ExampleState
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
print(f"Initial state after __init__: {self.state}") # Debug print
|
||||||
|
|
||||||
@start()
|
@start()
|
||||||
def start_method(self):
|
def start_method(self):
|
||||||
print("Starting the structured flow")
|
print("Starting the structured flow")
|
||||||
|
print(f"State in start_method: {self.state}") # Debug print
|
||||||
self.state.message = "Hello from structured flow"
|
self.state.message = "Hello from structured flow"
|
||||||
|
print(f"State after start_method: {self.state}") # Debug print
|
||||||
return "Start result"
|
return "Start result"
|
||||||
|
|
||||||
@listen(start_method)
|
@listen(start_method)
|
||||||
def second_method(self, result):
|
def second_method(self, result):
|
||||||
print(f"Second method, received: {result}")
|
print(f"Second method, received: {result}")
|
||||||
|
print(f"State before increment: {self.state}") # Debug print
|
||||||
self.state.counter += 1
|
self.state.counter += 1
|
||||||
self.state.message = "Hello from structured flow"
|
self.state.message += " - updated"
|
||||||
|
print(f"State after second_method: {self.state}") # Debug print
|
||||||
return "Second result"
|
return "Second result"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user