Fix flows to support cycles and added in test (#1556)

This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-11-05 12:02:54 -05:00
committed by GitHub
parent 3d44795476
commit faa231e278
2 changed files with 289 additions and 28 deletions

View File

@@ -131,7 +131,6 @@ class FlowMeta(type):
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
# TODO: should we add a check for __condition_type__ 'AND'?
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
@@ -171,8 +170,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def __init__(self) -> None:
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._executed_methods: Set[str] = set()
self._scheduled_tasks: Set[str] = set()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
@@ -309,7 +307,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
self._method_outputs.append(result) # Store the output
self._executed_methods.add(method_name)
# Track method execution counts
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)
return result
@@ -319,35 +320,34 @@ class Flow(Generic[T], metaclass=FlowMeta):
if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(
trigger_method, router_method
) # TODO: Change or not?
# Use the path as the new trigger method
self._routers[trigger_method], router_method
)
trigger_method = path
for listener_name, (condition_type, methods) in self._listeners.items():
if condition_type == "OR":
if trigger_method in methods:
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Schedule the listener without preventing re-execution
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND":
if all(method in self._executed_methods for method in methods):
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
# Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks)
if listener_tasks:
await asyncio.gather(*listener_tasks)
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
@@ -367,9 +367,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method)
# Remove from scheduled tasks after execution
self._scheduled_tasks.discard(listener_name)
# Execute listeners of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception as e: