mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 23:58:15 +00:00
Merge branch 'release/v1.0.0' of github.com:crewAIInc/crewAI into release/v1.0.0
This commit is contained in:
@@ -5,10 +5,13 @@ This module provides the event infrastructure that allows users to:
|
||||
- Track memory operations and performance
|
||||
- Build custom logging and analytics
|
||||
- Extend CrewAI with custom event handlers
|
||||
- Declare handler dependencies for ordered execution
|
||||
"""
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.handler_graph import CircularDependencyError
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
@@ -109,6 +112,7 @@ __all__ = [
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"CircularDependencyError",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
@@ -119,6 +123,7 @@ __all__ = [
|
||||
"CrewTrainCompletedEvent",
|
||||
"CrewTrainFailedEvent",
|
||||
"CrewTrainStartedEvent",
|
||||
"Depends",
|
||||
"FlowCreatedEvent",
|
||||
"FlowEvent",
|
||||
"FlowFinishedEvent",
|
||||
|
||||
@@ -9,6 +9,7 @@ class BaseEventListener(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
crewai_event_bus.validate_dependencies()
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
|
||||
|
||||
105
lib/crewai/src/crewai/events/depends.py
Normal file
105
lib/crewai/src/crewai/events/depends.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Dependency injection system for event handlers.
|
||||
|
||||
This module provides a FastAPI-style dependency system that allows event handlers
|
||||
to declare dependencies on other handlers, ensuring proper execution order while
|
||||
maintaining parallelism where possible.
|
||||
"""
|
||||
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
EventT_co = TypeVar("EventT_co", bound=BaseEvent, contravariant=True)
|
||||
|
||||
|
||||
class EventHandler(Protocol[EventT_co]):
|
||||
"""Protocol for event handler functions.
|
||||
|
||||
Generic protocol that accepts any subclass of BaseEvent.
|
||||
Handlers can be either synchronous (returning None) or asynchronous
|
||||
(returning a coroutine).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, source: Any, event: EventT_co, /
|
||||
) -> None | Coroutine[Any, Any, None]:
|
||||
"""Event handler signature.
|
||||
|
||||
Args:
|
||||
source: The object that emitted the event
|
||||
event: The event instance (any BaseEvent subclass)
|
||||
|
||||
Returns:
|
||||
None for sync handlers, Coroutine for async handlers
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=EventHandler[Any])
|
||||
|
||||
|
||||
class Depends(Generic[T]):
|
||||
"""Declares a dependency on another event handler.
|
||||
|
||||
Similar to FastAPI's Depends, this allows handlers to specify that they
|
||||
depend on other handlers completing first. Handlers with dependencies will
|
||||
execute after their dependencies, while independent handlers can run in parallel.
|
||||
|
||||
Args:
|
||||
handler: The handler function that this handler depends on
|
||||
|
||||
Example:
|
||||
>>> from crewai.events import Depends, crewai_event_bus
|
||||
>>> from crewai.events import LLMCallStartedEvent
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent)
|
||||
>>> def setup_context(source, event):
|
||||
... return {"initialized": True}
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
|
||||
>>> def process(source, event):
|
||||
... # Runs after setup_context completes
|
||||
... pass
|
||||
"""
|
||||
|
||||
def __init__(self, handler: T) -> None:
|
||||
"""Initialize a dependency on a handler.
|
||||
|
||||
Args:
|
||||
handler: The handler function this depends on
|
||||
"""
|
||||
self.handler = handler
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the dependency.
|
||||
|
||||
Returns:
|
||||
A string showing the dependent handler name
|
||||
"""
|
||||
handler_name = getattr(self.handler, "__name__", repr(self.handler))
|
||||
return f"Depends({handler_name})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check equality based on the handler reference.
|
||||
|
||||
Args:
|
||||
other: Another Depends instance to compare
|
||||
|
||||
Returns:
|
||||
True if both depend on the same handler, False otherwise
|
||||
"""
|
||||
if not isinstance(other, Depends):
|
||||
return False
|
||||
return self.handler is other.handler
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash based on handler identity.
|
||||
|
||||
Since equality is based on identity (is), we hash the handler
|
||||
object directly rather than its id for consistency.
|
||||
|
||||
Returns:
|
||||
Hash of the handler object
|
||||
"""
|
||||
return id(self.handler)
|
||||
@@ -1,125 +1,507 @@
|
||||
from __future__ import annotations
|
||||
"""Event bus for managing and dispatching events in CrewAI.
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
This module provides a singleton event bus that allows registration and handling
|
||||
of events throughout the CrewAI system, supporting both synchronous and asynchronous
|
||||
event handlers with optional dependency management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeVar, cast
|
||||
import threading
|
||||
from typing import Any, Final, ParamSpec, TypeVar
|
||||
|
||||
from blinker import Signal
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_types import EventTypes
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.handler_graph import build_execution_plan
|
||||
from crewai.events.types.event_bus_types import (
|
||||
AsyncHandler,
|
||||
AsyncHandlerSet,
|
||||
ExecutionPlan,
|
||||
Handler,
|
||||
SyncHandler,
|
||||
SyncHandlerSet,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
|
||||
from crewai.events.utils.rw_lock import RWLock
|
||||
|
||||
EventT = TypeVar("EventT", bound=BaseEvent)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
"""
|
||||
A singleton event bus that uses blinker signals for event handling.
|
||||
Allows both internal (Flow/Crew) and external event handling.
|
||||
"""Singleton event bus for handling events in CrewAI.
|
||||
|
||||
This class manages event registration and emission for both synchronous
|
||||
and asynchronous event handlers, automatically scheduling async handlers
|
||||
in a dedicated background event loop.
|
||||
|
||||
Synchronous handlers execute in a thread pool executor to ensure completion
|
||||
before program exit. Asynchronous handlers execute in a dedicated event loop
|
||||
running in a daemon thread, with graceful shutdown waiting for completion.
|
||||
|
||||
Attributes:
|
||||
_instance: Singleton instance of the event bus
|
||||
_instance_lock: Reentrant lock for singleton initialization (class-level)
|
||||
_rwlock: Read-write lock for handler registration and access (instance-level)
|
||||
_sync_handlers: Mapping of event types to registered synchronous handlers
|
||||
_async_handlers: Mapping of event types to registered asynchronous handlers
|
||||
_sync_executor: Thread pool executor for running synchronous handlers
|
||||
_loop: Dedicated asyncio event loop for async handler execution
|
||||
_loop_thread: Background daemon thread running the event loop
|
||||
_console: Console formatter for error output
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_instance: Self | None = None
|
||||
_instance_lock: threading.RLock = threading.RLock()
|
||||
_rwlock: RWLock
|
||||
_sync_handlers: dict[type[BaseEvent], SyncHandlerSet]
|
||||
_async_handlers: dict[type[BaseEvent], AsyncHandlerSet]
|
||||
_handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]]
|
||||
_execution_plan_cache: dict[type[BaseEvent], ExecutionPlan]
|
||||
_console: ConsoleFormatter
|
||||
_shutting_down: bool
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
"""Create or return the singleton instance.
|
||||
|
||||
Returns:
|
||||
The singleton CrewAIEventsBus instance
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
|
||||
"""Initialize the event bus internal state.
|
||||
|
||||
Creates handler dictionaries and starts a dedicated background
|
||||
event loop for async handler execution.
|
||||
"""
|
||||
self._shutting_down = False
|
||||
self._rwlock = RWLock()
|
||||
self._sync_handlers: dict[type[BaseEvent], SyncHandlerSet] = {}
|
||||
self._async_handlers: dict[type[BaseEvent], AsyncHandlerSet] = {}
|
||||
self._handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]] = {}
|
||||
self._execution_plan_cache: dict[type[BaseEvent], ExecutionPlan] = {}
|
||||
self._sync_executor = ThreadPoolExecutor(
|
||||
max_workers=10,
|
||||
thread_name_prefix="CrewAISyncHandler",
|
||||
)
|
||||
self._console = ConsoleFormatter()
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._loop_thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="CrewAIEventsLoop",
|
||||
daemon=True,
|
||||
)
|
||||
self._loop_thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Run the background async event loop."""
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
def _register_handler(
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
handler: Callable[..., Any],
|
||||
dependencies: list[Depends] | None = None,
|
||||
) -> None:
|
||||
"""Register a handler for the given event type.
|
||||
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
handler: The handler function to register
|
||||
dependencies: Optional list of dependencies
|
||||
"""
|
||||
with self._rwlock.w_locked():
|
||||
if is_async_handler(handler):
|
||||
existing_async = self._async_handlers.get(event_type, frozenset())
|
||||
self._async_handlers[event_type] = existing_async | {handler}
|
||||
else:
|
||||
existing_sync = self._sync_handlers.get(event_type, frozenset())
|
||||
self._sync_handlers[event_type] = existing_sync | {handler}
|
||||
|
||||
if dependencies:
|
||||
if event_type not in self._handler_dependencies:
|
||||
self._handler_dependencies[event_type] = {}
|
||||
self._handler_dependencies[event_type][handler] = dependencies
|
||||
|
||||
self._execution_plan_cache.pop(event_type, None)
|
||||
|
||||
def on(
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
depends_on: Depends | list[Depends] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator to register an event handler for a specific event type.
|
||||
|
||||
Usage:
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
):
|
||||
print(f"👍 Agent '{event.agent}' completed task")
|
||||
print(f" Output: {event.output}")
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
depends_on: Optional dependency or list of dependencies. Handlers with
|
||||
dependencies will execute after their dependencies complete.
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the handler
|
||||
|
||||
Example:
|
||||
>>> from crewai.events import crewai_event_bus, Depends
|
||||
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent)
|
||||
>>> def setup_context(source, event):
|
||||
... print("Setting up context")
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
|
||||
>>> def process(source, event):
|
||||
... print("Processing (runs after setup_context)")
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
handler: Callable[[Any, EventT], None],
|
||||
) -> Callable[[Any, EventT], None]:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventT], None], handler)
|
||||
)
|
||||
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Register the handler and return it unchanged.
|
||||
|
||||
Args:
|
||||
handler: Event handler function to register
|
||||
|
||||
Returns:
|
||||
The same handler function unchanged
|
||||
"""
|
||||
deps = None
|
||||
if depends_on is not None:
|
||||
deps = [depends_on] if isinstance(depends_on, Depends) else depends_on
|
||||
|
||||
self._register_handler(event_type, handler, dependencies=deps)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _call_handler(
|
||||
handler: Callable, source: Any, event: BaseEvent, event_type: type
|
||||
def _call_handlers(
|
||||
self,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
handlers: SyncHandlerSet,
|
||||
) -> None:
|
||||
"""Call a single handler with error handling."""
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
"""Call provided synchronous handlers.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance
|
||||
handlers: Frozenset of sync handlers to call
|
||||
"""
|
||||
errors: list[tuple[SyncHandler, Exception]] = [
|
||||
(handler, error)
|
||||
for handler in handlers
|
||||
if (error := is_call_handler_safe(handler, source, event)) is not None
|
||||
]
|
||||
|
||||
if errors:
|
||||
for handler, error in errors:
|
||||
self._console.print(
|
||||
f"[CrewAIEventsBus] Sync handler error in {handler.__name__}: {error}"
|
||||
)
|
||||
|
||||
async def _acall_handlers(
|
||||
self,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
handlers: AsyncHandlerSet,
|
||||
) -> None:
|
||||
"""Asynchronously call provided async handlers.
|
||||
|
||||
Args:
|
||||
source: The object that emitted the event
|
||||
event: The event instance
|
||||
handlers: Frozenset of async handlers to call
|
||||
"""
|
||||
coros = [handler(source, event) for handler in handlers]
|
||||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||
for handler, result in zip(handlers, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
self._console.print(
|
||||
f"[CrewAIEventsBus] Async handler error in {getattr(handler, '__name__', handler)}: {result}"
|
||||
)
|
||||
|
||||
async def _emit_with_dependencies(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Emit an event with dependency-aware handler execution.
|
||||
|
||||
Handlers are grouped into execution levels based on their dependencies.
|
||||
Within each level, async handlers run concurrently while sync handlers
|
||||
run sequentially (or in thread pool). Each level completes before the
|
||||
next level starts.
|
||||
|
||||
Uses a cached execution plan for performance. The plan is built once
|
||||
per event type and cached until handlers are modified.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance to emit
|
||||
"""
|
||||
event_type = type(event)
|
||||
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
return
|
||||
cached_plan = self._execution_plan_cache.get(event_type)
|
||||
if cached_plan is not None:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if cached_plan is None:
|
||||
with self._rwlock.w_locked():
|
||||
if self._shutting_down:
|
||||
return
|
||||
cached_plan = self._execution_plan_cache.get(event_type)
|
||||
if cached_plan is None:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
dependencies = dict(self._handler_dependencies.get(event_type, {}))
|
||||
all_handlers = list(sync_handlers | async_handlers)
|
||||
|
||||
if not all_handlers:
|
||||
return
|
||||
|
||||
cached_plan = build_execution_plan(all_handlers, dependencies)
|
||||
self._execution_plan_cache[event_type] = cached_plan
|
||||
else:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
for level in cached_plan:
|
||||
level_sync = frozenset(h for h in level if h in sync_handlers)
|
||||
level_async = frozenset(h for h in level if h in async_handlers)
|
||||
|
||||
if level_sync:
|
||||
if event_type is LLMStreamChunkEvent:
|
||||
self._call_handlers(source, event, level_sync)
|
||||
else:
|
||||
future = self._sync_executor.submit(
|
||||
self._call_handlers, source, event, level_sync
|
||||
)
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
None, future.result
|
||||
)
|
||||
|
||||
if level_async:
|
||||
await self._acall_handlers(source, event, level_async)
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
|
||||
"""Emit an event to all registered handlers.
|
||||
|
||||
If handlers have dependencies (registered with depends_on), they execute
|
||||
in dependency order. Otherwise, handlers execute as before (sync in thread
|
||||
pool, async fire-and-forget).
|
||||
|
||||
Stream chunk events always execute synchronously to preserve ordering.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance to emit
|
||||
|
||||
Returns:
|
||||
Future that completes when handlers finish. Returns:
|
||||
- Future for sync-only handlers (ThreadPoolExecutor future)
|
||||
- Future for async handlers or mixed handlers (asyncio future)
|
||||
- Future for dependency-managed handlers (asyncio future)
|
||||
- None if no handlers or sync stream chunk events
|
||||
|
||||
Example:
|
||||
>>> future = crewai_event_bus.emit(source, event)
|
||||
>>> if future:
|
||||
... await asyncio.wrap_future(future) # In async test
|
||||
... # or future.result(timeout=5.0) in sync code
|
||||
"""
|
||||
event_type = type(event)
|
||||
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
self._console.print(
|
||||
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
|
||||
)
|
||||
return None
|
||||
has_dependencies = event_type in self._handler_dependencies
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if has_dependencies:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._emit_with_dependencies(source, event),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""
|
||||
Emit an event to all registered handlers
|
||||
if sync_handlers:
|
||||
if event_type is LLMStreamChunkEvent:
|
||||
self._call_handlers(source, event, sync_handlers)
|
||||
else:
|
||||
sync_future = self._sync_executor.submit(
|
||||
self._call_handlers, source, event, sync_handlers
|
||||
)
|
||||
if not async_handlers:
|
||||
return sync_future
|
||||
|
||||
if async_handlers:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._acall_handlers(source, event, async_handlers),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def aemit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Asynchronously emit an event to registered async handlers.
|
||||
|
||||
Only processes async handlers. Use in async contexts.
|
||||
|
||||
Args:
|
||||
source: The object emitting the event
|
||||
event: The event instance to emit
|
||||
"""
|
||||
for event_type, handlers in self._handlers.items():
|
||||
if isinstance(event, event_type):
|
||||
for handler in handlers:
|
||||
self._call_handler(handler, source, event, event_type)
|
||||
event_type = type(event)
|
||||
|
||||
self._signal.send(source, event=event)
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
self._console.print(
|
||||
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
|
||||
)
|
||||
return
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if async_handlers:
|
||||
await self._acall_handlers(source, event, async_handlers)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
handler: SyncHandler | AsyncHandler,
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventTypes], None], handler)
|
||||
)
|
||||
"""Register an event handler for a specific event type.
|
||||
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
handler: The handler function to register
|
||||
"""
|
||||
self._register_handler(event_type, handler)
|
||||
|
||||
def validate_dependencies(self) -> None:
|
||||
"""Validate all registered handler dependencies.
|
||||
|
||||
Attempts to build execution plans for all event types with dependencies.
|
||||
This detects circular dependencies and cross-event-type dependencies
|
||||
before events are emitted.
|
||||
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies or unresolved
|
||||
dependencies (e.g., cross-event-type) are detected
|
||||
"""
|
||||
with self._rwlock.r_locked():
|
||||
for event_type in self._handler_dependencies:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
dependencies = dict(self._handler_dependencies.get(event_type, {}))
|
||||
all_handlers = list(sync_handlers | async_handlers)
|
||||
|
||||
if all_handlers and dependencies:
|
||||
build_execution_plan(all_handlers, dependencies)
|
||||
|
||||
@contextmanager
|
||||
def scoped_handlers(self):
|
||||
"""
|
||||
Context manager for temporary event handling scope.
|
||||
Useful for testing or temporary event handling.
|
||||
def scoped_handlers(self) -> Generator[None, Any, None]:
|
||||
"""Context manager for temporary event handling scope.
|
||||
|
||||
Usage:
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStarted)
|
||||
def temp_handler(source, event):
|
||||
print("Temporary handler")
|
||||
# Do stuff...
|
||||
# Handlers are cleared after the context
|
||||
Useful for testing or temporary event handling. All handlers registered
|
||||
within this context are cleared when the context exits.
|
||||
|
||||
Example:
|
||||
>>> from crewai.events.event_bus import crewai_event_bus
|
||||
>>> from crewai.events.event_types import CrewKickoffStartedEvent
|
||||
>>> with crewai_event_bus.scoped_handlers():
|
||||
...
|
||||
... @crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
... def temp_handler(source, event):
|
||||
... print("Temporary handler")
|
||||
...
|
||||
... # Do stuff...
|
||||
... # Handlers are cleared after the context
|
||||
"""
|
||||
previous_handlers = self._handlers.copy()
|
||||
self._handlers.clear()
|
||||
with self._rwlock.w_locked():
|
||||
prev_sync = self._sync_handlers
|
||||
prev_async = self._async_handlers
|
||||
prev_deps = self._handler_dependencies
|
||||
prev_cache = self._execution_plan_cache
|
||||
self._sync_handlers = {}
|
||||
self._async_handlers = {}
|
||||
self._handler_dependencies = {}
|
||||
self._execution_plan_cache = {}
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._handlers = previous_handlers
|
||||
with self._rwlock.w_locked():
|
||||
self._sync_handlers = prev_sync
|
||||
self._async_handlers = prev_async
|
||||
self._handler_dependencies = prev_deps
|
||||
self._execution_plan_cache = prev_cache
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Gracefully shutdown the event loop and wait for all tasks to finish.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for all pending tasks to complete before stopping.
|
||||
If False, cancel all pending tasks immediately.
|
||||
"""
|
||||
with self._rwlock.w_locked():
|
||||
self._shutting_down = True
|
||||
loop = getattr(self, "_loop", None)
|
||||
|
||||
if loop is None or loop.is_closed():
|
||||
return
|
||||
|
||||
if wait:
|
||||
|
||||
async def _wait_for_all_tasks() -> None:
|
||||
tasks = {
|
||||
t
|
||||
for t in asyncio.all_tasks(loop)
|
||||
if t is not asyncio.current_task()
|
||||
}
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(_wait_for_all_tasks(), loop)
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
self._console.print(f"[CrewAIEventsBus] Error waiting for tasks: {e}")
|
||||
else:
|
||||
|
||||
def _cancel_tasks() -> None:
|
||||
for task in asyncio.all_tasks(loop):
|
||||
if task is not asyncio.current_task():
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(_cancel_tasks)
|
||||
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
self._loop_thread.join()
|
||||
loop.close()
|
||||
self._sync_executor.shutdown(wait=wait)
|
||||
|
||||
with self._rwlock.w_locked():
|
||||
self._sync_handlers.clear()
|
||||
self._async_handlers.clear()
|
||||
self._execution_plan_cache.clear()
|
||||
|
||||
|
||||
# Global instance
|
||||
crewai_event_bus = CrewAIEventsBus()
|
||||
crewai_event_bus: Final[CrewAIEventsBus] = CrewAIEventsBus()
|
||||
|
||||
atexit.register(crewai_event_bus.shutdown)
|
||||
|
||||
130
lib/crewai/src/crewai/events/handler_graph.py
Normal file
130
lib/crewai/src/crewai/events/handler_graph.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Dependency graph resolution for event handlers.
|
||||
|
||||
This module resolves handler dependencies into execution levels, ensuring
|
||||
handlers execute in correct order while maximizing parallelism.
|
||||
"""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""Exception raised when circular dependencies are detected in event handlers.
|
||||
|
||||
Attributes:
|
||||
handlers: The handlers involved in the circular dependency
|
||||
"""
|
||||
|
||||
def __init__(self, handlers: list[Handler]) -> None:
|
||||
"""Initialize the circular dependency error.
|
||||
|
||||
Args:
|
||||
handlers: The handlers involved in the circular dependency
|
||||
"""
|
||||
handler_names = ", ".join(
|
||||
getattr(h, "__name__", repr(h)) for h in handlers[:5]
|
||||
)
|
||||
message = f"Circular dependency detected in event handlers: {handler_names}"
|
||||
super().__init__(message)
|
||||
self.handlers = handlers
|
||||
|
||||
|
||||
class HandlerGraph:
|
||||
"""Resolves handler dependencies into parallel execution levels.
|
||||
|
||||
Handlers are organized into levels where:
|
||||
- Level 0: Handlers with no dependencies (can run first)
|
||||
- Level N: Handlers that depend on handlers in levels 0...N-1
|
||||
|
||||
Handlers within the same level can execute in parallel.
|
||||
|
||||
Attributes:
|
||||
levels: List of handler sets, where each level can execute in parallel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: dict[Handler, list[Depends]],
|
||||
) -> None:
|
||||
"""Initialize the dependency graph.
|
||||
|
||||
Args:
|
||||
handlers: Mapping of handler -> list of `crewai.events.depends.Depends` objects
|
||||
"""
|
||||
self.handlers = handlers
|
||||
self.levels: ExecutionPlan = []
|
||||
self._resolve()
|
||||
|
||||
def _resolve(self) -> None:
|
||||
"""Resolve dependencies into execution levels using topological sort."""
|
||||
dependents: dict[Handler, set[Handler]] = defaultdict(set)
|
||||
in_degree: dict[Handler, int] = {}
|
||||
|
||||
for handler in self.handlers:
|
||||
in_degree[handler] = 0
|
||||
|
||||
for handler, deps in self.handlers.items():
|
||||
in_degree[handler] = len(deps)
|
||||
for dep in deps:
|
||||
dependents[dep.handler].add(handler)
|
||||
|
||||
queue: deque[Handler] = deque(
|
||||
[h for h, deg in in_degree.items() if deg == 0]
|
||||
)
|
||||
|
||||
while queue:
|
||||
current_level: set[Handler] = set()
|
||||
|
||||
for _ in range(len(queue)):
|
||||
handler = queue.popleft()
|
||||
current_level.add(handler)
|
||||
|
||||
for dependent in dependents[handler]:
|
||||
in_degree[dependent] -= 1
|
||||
if in_degree[dependent] == 0:
|
||||
queue.append(dependent)
|
||||
|
||||
if current_level:
|
||||
self.levels.append(current_level)
|
||||
|
||||
remaining = [h for h, deg in in_degree.items() if deg > 0]
|
||||
if remaining:
|
||||
raise CircularDependencyError(remaining)
|
||||
|
||||
def get_execution_plan(self) -> ExecutionPlan:
|
||||
"""Get the ordered execution plan.
|
||||
|
||||
Returns:
|
||||
List of handler sets, where each set represents handlers that can
|
||||
execute in parallel. Sets are ordered such that dependencies are
|
||||
satisfied.
|
||||
"""
|
||||
return self.levels
|
||||
|
||||
|
||||
def build_execution_plan(
|
||||
handlers: Sequence[Handler],
|
||||
dependencies: dict[Handler, list[Depends]],
|
||||
) -> ExecutionPlan:
|
||||
"""Build an execution plan from handlers and their dependencies.
|
||||
|
||||
Args:
|
||||
handlers: All handlers for an event type
|
||||
dependencies: Mapping of handler -> list of dependencies
|
||||
|
||||
Returns:
|
||||
Execution plan as list of levels, where each level is a set of
|
||||
handlers that can execute in parallel
|
||||
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies are detected
|
||||
"""
|
||||
handler_dict: dict[Handler, list[Depends]] = {
|
||||
h: dependencies.get(h, []) for h in handlers
|
||||
}
|
||||
|
||||
graph = HandlerGraph(handler_dict)
|
||||
return graph.get_execution_plan()
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from logging import getLogger
|
||||
from threading import Condition, Lock
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -14,6 +15,7 @@ from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -41,6 +43,11 @@ class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
def __init__(self):
|
||||
self._init_lock = Lock()
|
||||
self._pending_events_lock = Lock()
|
||||
self._pending_events_cv = Condition(self._pending_events_lock)
|
||||
self._pending_events_count = 0
|
||||
|
||||
self.is_current_batch_ephemeral: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.current_batch: TraceBatch | None = None
|
||||
@@ -64,24 +71,28 @@ class TraceBatchManager:
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch"""
|
||||
self.current_batch = TraceBatch(
|
||||
user_context=user_context, execution_metadata=execution_metadata
|
||||
)
|
||||
self.event_buffer.clear()
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
"""Initialize a new trace batch (thread-safe)"""
|
||||
with self._init_lock:
|
||||
if self.current_batch is not None:
|
||||
logger.debug("Batch already initialized, skipping duplicate initialization")
|
||||
return self.current_batch
|
||||
|
||||
self.record_start_time("execution")
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
self.current_batch = TraceBatch(
|
||||
user_context=user_context, execution_metadata=execution_metadata
|
||||
)
|
||||
self.backend_initialized = True
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
|
||||
return self.current_batch
|
||||
self.record_start_time("execution")
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
)
|
||||
self.backend_initialized = True
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
self,
|
||||
@@ -148,6 +159,38 @@ class TraceBatchManager:
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
|
||||
def begin_event_processing(self):
|
||||
"""Mark that an event handler started processing (for synchronization)"""
|
||||
with self._pending_events_lock:
|
||||
self._pending_events_count += 1
|
||||
|
||||
def end_event_processing(self):
|
||||
"""Mark that an event handler finished processing (for synchronization)"""
|
||||
with self._pending_events_cv:
|
||||
self._pending_events_count -= 1
|
||||
if self._pending_events_count == 0:
|
||||
self._pending_events_cv.notify_all()
|
||||
|
||||
def wait_for_pending_events(self, timeout: float = 2.0) -> bool:
|
||||
"""Wait for all pending event handlers to finish processing
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (default: 2.0)
|
||||
|
||||
Returns:
|
||||
True if all handlers completed, False if timeout occurred
|
||||
"""
|
||||
with self._pending_events_cv:
|
||||
if self._pending_events_count > 0:
|
||||
logger.debug(f"Waiting for {self._pending_events_count} pending event handlers...")
|
||||
self._pending_events_cv.wait(timeout)
|
||||
if self._pending_events_count > 0:
|
||||
logger.error(
|
||||
f"Timeout waiting for event handlers. {self._pending_events_count} still pending. Events may be incomplete!"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
@@ -180,8 +223,8 @@ class TraceBatchManager:
|
||||
self.event_buffer.clear()
|
||||
return 200
|
||||
|
||||
logger.warning(
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
logger.error(
|
||||
f"Failed to send events: {response.status_code}. Response: {response.text}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
@@ -196,15 +239,33 @@ class TraceBatchManager:
|
||||
if not self.current_batch:
|
||||
return None
|
||||
|
||||
self.current_batch.events = self.event_buffer.copy()
|
||||
if self.event_buffer:
|
||||
all_handlers_completed = self.wait_for_pending_events(timeout=2.0)
|
||||
|
||||
if not all_handlers_completed:
|
||||
logger.error("Event handler timeout - marking batch as failed due to incomplete events")
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Timeout waiting for event handlers - events incomplete"
|
||||
)
|
||||
return None
|
||||
|
||||
sorted_events = sorted(
|
||||
self.event_buffer,
|
||||
key=lambda e: e.timestamp if hasattr(e, 'timestamp') and e.timestamp else ''
|
||||
)
|
||||
|
||||
self.current_batch.events = sorted_events
|
||||
events_sent_count = len(sorted_events)
|
||||
if sorted_events:
|
||||
original_buffer = self.event_buffer
|
||||
self.event_buffer = sorted_events
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
self.event_buffer = original_buffer
|
||||
if events_sent_to_backend_status == 500:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
self._finalize_backend_batch()
|
||||
self._finalize_backend_batch(events_sent_count)
|
||||
|
||||
finalized_batch = self.current_batch
|
||||
|
||||
@@ -220,18 +281,20 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self):
|
||||
"""Send batch finalization to backend"""
|
||||
def _finalize_backend_batch(self, events_count: int = 0):
|
||||
"""Send batch finalization to backend
|
||||
|
||||
Args:
|
||||
events_count: Number of events that were successfully sent
|
||||
"""
|
||||
if not self.plus_api or not self.trace_batch_id:
|
||||
return
|
||||
|
||||
try:
|
||||
total_events = len(self.current_batch.events) if self.current_batch else 0
|
||||
|
||||
payload = {
|
||||
"status": "completed",
|
||||
"duration_ms": self.calculate_duration("execution"),
|
||||
"final_event_count": total_events,
|
||||
"final_event_count": events_count,
|
||||
}
|
||||
|
||||
response = (
|
||||
|
||||
@@ -170,14 +170,6 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_flow_finished(source, event):
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
# Normal flow finalization
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
self._handle_action_event("flow_plot", source, event)
|
||||
@@ -383,10 +375,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
|
||||
self.batch_manager.add_event(trace_event)
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
finally:
|
||||
self.batch_manager.end_event_processing()
|
||||
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for action events (LLM calls, tool usage)"""
|
||||
@@ -399,18 +393,29 @@ class TraceCollectionListener(BaseEventListener):
|
||||
}
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
finally:
|
||||
self.batch_manager.end_event_processing()
|
||||
|
||||
def _create_trace_event(
|
||||
self, event_type: str, source: Any, event: Any
|
||||
) -> TraceEvent:
|
||||
"""Create a trace event"""
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
)
|
||||
if hasattr(event, 'timestamp') and event.timestamp:
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
timestamp=event.timestamp.isoformat(),
|
||||
)
|
||||
else:
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
)
|
||||
|
||||
trace_event.event_data = self._build_event_data(event_type, event, source)
|
||||
|
||||
return trace_event
|
||||
|
||||
def _build_event_data(
|
||||
|
||||
14
lib/crewai/src/crewai/events/types/event_bus_types.py
Normal file
14
lib/crewai/src/crewai/events/types/event_bus_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Type definitions for event handlers."""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
|
||||
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
||||
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
|
||||
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
|
||||
|
||||
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
|
||||
ExecutionPlan: TypeAlias = list[set[Handler]]
|
||||
59
lib/crewai/src/crewai/events/utils/handlers.py
Normal file
59
lib/crewai/src/crewai/events/utils/handlers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Handler utility functions for event processing."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
|
||||
|
||||
|
||||
def is_async_handler(
|
||||
handler: Any,
|
||||
) -> TypeIs[AsyncHandler]:
|
||||
"""Type guard to check if handler is an async handler.
|
||||
|
||||
Args:
|
||||
handler: The handler to check
|
||||
|
||||
Returns:
|
||||
True if handler is an async coroutine function
|
||||
"""
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler) or (
|
||||
callable(handler) and inspect.iscoroutinefunction(handler.__call__)
|
||||
):
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if isinstance(handler, functools.partial) and inspect.iscoroutinefunction(
|
||||
handler.func
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_call_handler_safe(
|
||||
handler: SyncHandler,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
) -> Exception | None:
|
||||
"""Safely call a single handler and return any exception.
|
||||
|
||||
Args:
|
||||
handler: The handler function to call
|
||||
source: The object that emitted the event
|
||||
event: The event instance
|
||||
|
||||
Returns:
|
||||
Exception if handler raised one, None otherwise
|
||||
"""
|
||||
try:
|
||||
handler(source, event)
|
||||
return None
|
||||
except Exception as e:
|
||||
return e
|
||||
81
lib/crewai/src/crewai/events/utils/rw_lock.py
Normal file
81
lib/crewai/src/crewai/events/utils/rw_lock.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Read-write lock for thread-safe concurrent access.
|
||||
|
||||
This module provides a reader-writer lock implementation that allows multiple
|
||||
concurrent readers or a single exclusive writer.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from threading import Condition
|
||||
|
||||
|
||||
class RWLock:
|
||||
"""Read-write lock for managing concurrent read and exclusive write access.
|
||||
|
||||
Allows multiple threads to acquire read locks simultaneously, but ensures
|
||||
exclusive access for write operations. Writers are prioritized when waiting.
|
||||
|
||||
Attributes:
|
||||
_cond: Condition variable for coordinating lock access
|
||||
_readers: Count of active readers
|
||||
_writer: Whether a writer currently holds the lock
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the read-write lock."""
|
||||
self._cond = Condition()
|
||||
self._readers = 0
|
||||
self._writer = False
|
||||
|
||||
def r_acquire(self) -> None:
|
||||
"""Acquire a read lock, blocking if a writer holds the lock."""
|
||||
with self._cond:
|
||||
while self._writer:
|
||||
self._cond.wait()
|
||||
self._readers += 1
|
||||
|
||||
def r_release(self) -> None:
|
||||
"""Release a read lock and notify waiting writers if last reader."""
|
||||
with self._cond:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._cond.notify_all()
|
||||
|
||||
@contextmanager
|
||||
def r_locked(self) -> Generator[None, None, None]:
|
||||
"""Context manager for acquiring a read lock.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self.r_acquire()
|
||||
yield
|
||||
finally:
|
||||
self.r_release()
|
||||
|
||||
def w_acquire(self) -> None:
|
||||
"""Acquire a write lock, blocking if any readers or writers are active."""
|
||||
with self._cond:
|
||||
while self._writer or self._readers > 0:
|
||||
self._cond.wait()
|
||||
self._writer = True
|
||||
|
||||
def w_release(self) -> None:
|
||||
"""Release a write lock and notify all waiting threads."""
|
||||
with self._cond:
|
||||
self._writer = False
|
||||
self._cond.notify_all()
|
||||
|
||||
@contextmanager
|
||||
def w_locked(self) -> Generator[None, None, None]:
|
||||
"""Context manager for acquiring a write lock.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self.w_acquire()
|
||||
yield
|
||||
finally:
|
||||
self.w_release()
|
||||
@@ -52,19 +52,14 @@ class AgentEvaluator:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
self.display_formatter = EvaluationDisplayFormatter()
|
||||
|
||||
self._thread_local: threading.local = threading.local()
|
||||
self._execution_state = ExecutionState()
|
||||
self._state_lock = threading.Lock()
|
||||
|
||||
for agent in self.agents:
|
||||
self._execution_state.agent_evaluators[str(agent.id)] = self.evaluators
|
||||
|
||||
self._subscribe_to_events()
|
||||
|
||||
@property
|
||||
def _execution_state(self) -> ExecutionState:
|
||||
if not hasattr(self._thread_local, "execution_state"):
|
||||
self._thread_local.execution_state = ExecutionState()
|
||||
return self._thread_local.execution_state
|
||||
|
||||
def _subscribe_to_events(self) -> None:
|
||||
from typing import cast
|
||||
|
||||
@@ -112,21 +107,22 @@ class AgentEvaluator:
|
||||
state=state,
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
with self._state_lock:
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
if (
|
||||
agent.role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
] = []
|
||||
|
||||
if (
|
||||
agent.role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
].append(result)
|
||||
].append(result)
|
||||
|
||||
def _handle_lite_agent_completed(
|
||||
self, source: object, event: LiteAgentExecutionCompletedEvent
|
||||
@@ -164,22 +160,23 @@ class AgentEvaluator:
|
||||
state=state,
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
with self._state_lock:
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
agent_role = target_agent.role
|
||||
if (
|
||||
agent_role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
] = []
|
||||
|
||||
agent_role = target_agent.role
|
||||
if (
|
||||
agent_role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
].append(result)
|
||||
].append(result)
|
||||
|
||||
def set_iteration(self, iteration: int) -> None:
|
||||
self._execution_state.iteration = iteration
|
||||
|
||||
@@ -3,6 +3,7 @@ import copy
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -463,6 +464,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._completed_methods: set[str] = set() # Track completed methods for reload
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
self._event_futures: list[Future[None]] = []
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -855,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
# Emit FlowStartedEvent and log the start of the flow.
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
@@ -863,6 +865,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
inputs=inputs,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
)
|
||||
@@ -881,7 +885,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
@@ -889,6 +893,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result=final_output,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
if self._event_futures:
|
||||
await asyncio.gather(*[asyncio.wrap_future(f) for f in self._event_futures])
|
||||
self._event_futures.clear()
|
||||
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
|
||||
return final_output
|
||||
finally:
|
||||
@@ -971,7 +994,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
type="method_execution_started",
|
||||
@@ -981,6 +1004,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
state=self._copy_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
result = (
|
||||
await method(*args, **kwargs)
|
||||
@@ -994,7 +1019,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
@@ -1004,10 +1029,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFailedEvent(
|
||||
type="method_execution_failed",
|
||||
@@ -1016,6 +1043,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from openai import OpenAI
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -33,6 +33,9 @@ class OpenAICompletion(BaseLLM):
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
@@ -44,8 +47,8 @@ class OpenAICompletion(BaseLLM):
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None, # For o1 models
|
||||
provider: str | None = None, # Add provider parameter
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI chat completion client."""
|
||||
@@ -53,6 +56,16 @@ class OpenAICompletion(BaseLLM):
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -63,15 +76,10 @@ class OpenAICompletion(BaseLLM):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
client_config = self._get_client_params()
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
@@ -83,10 +91,35 @@ class OpenAICompletion(BaseLLM):
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.timeout = timeout
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"project": self.project,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
|
||||
client_params = {k: v for k, v in base_params.items() if v is not None}
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -207,7 +240,6 @@ class OpenAICompletion(BaseLLM):
|
||||
"api_key",
|
||||
"base_url",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
}
|
||||
|
||||
return {k: v for k, v in params.items() if k not in crewai_specific_params}
|
||||
@@ -306,10 +338,31 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except APIConnectionError as e:
|
||||
error_msg = f"Failed to connect to OpenAI API: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Handle context length exceeded and other errors
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
Reference in New Issue
Block a user