quick fixes

This commit is contained in:
Brandon Hancock
2024-10-02 16:09:38 -04:00
parent f46a12b3b4
commit f775101b18
2 changed files with 9 additions and 10 deletions

View File

@@ -148,15 +148,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
_router_paths: Dict[str, List[str]] = {} _router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None initial_state: Union[Type[T], T, None] = None
def __class_getitem__(cls, item): def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls): class _FlowGeneric(cls):
_initial_state_T = item _initial_state_T: Type[T] = item
return _FlowGeneric return _FlowGeneric
def __init__(self): def __init__(self) -> None:
self._methods: Dict[str, Callable] = {} self._methods: Dict[str, Callable] = {}
self._state = self._create_initial_state() self._state: T = self._create_initial_state()
self._completed_methods: Set[str] = set() self._completed_methods: Set[str] = set()
self._pending_and_listeners: Dict[str, Set[str]] = {} self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs self._method_outputs: List[Any] = [] # List to store all method outputs
@@ -205,11 +205,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
else: else:
return None # Or raise an exception if no methods were executed return None # Or raise an exception if no methods were executed
async def _execute_start_method(self, start_method: str): async def _execute_start_method(self, start_method: str) -> None:
result = await self._execute_method(self._methods[start_method]) result = await self._execute_method(self._methods[start_method])
await self._execute_listeners(start_method, result) await self._execute_listeners(start_method, result)
async def _execute_method(self, method: Callable, *args, **kwargs): async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any:
result = ( result = (
await method(*args, **kwargs) await method(*args, **kwargs)
if asyncio.iscoroutinefunction(method) if asyncio.iscoroutinefunction(method)
@@ -218,7 +218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_outputs.append(result) # Store the output self._method_outputs.append(result) # Store the output
return result return result
async def _execute_listeners(self, trigger_method: str, result: Any): async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = [] listener_tasks = []
if trigger_method in self._routers: if trigger_method in self._routers:
@@ -246,7 +246,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Run all listener tasks concurrently and wait for them to complete # Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks) await asyncio.gather(*listener_tasks)
async def _execute_single_listener(self, listener: str, result: Any): async def _execute_single_listener(self, listener: str, result: Any) -> None:
try: try:
method = self._methods[listener] method = self._methods[listener]
sig = inspect.signature(method) sig = inspect.signature(method)
@@ -270,5 +270,5 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc() traceback.print_exc()
def plot(self, filename: str = "crewai_flow_graph"): def plot(self, filename: str = "crewai_flow_graph") -> None:
plot_flow(self, filename) plot_flow(self, filename)

View File

@@ -1,5 +1,4 @@
import base64 import base64
import os
import re import re