mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
WIP. Working on adding and & or to flows. In the middle of setting up template for flow as well
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,11 +18,34 @@ class FlowMeta(type):
|
||||
if hasattr(attr_value, "__is_start_method__"):
|
||||
start_methods.append(attr_name)
|
||||
if hasattr(attr_value, "__trigger_methods__"):
|
||||
for trigger in attr_value.__trigger_methods__:
|
||||
trigger_name = trigger.__name__ if callable(trigger) else trigger
|
||||
if trigger_name not in listeners:
|
||||
listeners[trigger_name] = []
|
||||
listeners[trigger_name].append(attr_name)
|
||||
condition = attr_value.__trigger_methods__
|
||||
if callable(condition):
|
||||
# Single method reference
|
||||
method_name = condition.__name__
|
||||
if method_name not in listeners:
|
||||
listeners[method_name] = []
|
||||
listeners[method_name].append((attr_name, "SINGLE", [method_name]))
|
||||
elif isinstance(condition, str):
|
||||
# Single method name
|
||||
if condition not in listeners:
|
||||
listeners[condition] = []
|
||||
listeners[condition].append((attr_name, "SINGLE", [condition]))
|
||||
elif isinstance(condition, tuple):
|
||||
# AND or OR condition
|
||||
condition_type = (
|
||||
"AND" if any(item == "and" for item in condition) else "OR"
|
||||
)
|
||||
methods = [
|
||||
m.__name__ if callable(m) else m
|
||||
for m in condition
|
||||
if m != "and" and m != "or"
|
||||
]
|
||||
for method in methods:
|
||||
if method not in listeners:
|
||||
listeners[method] = []
|
||||
listeners[method].append((attr_name, condition_type, methods))
|
||||
else:
|
||||
raise ValueError(f"Invalid listener format for {attr_name}")
|
||||
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
@@ -38,12 +63,13 @@ class FlowMeta(type):
|
||||
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, List[str]] = {}
|
||||
_listeners: Dict[str, List[tuple[str, str, List[str]]]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
|
||||
def __init__(self):
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._state = self._create_initial_state()
|
||||
self._completed_methods: set[str] = set()
|
||||
|
||||
for method_name in dir(self):
|
||||
if callable(getattr(self, method_name)) and not method_name.startswith(
|
||||
@@ -63,23 +89,55 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def state(self) -> T:
|
||||
return self._state
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
if not self._start_methods:
|
||||
raise ValueError("No start method defined")
|
||||
|
||||
for start_method in self._start_methods:
|
||||
result = self._methods[start_method]()
|
||||
self._execute_listeners(start_method, result)
|
||||
result = await self._execute_method(self._methods[start_method])
|
||||
await self._execute_listeners(start_method, result)
|
||||
|
||||
async def _execute_method(self, method: Callable, *args, **kwargs):
|
||||
if inspect.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
else:
|
||||
return method(*args, **kwargs)
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any):
|
||||
self._completed_methods.add(trigger_method)
|
||||
|
||||
def _execute_listeners(self, trigger_method: str, result: Any):
|
||||
if trigger_method in self._listeners:
|
||||
for listener in self._listeners[trigger_method]:
|
||||
try:
|
||||
listener_result = self._methods[listener](result)
|
||||
self._execute_listeners(listener, listener_result)
|
||||
except Exception as e:
|
||||
print(f"Error in method {listener}: {str(e)}")
|
||||
return
|
||||
listener_tasks = []
|
||||
for listener, condition_type, methods in self._listeners[trigger_method]:
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener, result)
|
||||
)
|
||||
elif condition_type == "AND":
|
||||
if all(method in self._completed_methods for method in methods):
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener, result)
|
||||
)
|
||||
elif condition_type == "SINGLE":
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener, result)
|
||||
)
|
||||
|
||||
# Run all listener tasks concurrently and wait for them to complete
|
||||
await asyncio.gather(*listener_tasks)
|
||||
|
||||
async def _execute_single_listener(self, listener: str, result: Any):
|
||||
try:
|
||||
method = self._methods[listener]
|
||||
sig = inspect.signature(method)
|
||||
if len(sig.parameters) > 1: # More than just 'self'
|
||||
listener_result = await self._execute_method(method, result)
|
||||
else:
|
||||
listener_result = await self._execute_method(method)
|
||||
await self._execute_listeners(listener, listener_result)
|
||||
except Exception as e:
|
||||
print(f"Error in method {listener}: {str(e)}")
|
||||
|
||||
|
||||
def start():
|
||||
@@ -90,9 +148,9 @@ def start():
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(*trigger_methods):
|
||||
def listen(condition):
|
||||
def decorator(func):
|
||||
func.__trigger_methods__ = trigger_methods
|
||||
func.__trigger_methods__ = condition
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user