feat: improve event bus thread safety and async support

Add thread-safe, async-compatible event bus with read–write locking and
handler dependency ordering. Remove blinker dependency and implement
direct dispatch. Improve type safety, error handling, and deterministic
event synchronization.

Refactor tests to auto-wait for async handlers, ensure clean teardown,
and add comprehensive concurrency coverage. Replace thread-local state
in AgentEvaluator with instance-based locking for correct cross-thread
access. Enhance tracing reliability and event finalization.
This commit is contained in:
Greyson LaLonde
2025-10-14 13:28:58 -04:00
committed by GitHub
parent cec4e4c2e9
commit 53b239c6df
34 changed files with 3360 additions and 876 deletions

View File

@@ -35,7 +35,6 @@ dependencies = [
"uv>=0.4.25", "uv>=0.4.25",
"tomli-w>=1.1.0", "tomli-w>=1.1.0",
"tomli>=2.0.2", "tomli>=2.0.2",
"blinker>=1.9.0",
"json5>=0.10.0", "json5>=0.10.0",
"portalocker==2.7.0", "portalocker==2.7.0",
"pydantic-settings>=2.10.1", "pydantic-settings>=2.10.1",

View File

@@ -5,10 +5,13 @@ This module provides the event infrastructure that allows users to:
- Track memory operations and performance - Track memory operations and performance
- Build custom logging and analytics - Build custom logging and analytics
- Extend CrewAI with custom event handlers - Extend CrewAI with custom event handlers
- Declare handler dependencies for ordered execution
""" """
from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_event_listener import BaseEventListener
from crewai.events.depends import Depends
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.handler_graph import CircularDependencyError
from crewai.events.types.agent_events import ( from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent, AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent, AgentEvaluationFailedEvent,
@@ -109,6 +112,7 @@ __all__ = [
"AgentReasoningFailedEvent", "AgentReasoningFailedEvent",
"AgentReasoningStartedEvent", "AgentReasoningStartedEvent",
"BaseEventListener", "BaseEventListener",
"CircularDependencyError",
"CrewKickoffCompletedEvent", "CrewKickoffCompletedEvent",
"CrewKickoffFailedEvent", "CrewKickoffFailedEvent",
"CrewKickoffStartedEvent", "CrewKickoffStartedEvent",
@@ -119,6 +123,7 @@ __all__ = [
"CrewTrainCompletedEvent", "CrewTrainCompletedEvent",
"CrewTrainFailedEvent", "CrewTrainFailedEvent",
"CrewTrainStartedEvent", "CrewTrainStartedEvent",
"Depends",
"FlowCreatedEvent", "FlowCreatedEvent",
"FlowEvent", "FlowEvent",
"FlowFinishedEvent", "FlowFinishedEvent",

View File

@@ -9,6 +9,7 @@ class BaseEventListener(ABC):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.setup_listeners(crewai_event_bus) self.setup_listeners(crewai_event_bus)
crewai_event_bus.validate_dependencies()
@abstractmethod @abstractmethod
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus): def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):

View File

@@ -0,0 +1,105 @@
"""Dependency injection system for event handlers.
This module provides a FastAPI-style dependency system that allows event handlers
to declare dependencies on other handlers, ensuring proper execution order while
maintaining parallelism where possible.
"""
from collections.abc import Coroutine
from typing import Any, Generic, Protocol, TypeVar
from crewai.events.base_events import BaseEvent
EventT_co = TypeVar("EventT_co", bound=BaseEvent, contravariant=True)
class EventHandler(Protocol[EventT_co]):
"""Protocol for event handler functions.
Generic protocol that accepts any subclass of BaseEvent.
Handlers can be either synchronous (returning None) or asynchronous
(returning a coroutine).
"""
def __call__(
self, source: Any, event: EventT_co, /
) -> None | Coroutine[Any, Any, None]:
"""Event handler signature.
Args:
source: The object that emitted the event
event: The event instance (any BaseEvent subclass)
Returns:
None for sync handlers, Coroutine for async handlers
"""
...
T = TypeVar("T", bound=EventHandler[Any])
class Depends(Generic[T]):
"""Declares a dependency on another event handler.
Similar to FastAPI's Depends, this allows handlers to specify that they
depend on other handlers completing first. Handlers with dependencies will
execute after their dependencies, while independent handlers can run in parallel.
Args:
handler: The handler function that this handler depends on
Example:
>>> from crewai.events import Depends, crewai_event_bus
>>> from crewai.events import LLMCallStartedEvent
>>> @crewai_event_bus.on(LLMCallStartedEvent)
>>> def setup_context(source, event):
... return {"initialized": True}
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
>>> def process(source, event):
... # Runs after setup_context completes
... pass
"""
def __init__(self, handler: T) -> None:
"""Initialize a dependency on a handler.
Args:
handler: The handler function this depends on
"""
self.handler = handler
def __repr__(self) -> str:
"""Return a string representation of the dependency.
Returns:
A string showing the dependent handler name
"""
handler_name = getattr(self.handler, "__name__", repr(self.handler))
return f"Depends({handler_name})"
def __eq__(self, other: object) -> bool:
"""Check equality based on the handler reference.
Args:
other: Another Depends instance to compare
Returns:
True if both depend on the same handler, False otherwise
"""
if not isinstance(other, Depends):
return False
return self.handler is other.handler
def __hash__(self) -> int:
"""Return hash based on handler identity.
Since equality is based on identity (is), we hash the handler
object directly rather than its id for consistency.
Returns:
Hash of the handler object
"""
return id(self.handler)

View File

@@ -1,125 +1,507 @@
from __future__ import annotations """Event bus for managing and dispatching events in CrewAI.
import threading This module provides a singleton event bus that allows registration and handling
from collections.abc import Callable of events throughout the CrewAI system, supporting both synchronous and asynchronous
event handlers with optional dependency management.
"""
import asyncio
import atexit
from collections.abc import Callable, Generator
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, TypeVar, cast import threading
from typing import Any, Final, ParamSpec, TypeVar
from blinker import Signal from typing_extensions import Self
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes from crewai.events.depends import Depends
from crewai.events.handler_graph import build_execution_plan
from crewai.events.types.event_bus_types import (
AsyncHandler,
AsyncHandlerSet,
ExecutionPlan,
Handler,
SyncHandler,
SyncHandlerSet,
)
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
from crewai.events.utils.rw_lock import RWLock
EventT = TypeVar("EventT", bound=BaseEvent)
P = ParamSpec("P")
R = TypeVar("R")
class CrewAIEventsBus: class CrewAIEventsBus:
""" """Singleton event bus for handling events in CrewAI.
A singleton event bus that uses blinker signals for event handling.
Allows both internal (Flow/Crew) and external event handling. This class manages event registration and emission for both synchronous
and asynchronous event handlers, automatically scheduling async handlers
in a dedicated background event loop.
Synchronous handlers execute in a thread pool executor to ensure completion
before program exit. Asynchronous handlers execute in a dedicated event loop
running in a daemon thread, with graceful shutdown waiting for completion.
Attributes:
_instance: Singleton instance of the event bus
_instance_lock: Reentrant lock for singleton initialization (class-level)
_rwlock: Read-write lock for handler registration and access (instance-level)
_sync_handlers: Mapping of event types to registered synchronous handlers
_async_handlers: Mapping of event types to registered asynchronous handlers
_sync_executor: Thread pool executor for running synchronous handlers
_loop: Dedicated asyncio event loop for async handler execution
_loop_thread: Background daemon thread running the event loop
_console: Console formatter for error output
""" """
_instance = None _instance: Self | None = None
_lock = threading.Lock() _instance_lock: threading.RLock = threading.RLock()
_rwlock: RWLock
_sync_handlers: dict[type[BaseEvent], SyncHandlerSet]
_async_handlers: dict[type[BaseEvent], AsyncHandlerSet]
_handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]]
_execution_plan_cache: dict[type[BaseEvent], ExecutionPlan]
_console: ConsoleFormatter
_shutting_down: bool
def __new__(cls): def __new__(cls) -> Self:
"""Create or return the singleton instance.
Returns:
The singleton CrewAIEventsBus instance
"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._instance_lock:
if cls._instance is None: # prevent race condition if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialize() cls._instance._initialize()
return cls._instance return cls._instance
def _initialize(self) -> None: def _initialize(self) -> None:
"""Initialize the event bus internal state""" """Initialize the event bus internal state.
self._signal = Signal("crewai_event_bus")
self._handlers: dict[type[BaseEvent], list[Callable]] = {} Creates handler dictionaries and starts a dedicated background
event loop for async handler execution.
"""
self._shutting_down = False
self._rwlock = RWLock()
self._sync_handlers: dict[type[BaseEvent], SyncHandlerSet] = {}
self._async_handlers: dict[type[BaseEvent], AsyncHandlerSet] = {}
self._handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]] = {}
self._execution_plan_cache: dict[type[BaseEvent], ExecutionPlan] = {}
self._sync_executor = ThreadPoolExecutor(
max_workers=10,
thread_name_prefix="CrewAISyncHandler",
)
self._console = ConsoleFormatter()
self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
target=self._run_loop,
name="CrewAIEventsLoop",
daemon=True,
)
self._loop_thread.start()
def _run_loop(self) -> None:
"""Run the background async event loop."""
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _register_handler(
self,
event_type: type[BaseEvent],
handler: Callable[..., Any],
dependencies: list[Depends] | None = None,
) -> None:
"""Register a handler for the given event type.
Args:
event_type: The event class to listen for
handler: The handler function to register
dependencies: Optional list of dependencies
"""
with self._rwlock.w_locked():
if is_async_handler(handler):
existing_async = self._async_handlers.get(event_type, frozenset())
self._async_handlers[event_type] = existing_async | {handler}
else:
existing_sync = self._sync_handlers.get(event_type, frozenset())
self._sync_handlers[event_type] = existing_sync | {handler}
if dependencies:
if event_type not in self._handler_dependencies:
self._handler_dependencies[event_type] = {}
self._handler_dependencies[event_type][handler] = dependencies
self._execution_plan_cache.pop(event_type, None)
def on( def on(
self, event_type: type[EventT] self,
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]: event_type: type[BaseEvent],
""" depends_on: Depends | list[Depends] | None = None,
Decorator to register an event handler for a specific event type. ) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to register an event handler for a specific event type.
Usage: Args:
@crewai_event_bus.on(AgentExecutionCompletedEvent) event_type: The event class to listen for
def on_agent_execution_completed( depends_on: Optional dependency or list of dependencies. Handlers with
source: Any, event: AgentExecutionCompletedEvent dependencies will execute after their dependencies complete.
):
print(f"👍 Agent '{event.agent}' completed task") Returns:
print(f" Output: {event.output}") Decorator function that registers the handler
Example:
>>> from crewai.events import crewai_event_bus, Depends
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent)
>>> def setup_context(source, event):
... print("Setting up context")
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
>>> def process(source, event):
... print("Processing (runs after setup_context)")
""" """
def decorator( def decorator(handler: Callable[P, R]) -> Callable[P, R]:
handler: Callable[[Any, EventT], None], """Register the handler and return it unchanged.
) -> Callable[[Any, EventT], None]:
if event_type not in self._handlers: Args:
self._handlers[event_type] = [] handler: Event handler function to register
self._handlers[event_type].append(
cast(Callable[[Any, EventT], None], handler) Returns:
) The same handler function unchanged
"""
deps = None
if depends_on is not None:
deps = [depends_on] if isinstance(depends_on, Depends) else depends_on
self._register_handler(event_type, handler, dependencies=deps)
return handler return handler
return decorator return decorator
@staticmethod def _call_handlers(
def _call_handler( self,
handler: Callable, source: Any, event: BaseEvent, event_type: type source: Any,
event: BaseEvent,
handlers: SyncHandlerSet,
) -> None: ) -> None:
"""Call a single handler with error handling.""" """Call provided synchronous handlers.
try:
handler(source, event) Args:
except Exception as e: source: The emitting object
print( event: The event instance
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}" handlers: Frozenset of sync handlers to call
"""
errors: list[tuple[SyncHandler, Exception]] = [
(handler, error)
for handler in handlers
if (error := is_call_handler_safe(handler, source, event)) is not None
]
if errors:
for handler, error in errors:
self._console.print(
f"[CrewAIEventsBus] Sync handler error in {handler.__name__}: {error}"
)
async def _acall_handlers(
self,
source: Any,
event: BaseEvent,
handlers: AsyncHandlerSet,
) -> None:
"""Asynchronously call provided async handlers.
Args:
source: The object that emitted the event
event: The event instance
handlers: Frozenset of async handlers to call
"""
coros = [handler(source, event) for handler in handlers]
results = await asyncio.gather(*coros, return_exceptions=True)
for handler, result in zip(handlers, results, strict=False):
if isinstance(result, Exception):
self._console.print(
f"[CrewAIEventsBus] Async handler error in {getattr(handler, '__name__', handler)}: {result}"
)
async def _emit_with_dependencies(self, source: Any, event: BaseEvent) -> None:
"""Emit an event with dependency-aware handler execution.
Handlers are grouped into execution levels based on their dependencies.
Within each level, async handlers run concurrently while sync handlers
run sequentially (or in thread pool). Each level completes before the
next level starts.
Uses a cached execution plan for performance. The plan is built once
per event type and cached until handlers are modified.
Args:
source: The emitting object
event: The event instance to emit
"""
event_type = type(event)
with self._rwlock.r_locked():
if self._shutting_down:
return
cached_plan = self._execution_plan_cache.get(event_type)
if cached_plan is not None:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
if cached_plan is None:
with self._rwlock.w_locked():
if self._shutting_down:
return
cached_plan = self._execution_plan_cache.get(event_type)
if cached_plan is None:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
dependencies = dict(self._handler_dependencies.get(event_type, {}))
all_handlers = list(sync_handlers | async_handlers)
if not all_handlers:
return
cached_plan = build_execution_plan(all_handlers, dependencies)
self._execution_plan_cache[event_type] = cached_plan
else:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
for level in cached_plan:
level_sync = frozenset(h for h in level if h in sync_handlers)
level_async = frozenset(h for h in level if h in async_handlers)
if level_sync:
if event_type is LLMStreamChunkEvent:
self._call_handlers(source, event, level_sync)
else:
future = self._sync_executor.submit(
self._call_handlers, source, event, level_sync
)
await asyncio.get_running_loop().run_in_executor(
None, future.result
)
if level_async:
await self._acall_handlers(source, event, level_async)
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
"""Emit an event to all registered handlers.
If handlers have dependencies (registered with depends_on), they execute
in dependency order. Otherwise, handlers execute as before (sync in thread
pool, async fire-and-forget).
Stream chunk events always execute synchronously to preserve ordering.
Args:
source: The emitting object
event: The event instance to emit
Returns:
Future that completes when handlers finish. Returns:
- Future for sync-only handlers (ThreadPoolExecutor future)
- Future for async handlers or mixed handlers (asyncio future)
- Future for dependency-managed handlers (asyncio future)
- None if no handlers or sync stream chunk events
Example:
>>> future = crewai_event_bus.emit(source, event)
>>> if future:
... await asyncio.wrap_future(future) # In async test
... # or future.result(timeout=5.0) in sync code
"""
event_type = type(event)
with self._rwlock.r_locked():
if self._shutting_down:
self._console.print(
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
)
return None
has_dependencies = event_type in self._handler_dependencies
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
if has_dependencies:
return asyncio.run_coroutine_threadsafe(
self._emit_with_dependencies(source, event),
self._loop,
) )
def emit(self, source: Any, event: BaseEvent) -> None: if sync_handlers:
""" if event_type is LLMStreamChunkEvent:
Emit an event to all registered handlers self._call_handlers(source, event, sync_handlers)
else:
sync_future = self._sync_executor.submit(
self._call_handlers, source, event, sync_handlers
)
if not async_handlers:
return sync_future
if async_handlers:
return asyncio.run_coroutine_threadsafe(
self._acall_handlers(source, event, async_handlers),
self._loop,
)
return None
async def aemit(self, source: Any, event: BaseEvent) -> None:
"""Asynchronously emit an event to registered async handlers.
Only processes async handlers. Use in async contexts.
Args: Args:
source: The object emitting the event source: The object emitting the event
event: The event instance to emit event: The event instance to emit
""" """
for event_type, handlers in self._handlers.items(): event_type = type(event)
if isinstance(event, event_type):
for handler in handlers:
self._call_handler(handler, source, event, event_type)
self._signal.send(source, event=event) with self._rwlock.r_locked():
if self._shutting_down:
self._console.print(
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
)
return
async_handlers = self._async_handlers.get(event_type, frozenset())
if async_handlers:
await self._acall_handlers(source, event, async_handlers)
def register_handler( def register_handler(
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None] self,
event_type: type[BaseEvent],
handler: SyncHandler | AsyncHandler,
) -> None: ) -> None:
"""Register an event handler for a specific event type""" """Register an event handler for a specific event type.
if event_type not in self._handlers:
self._handlers[event_type] = [] Args:
self._handlers[event_type].append( event_type: The event class to listen for
cast(Callable[[Any, EventTypes], None], handler) handler: The handler function to register
) """
self._register_handler(event_type, handler)
def validate_dependencies(self) -> None:
"""Validate all registered handler dependencies.
Attempts to build execution plans for all event types with dependencies.
This detects circular dependencies and cross-event-type dependencies
before events are emitted.
Raises:
CircularDependencyError: If circular dependencies or unresolved
dependencies (e.g., cross-event-type) are detected
"""
with self._rwlock.r_locked():
for event_type in self._handler_dependencies:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
dependencies = dict(self._handler_dependencies.get(event_type, {}))
all_handlers = list(sync_handlers | async_handlers)
if all_handlers and dependencies:
build_execution_plan(all_handlers, dependencies)
@contextmanager @contextmanager
def scoped_handlers(self): def scoped_handlers(self) -> Generator[None, Any, None]:
""" """Context manager for temporary event handling scope.
Context manager for temporary event handling scope.
Useful for testing or temporary event handling.
Usage: Useful for testing or temporary event handling. All handlers registered
with crewai_event_bus.scoped_handlers(): within this context are cleared when the context exits.
@crewai_event_bus.on(CrewKickoffStarted)
def temp_handler(source, event): Example:
print("Temporary handler") >>> from crewai.events.event_bus import crewai_event_bus
# Do stuff... >>> from crewai.events.event_types import CrewKickoffStartedEvent
# Handlers are cleared after the context >>> with crewai_event_bus.scoped_handlers():
...
... @crewai_event_bus.on(CrewKickoffStartedEvent)
... def temp_handler(source, event):
... print("Temporary handler")
...
... # Do stuff...
... # Handlers are cleared after the context
""" """
previous_handlers = self._handlers.copy() with self._rwlock.w_locked():
self._handlers.clear() prev_sync = self._sync_handlers
prev_async = self._async_handlers
prev_deps = self._handler_dependencies
prev_cache = self._execution_plan_cache
self._sync_handlers = {}
self._async_handlers = {}
self._handler_dependencies = {}
self._execution_plan_cache = {}
try: try:
yield yield
finally: finally:
self._handlers = previous_handlers with self._rwlock.w_locked():
self._sync_handlers = prev_sync
self._async_handlers = prev_async
self._handler_dependencies = prev_deps
self._execution_plan_cache = prev_cache
def shutdown(self, wait: bool = True) -> None:
"""Gracefully shutdown the event loop and wait for all tasks to finish.
Args:
wait: If True, wait for all pending tasks to complete before stopping.
If False, cancel all pending tasks immediately.
"""
with self._rwlock.w_locked():
self._shutting_down = True
loop = getattr(self, "_loop", None)
if loop is None or loop.is_closed():
return
if wait:
async def _wait_for_all_tasks() -> None:
tasks = {
t
for t in asyncio.all_tasks(loop)
if t is not asyncio.current_task()
}
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
future = asyncio.run_coroutine_threadsafe(_wait_for_all_tasks(), loop)
try:
future.result()
except Exception as e:
self._console.print(f"[CrewAIEventsBus] Error waiting for tasks: {e}")
else:
def _cancel_tasks() -> None:
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
task.cancel()
loop.call_soon_threadsafe(_cancel_tasks)
loop.call_soon_threadsafe(loop.stop)
self._loop_thread.join()
loop.close()
self._sync_executor.shutdown(wait=wait)
with self._rwlock.w_locked():
self._sync_handlers.clear()
self._async_handlers.clear()
self._execution_plan_cache.clear()
# Global instance crewai_event_bus: Final[CrewAIEventsBus] = CrewAIEventsBus()
crewai_event_bus = CrewAIEventsBus()
atexit.register(crewai_event_bus.shutdown)

View File

@@ -0,0 +1,130 @@
"""Dependency graph resolution for event handlers.
This module resolves handler dependencies into execution levels, ensuring
handlers execute in correct order while maximizing parallelism.
"""
from collections import defaultdict, deque
from collections.abc import Sequence
from crewai.events.depends import Depends
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
class CircularDependencyError(Exception):
"""Exception raised when circular dependencies are detected in event handlers.
Attributes:
handlers: The handlers involved in the circular dependency
"""
def __init__(self, handlers: list[Handler]) -> None:
"""Initialize the circular dependency error.
Args:
handlers: The handlers involved in the circular dependency
"""
handler_names = ", ".join(
getattr(h, "__name__", repr(h)) for h in handlers[:5]
)
message = f"Circular dependency detected in event handlers: {handler_names}"
super().__init__(message)
self.handlers = handlers
class HandlerGraph:
"""Resolves handler dependencies into parallel execution levels.
Handlers are organized into levels where:
- Level 0: Handlers with no dependencies (can run first)
- Level N: Handlers that depend on handlers in levels 0...N-1
Handlers within the same level can execute in parallel.
Attributes:
levels: List of handler sets, where each level can execute in parallel
"""
def __init__(
self,
handlers: dict[Handler, list[Depends]],
) -> None:
"""Initialize the dependency graph.
Args:
handlers: Mapping of handler -> list of `crewai.events.depends.Depends` objects
"""
self.handlers = handlers
self.levels: ExecutionPlan = []
self._resolve()
def _resolve(self) -> None:
"""Resolve dependencies into execution levels using topological sort."""
dependents: dict[Handler, set[Handler]] = defaultdict(set)
in_degree: dict[Handler, int] = {}
for handler in self.handlers:
in_degree[handler] = 0
for handler, deps in self.handlers.items():
in_degree[handler] = len(deps)
for dep in deps:
dependents[dep.handler].add(handler)
queue: deque[Handler] = deque(
[h for h, deg in in_degree.items() if deg == 0]
)
while queue:
current_level: set[Handler] = set()
for _ in range(len(queue)):
handler = queue.popleft()
current_level.add(handler)
for dependent in dependents[handler]:
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
queue.append(dependent)
if current_level:
self.levels.append(current_level)
remaining = [h for h, deg in in_degree.items() if deg > 0]
if remaining:
raise CircularDependencyError(remaining)
def get_execution_plan(self) -> ExecutionPlan:
"""Get the ordered execution plan.
Returns:
List of handler sets, where each set represents handlers that can
execute in parallel. Sets are ordered such that dependencies are
satisfied.
"""
return self.levels
def build_execution_plan(
handlers: Sequence[Handler],
dependencies: dict[Handler, list[Depends]],
) -> ExecutionPlan:
"""Build an execution plan from handlers and their dependencies.
Args:
handlers: All handlers for an event type
dependencies: Mapping of handler -> list of dependencies
Returns:
Execution plan as list of levels, where each level is a set of
handlers that can execute in parallel
Raises:
CircularDependencyError: If circular dependencies are detected
"""
handler_dict: dict[Handler, list[Depends]] = {
h: dependencies.get(h, []) for h in handlers
}
graph = HandlerGraph(handler_dict)
return graph.get_execution_plan()

View File

@@ -1,8 +1,9 @@
import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from logging import getLogger from logging import getLogger
from threading import Condition, Lock
from typing import Any from typing import Any
import uuid
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@@ -14,6 +15,7 @@ from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
from crewai.utilities.constants import CREWAI_BASE_URL from crewai.utilities.constants import CREWAI_BASE_URL
logger = getLogger(__name__) logger = getLogger(__name__)
@@ -41,6 +43,11 @@ class TraceBatchManager:
"""Single responsibility: Manage batches and event buffering""" """Single responsibility: Manage batches and event buffering"""
def __init__(self): def __init__(self):
self._init_lock = Lock()
self._pending_events_lock = Lock()
self._pending_events_cv = Condition(self._pending_events_lock)
self._pending_events_count = 0
self.is_current_batch_ephemeral: bool = False self.is_current_batch_ephemeral: bool = False
self.trace_batch_id: str | None = None self.trace_batch_id: str | None = None
self.current_batch: TraceBatch | None = None self.current_batch: TraceBatch | None = None
@@ -64,24 +71,28 @@ class TraceBatchManager:
execution_metadata: dict[str, Any], execution_metadata: dict[str, Any],
use_ephemeral: bool = False, use_ephemeral: bool = False,
) -> TraceBatch: ) -> TraceBatch:
"""Initialize a new trace batch""" """Initialize a new trace batch (thread-safe)"""
self.current_batch = TraceBatch( with self._init_lock:
user_context=user_context, execution_metadata=execution_metadata if self.current_batch is not None:
) logger.debug("Batch already initialized, skipping duplicate initialization")
self.event_buffer.clear() return self.current_batch
self.is_current_batch_ephemeral = use_ephemeral
self.record_start_time("execution") self.current_batch = TraceBatch(
user_context=user_context, execution_metadata=execution_metadata
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
) )
self.backend_initialized = True self.is_current_batch_ephemeral = use_ephemeral
return self.current_batch self.record_start_time("execution")
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
)
self.backend_initialized = True
return self.current_batch
def _initialize_backend_batch( def _initialize_backend_batch(
self, self,
@@ -148,6 +159,38 @@ class TraceBatchManager:
f"Error initializing trace batch: {e}. Continuing without tracing." f"Error initializing trace batch: {e}. Continuing without tracing."
) )
def begin_event_processing(self):
"""Mark that an event handler started processing (for synchronization)"""
with self._pending_events_lock:
self._pending_events_count += 1
def end_event_processing(self):
"""Mark that an event handler finished processing (for synchronization)"""
with self._pending_events_cv:
self._pending_events_count -= 1
if self._pending_events_count == 0:
self._pending_events_cv.notify_all()
def wait_for_pending_events(self, timeout: float = 2.0) -> bool:
"""Wait for all pending event handlers to finish processing
Args:
timeout: Maximum time to wait in seconds (default: 2.0)
Returns:
True if all handlers completed, False if timeout occurred
"""
with self._pending_events_cv:
if self._pending_events_count > 0:
logger.debug(f"Waiting for {self._pending_events_count} pending event handlers...")
self._pending_events_cv.wait(timeout)
if self._pending_events_count > 0:
logger.error(
f"Timeout waiting for event handlers. {self._pending_events_count} still pending. Events may be incomplete!"
)
return False
return True
def add_event(self, trace_event: TraceEvent): def add_event(self, trace_event: TraceEvent):
"""Add event to buffer""" """Add event to buffer"""
self.event_buffer.append(trace_event) self.event_buffer.append(trace_event)
@@ -180,8 +223,8 @@ class TraceBatchManager:
self.event_buffer.clear() self.event_buffer.clear()
return 200 return 200
logger.warning( logger.error(
f"Failed to send events: {response.status_code}. Events will be lost." f"Failed to send events: {response.status_code}. Response: {response.text}. Events will be lost."
) )
return 500 return 500
@@ -196,15 +239,33 @@ class TraceBatchManager:
if not self.current_batch: if not self.current_batch:
return None return None
self.current_batch.events = self.event_buffer.copy() all_handlers_completed = self.wait_for_pending_events(timeout=2.0)
if self.event_buffer:
if not all_handlers_completed:
logger.error("Event handler timeout - marking batch as failed due to incomplete events")
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Timeout waiting for event handlers - events incomplete"
)
return None
sorted_events = sorted(
self.event_buffer,
key=lambda e: e.timestamp if hasattr(e, 'timestamp') and e.timestamp else ''
)
self.current_batch.events = sorted_events
events_sent_count = len(sorted_events)
if sorted_events:
original_buffer = self.event_buffer
self.event_buffer = sorted_events
events_sent_to_backend_status = self._send_events_to_backend() events_sent_to_backend_status = self._send_events_to_backend()
self.event_buffer = original_buffer
if events_sent_to_backend_status == 500: if events_sent_to_backend_status == 500:
self.plus_api.mark_trace_batch_as_failed( self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Error sending events to backend" self.trace_batch_id, "Error sending events to backend"
) )
return None return None
self._finalize_backend_batch() self._finalize_backend_batch(events_sent_count)
finalized_batch = self.current_batch finalized_batch = self.current_batch
@@ -220,18 +281,20 @@ class TraceBatchManager:
return finalized_batch return finalized_batch
def _finalize_backend_batch(self): def _finalize_backend_batch(self, events_count: int = 0):
"""Send batch finalization to backend""" """Send batch finalization to backend
Args:
events_count: Number of events that were successfully sent
"""
if not self.plus_api or not self.trace_batch_id: if not self.plus_api or not self.trace_batch_id:
return return
try: try:
total_events = len(self.current_batch.events) if self.current_batch else 0
payload = { payload = {
"status": "completed", "status": "completed",
"duration_ms": self.calculate_duration("execution"), "duration_ms": self.calculate_duration("execution"),
"final_event_count": total_events, "final_event_count": events_count,
} }
response = ( response = (

View File

@@ -170,14 +170,6 @@ class TraceCollectionListener(BaseEventListener):
def on_flow_finished(source, event): def on_flow_finished(source, event):
self._handle_trace_event("flow_finished", source, event) self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow":
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
# Normal flow finalization
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent) @event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event): def on_flow_plot(source, event):
self._handle_action_event("flow_plot", source, event) self._handle_action_event("flow_plot", source, event)
@@ -383,10 +375,12 @@ class TraceCollectionListener(BaseEventListener):
def _handle_trace_event(self, event_type: str, source: Any, event: Any): def _handle_trace_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for context end events""" """Generic handler for context end events"""
self.batch_manager.begin_event_processing()
trace_event = self._create_trace_event(event_type, source, event) try:
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event) self.batch_manager.add_event(trace_event)
finally:
self.batch_manager.end_event_processing()
def _handle_action_event(self, event_type: str, source: Any, event: Any): def _handle_action_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for action events (LLM calls, tool usage)""" """Generic handler for action events (LLM calls, tool usage)"""
@@ -399,18 +393,29 @@ class TraceCollectionListener(BaseEventListener):
} }
self.batch_manager.initialize_batch(user_context, execution_metadata) self.batch_manager.initialize_batch(user_context, execution_metadata)
trace_event = self._create_trace_event(event_type, source, event) self.batch_manager.begin_event_processing()
self.batch_manager.add_event(trace_event) try:
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
finally:
self.batch_manager.end_event_processing()
def _create_trace_event( def _create_trace_event(
self, event_type: str, source: Any, event: Any self, event_type: str, source: Any, event: Any
) -> TraceEvent: ) -> TraceEvent:
"""Create a trace event""" """Create a trace event"""
trace_event = TraceEvent( if hasattr(event, 'timestamp') and event.timestamp:
type=event_type, trace_event = TraceEvent(
) type=event_type,
timestamp=event.timestamp.isoformat(),
)
else:
trace_event = TraceEvent(
type=event_type,
)
trace_event.event_data = self._build_event_data(event_type, event, source) trace_event.event_data = self._build_event_data(event_type, event, source)
return trace_event return trace_event
def _build_event_data( def _build_event_data(

View File

@@ -0,0 +1,14 @@
"""Type definitions for event handlers."""
from collections.abc import Callable, Coroutine
from typing import Any, TypeAlias
from crewai.events.base_events import BaseEvent
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
ExecutionPlan: TypeAlias = list[set[Handler]]

View File

@@ -0,0 +1,59 @@
"""Handler utility functions for event processing."""
import functools
import inspect
from typing import Any
from typing_extensions import TypeIs
from crewai.events.base_events import BaseEvent
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
def is_async_handler(
handler: Any,
) -> TypeIs[AsyncHandler]:
"""Type guard to check if handler is an async handler.
Args:
handler: The handler to check
Returns:
True if handler is an async coroutine function
"""
try:
if inspect.iscoroutinefunction(handler) or (
callable(handler) and inspect.iscoroutinefunction(handler.__call__)
):
return True
except AttributeError:
return False
if isinstance(handler, functools.partial) and inspect.iscoroutinefunction(
handler.func
):
return True
return False
def is_call_handler_safe(
handler: SyncHandler,
source: Any,
event: BaseEvent,
) -> Exception | None:
"""Safely call a single handler and return any exception.
Args:
handler: The handler function to call
source: The object that emitted the event
event: The event instance
Returns:
Exception if handler raised one, None otherwise
"""
try:
handler(source, event)
return None
except Exception as e:
return e

View File

@@ -0,0 +1,81 @@
"""Read-write lock for thread-safe concurrent access.
This module provides a reader-writer lock implementation that allows multiple
concurrent readers or a single exclusive writer.
"""
from collections.abc import Generator
from contextlib import contextmanager
from threading import Condition
class RWLock:
"""Read-write lock for managing concurrent read and exclusive write access.
Allows multiple threads to acquire read locks simultaneously, but ensures
exclusive access for write operations. Writers are prioritized when waiting.
Attributes:
_cond: Condition variable for coordinating lock access
_readers: Count of active readers
_writer: Whether a writer currently holds the lock
"""
def __init__(self) -> None:
"""Initialize the read-write lock."""
self._cond = Condition()
self._readers = 0
self._writer = False
def r_acquire(self) -> None:
"""Acquire a read lock, blocking if a writer holds the lock."""
with self._cond:
while self._writer:
self._cond.wait()
self._readers += 1
def r_release(self) -> None:
"""Release a read lock and notify waiting writers if last reader."""
with self._cond:
self._readers -= 1
if self._readers == 0:
self._cond.notify_all()
@contextmanager
def r_locked(self) -> Generator[None, None, None]:
"""Context manager for acquiring a read lock.
Yields:
None
"""
try:
self.r_acquire()
yield
finally:
self.r_release()
def w_acquire(self) -> None:
"""Acquire a write lock, blocking if any readers or writers are active."""
with self._cond:
while self._writer or self._readers > 0:
self._cond.wait()
self._writer = True
def w_release(self) -> None:
"""Release a write lock and notify all waiting threads."""
with self._cond:
self._writer = False
self._cond.notify_all()
@contextmanager
def w_locked(self) -> Generator[None, None, None]:
"""Context manager for acquiring a write lock.
Yields:
None
"""
try:
self.w_acquire()
yield
finally:
self.w_release()

View File

@@ -52,19 +52,14 @@ class AgentEvaluator:
self.console_formatter = ConsoleFormatter() self.console_formatter = ConsoleFormatter()
self.display_formatter = EvaluationDisplayFormatter() self.display_formatter = EvaluationDisplayFormatter()
self._thread_local: threading.local = threading.local() self._execution_state = ExecutionState()
self._state_lock = threading.Lock()
for agent in self.agents: for agent in self.agents:
self._execution_state.agent_evaluators[str(agent.id)] = self.evaluators self._execution_state.agent_evaluators[str(agent.id)] = self.evaluators
self._subscribe_to_events() self._subscribe_to_events()
@property
def _execution_state(self) -> ExecutionState:
if not hasattr(self._thread_local, "execution_state"):
self._thread_local.execution_state = ExecutionState()
return self._thread_local.execution_state
def _subscribe_to_events(self) -> None: def _subscribe_to_events(self) -> None:
from typing import cast from typing import cast
@@ -112,21 +107,22 @@ class AgentEvaluator:
state=state, state=state,
) )
current_iteration = self._execution_state.iteration with self._state_lock:
if current_iteration not in self._execution_state.iterations_results: current_iteration = self._execution_state.iteration
self._execution_state.iterations_results[current_iteration] = {} if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
if (
agent.role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent.role
] = []
if (
agent.role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][ self._execution_state.iterations_results[current_iteration][
agent.role agent.role
] = [] ].append(result)
self._execution_state.iterations_results[current_iteration][
agent.role
].append(result)
def _handle_lite_agent_completed( def _handle_lite_agent_completed(
self, source: object, event: LiteAgentExecutionCompletedEvent self, source: object, event: LiteAgentExecutionCompletedEvent
@@ -164,22 +160,23 @@ class AgentEvaluator:
state=state, state=state,
) )
current_iteration = self._execution_state.iteration with self._state_lock:
if current_iteration not in self._execution_state.iterations_results: current_iteration = self._execution_state.iteration
self._execution_state.iterations_results[current_iteration] = {} if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
agent_role = target_agent.role
if (
agent_role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent_role
] = []
agent_role = target_agent.role
if (
agent_role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][ self._execution_state.iterations_results[current_iteration][
agent_role agent_role
] = [] ].append(result)
self._execution_state.iterations_results[current_iteration][
agent_role
].append(result)
def set_iteration(self, iteration: int) -> None: def set_iteration(self, iteration: int) -> None:
self._execution_state.iteration = iteration self._execution_state.iteration = iteration

View File

@@ -3,6 +3,7 @@ import copy
import inspect import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import Future
from typing import Any, ClassVar, Generic, TypeVar, cast from typing import Any, ClassVar, Generic, TypeVar, cast
from uuid import uuid4 from uuid import uuid4
@@ -463,6 +464,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._completed_methods: set[str] = set() # Track completed methods for reload self._completed_methods: set[str] = set() # Track completed methods for reload
self._persistence: FlowPersistence | None = persistence self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False self._is_execution_resuming: bool = False
self._event_futures: list[Future[None]] = []
# Initialize state with initial values # Initialize state with initial values
self._state = self._create_initial_state() self._state = self._create_initial_state()
@@ -855,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._initialize_state(filtered_inputs) self._initialize_state(filtered_inputs)
# Emit FlowStartedEvent and log the start of the flow. # Emit FlowStartedEvent and log the start of the flow.
crewai_event_bus.emit( future = crewai_event_bus.emit(
self, self,
FlowStartedEvent( FlowStartedEvent(
type="flow_started", type="flow_started",
@@ -863,6 +865,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
inputs=inputs, inputs=inputs,
), ),
) )
if future:
self._event_futures.append(future)
self._log_flow_event( self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta" f"Flow started with ID: {self.flow_id}", color="bold_magenta"
) )
@@ -881,7 +885,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
final_output = self._method_outputs[-1] if self._method_outputs else None final_output = self._method_outputs[-1] if self._method_outputs else None
crewai_event_bus.emit( future = crewai_event_bus.emit(
self, self,
FlowFinishedEvent( FlowFinishedEvent(
type="flow_finished", type="flow_finished",
@@ -889,6 +893,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=final_output, result=final_output,
), ),
) )
if future:
self._event_futures.append(future)
if self._event_futures:
await asyncio.gather(*[asyncio.wrap_future(f) for f in self._event_futures])
self._event_futures.clear()
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
if trace_listener.batch_manager.batch_owner_type == "flow":
if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion()
else:
trace_listener.batch_manager.finalize_batch()
return final_output return final_output
finally: finally:
@@ -971,7 +994,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | ( dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {} kwargs or {}
) )
crewai_event_bus.emit( future = crewai_event_bus.emit(
self, self,
MethodExecutionStartedEvent( MethodExecutionStartedEvent(
type="method_execution_started", type="method_execution_started",
@@ -981,6 +1004,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
state=self._copy_state(), state=self._copy_state(),
), ),
) )
if future:
self._event_futures.append(future)
result = ( result = (
await method(*args, **kwargs) await method(*args, **kwargs)
@@ -994,7 +1019,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
) )
self._completed_methods.add(method_name) self._completed_methods.add(method_name)
crewai_event_bus.emit( future = crewai_event_bus.emit(
self, self,
MethodExecutionFinishedEvent( MethodExecutionFinishedEvent(
type="method_execution_finished", type="method_execution_finished",
@@ -1004,10 +1029,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=result, result=result,
), ),
) )
if future:
self._event_futures.append(future)
return result return result
except Exception as e: except Exception as e:
crewai_event_bus.emit( future = crewai_event_bus.emit(
self, self,
MethodExecutionFailedEvent( MethodExecutionFailedEvent(
type="method_execution_failed", type="method_execution_failed",
@@ -1016,6 +1043,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
error=e, error=e,
), ),
) )
if future:
self._event_futures.append(future)
raise e raise e
async def _execute_listeners(self, trigger_method: str, result: Any) -> None: async def _execute_listeners(self, trigger_method: str, result: Any) -> None:

View File

@@ -1,6 +1,7 @@
"""Test Agent creation and execution basic functionality.""" """Test Agent creation and execution basic functionality."""
import os import os
import threading
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -185,14 +186,17 @@ def test_agent_execution_with_tools():
expected_output="The result of the multiplication.", expected_output="The result of the multiplication.",
) )
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event): def handle_tool_end(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
output = agent.execute_task(task) output = agent.execute_task(task)
assert output == "The result of the multiplication is 12." assert output == "The result of the multiplication is 12."
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
assert len(received_events) == 1 assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent) assert isinstance(received_events[0], ToolUsageFinishedEvent)
assert received_events[0].tool_name == "multiplier" assert received_events[0].tool_name == "multiplier"
@@ -284,10 +288,12 @@ def test_cache_hitting():
'multiplier-{"first_number": 12, "second_number": 3}': 36, 'multiplier-{"first_number": 12, "second_number": 3}': 36,
} }
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event): def handle_tool_end(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
with ( with (
patch.object(CacheHandler, "read") as read, patch.object(CacheHandler, "read") as read,
@@ -303,6 +309,7 @@ def test_cache_hitting():
read.assert_called_with( read.assert_called_with(
tool="multiplier", input='{"first_number": 2, "second_number": 6}' tool="multiplier", input='{"first_number": 2, "second_number": 6}'
) )
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
assert len(received_events) == 1 assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent) assert isinstance(received_events[0], ToolUsageFinishedEvent)
assert received_events[0].from_cache assert received_events[0].from_cache

View File

@@ -1,4 +1,5 @@
# mypy: ignore-errors # mypy: ignore-errors
import threading
from collections import defaultdict from collections import defaultdict
from typing import cast from typing import cast
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@@ -156,14 +157,17 @@ def test_lite_agent_with_tools():
) )
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageStartedEvent) @crewai_event_bus.on(ToolUsageStartedEvent)
def event_handler(source, event): def event_handler(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
agent.kickoff("What are the effects of climate change on coral reefs?") agent.kickoff("What are the effects of climate change on coral reefs?")
# Verify tool usage events were emitted # Verify tool usage events were emitted
assert event_received.wait(timeout=5), "Timeout waiting for tool usage events"
assert len(received_events) > 0, "Tool usage events should be emitted" assert len(received_events) > 0, "Tool usage events should be emitted"
event = received_events[0] event = received_events[0]
assert isinstance(event, ToolUsageStartedEvent) assert isinstance(event, ToolUsageStartedEvent)
@@ -316,15 +320,18 @@ def test_sets_parent_flow_when_inside_flow():
return agent.kickoff("Test query") return agent.kickoff("Test query")
flow = MyFlow() flow = MyFlow()
with crewai_event_bus.scoped_handlers(): event_received = threading.Event()
@crewai_event_bus.on(LiteAgentExecutionStartedEvent) @crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def capture_agent(source, event): def capture_agent(source, event):
nonlocal captured_agent nonlocal captured_agent
captured_agent = source captured_agent = source
event_received.set()
flow.kickoff() flow.kickoff()
assert captured_agent.parent_flow is flow
assert event_received.wait(timeout=5), "Timeout waiting for agent execution event"
assert captured_agent.parent_flow is flow
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -342,30 +349,43 @@ def test_guardrail_is_called_using_string():
guardrail="""Only include Brazilian players, both women and men""", guardrail="""Only include Brazilian players, both women and men""",
) )
with crewai_event_bus.scoped_handlers(): all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent) @crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event): def capture_guardrail_started(source, event):
assert isinstance(source, LiteAgent) assert isinstance(source, LiteAgent)
assert source.original_agent == agent assert source.original_agent == agent
guardrail_events["started"].append(event) guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
@crewai_event_bus.on(LLMGuardrailCompletedEvent) @crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event): def capture_guardrail_completed(source, event):
assert isinstance(source, LiteAgent) assert isinstance(source, LiteAgent)
assert source.original_agent == agent assert source.original_agent == agent
guardrail_events["completed"].append(event) guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
result = agent.kickoff(messages="Top 10 best players in the world?") result = agent.kickoff(messages="Top 10 best players in the world?")
assert len(guardrail_events["started"]) == 2 assert all_events_received.wait(timeout=10), (
assert len(guardrail_events["completed"]) == 2 "Timeout waiting for all guardrail events"
assert not guardrail_events["completed"][0].success )
assert guardrail_events["completed"][1].success assert len(guardrail_events["started"]) == 2
assert ( assert len(guardrail_events["completed"]) == 2
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players" assert not guardrail_events["completed"][0].success
in result.raw assert guardrail_events["completed"][1].success
) assert (
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
in result.raw
)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -376,29 +396,42 @@ def test_guardrail_is_called_using_callable():
LLMGuardrailStartedEvent, LLMGuardrailStartedEvent,
) )
with crewai_event_bus.scoped_handlers(): all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent) @crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event): def capture_guardrail_started(source, event):
guardrail_events["started"].append(event) guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(LLMGuardrailCompletedEvent) @crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event): def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event) guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
agent = Agent( agent = Agent(
role="Sports Analyst", role="Sports Analyst",
goal="Gather information about the best soccer players", goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""", backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (True, "Pelé - Santos, 1958"), guardrail=lambda output: (True, "Pelé - Santos, 1958"),
) )
result = agent.kickoff(messages="Top 1 best players in the world?") result = agent.kickoff(messages="Top 1 best players in the world?")
assert len(guardrail_events["started"]) == 1 assert all_events_received.wait(timeout=10), (
assert len(guardrail_events["completed"]) == 1 "Timeout waiting for all guardrail events"
assert guardrail_events["completed"][0].success )
assert "Pelé - Santos, 1958" in result.raw assert len(guardrail_events["started"]) == 1
assert len(guardrail_events["completed"]) == 1
assert guardrail_events["completed"][0].success
assert "Pelé - Santos, 1958" in result.raw
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -409,37 +442,50 @@ def test_guardrail_reached_attempt_limit():
LLMGuardrailStartedEvent, LLMGuardrailStartedEvent,
) )
with crewai_event_bus.scoped_handlers(): all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent) @crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event): def capture_guardrail_started(source, event):
guardrail_events["started"].append(event) guardrail_events["started"].append(event)
if (
@crewai_event_bus.on(LLMGuardrailCompletedEvent) len(guardrail_events["started"]) == 3
def capture_guardrail_completed(source, event): and len(guardrail_events["completed"]) == 3
guardrail_events["completed"].append(event)
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (
False,
"You are not allowed to include Brazilian players",
),
guardrail_max_retries=2,
)
with pytest.raises(
Exception, match="Agent's guardrail failed validation after 2 retries"
): ):
agent.kickoff(messages="Top 10 best players in the world?") all_events_received.set()
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call @crewai_event_bus.on(LLMGuardrailCompletedEvent)
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call def capture_guardrail_completed(source, event):
assert not guardrail_events["completed"][0].success guardrail_events["completed"].append(event)
assert not guardrail_events["completed"][1].success if (
assert not guardrail_events["completed"][2].success len(guardrail_events["started"]) == 3
and len(guardrail_events["completed"]) == 3
):
all_events_received.set()
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (
False,
"You are not allowed to include Brazilian players",
),
guardrail_max_retries=2,
)
with pytest.raises(
Exception, match="Agent's guardrail failed validation after 2 retries"
):
agent.kickoff(messages="Top 10 best players in the world?")
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
assert not guardrail_events["completed"][0].success
assert not guardrail_events["completed"][1].success
assert not guardrail_events["completed"][2].success
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])

View File

@@ -33,7 +33,7 @@ def setup_test_environment():
except (OSError, IOError) as e: except (OSError, IOError) as e:
raise RuntimeError( raise RuntimeError(
f"Test storage directory {storage_dir} is not writable: {e}" f"Test storage directory {storage_dir} is not writable: {e}"
) ) from e
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir) os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
os.environ["CREWAI_TESTING"] = "true" os.environ["CREWAI_TESTING"] = "true"
@@ -159,6 +159,29 @@ def mock_opentelemetry_components():
} }
@pytest.fixture(autouse=True)
def clear_event_bus_handlers():
"""Clear event bus handlers after each test for isolation.
Handlers registered during the test are allowed to run, then cleaned up
after the test completes.
"""
from crewai.events.event_bus import crewai_event_bus
from crewai.experimental.evaluation.evaluation_listener import (
EvaluationTraceCallback,
)
yield
crewai_event_bus.shutdown(wait=True)
crewai_event_bus._initialize()
callback = EvaluationTraceCallback()
callback.traces.clear()
callback.current_agent_id = None
callback.current_task_id = None
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def vcr_config(request) -> dict: def vcr_config(request) -> dict:
import os import os

View File

@@ -0,0 +1,286 @@
"""Tests for FastAPI-style dependency injection in event handlers."""
import asyncio
import pytest
from crewai.events import Depends, crewai_event_bus
from crewai.events.base_events import BaseEvent
class DependsTestEvent(BaseEvent):
"""Test event for dependency tests."""
value: int = 0
type: str = "test_event"
@pytest.mark.asyncio
async def test_basic_dependency():
"""Test that handler with dependency runs after its dependency."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup(source, event: DependsTestEvent):
execution_order.append("setup")
@crewai_event_bus.on(DependsTestEvent, Depends(setup))
def process(source, event: DependsTestEvent):
execution_order.append("process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["setup", "process"]
@pytest.mark.asyncio
async def test_multiple_dependencies():
"""Test handler with multiple dependencies."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup_a(source, event: DependsTestEvent):
execution_order.append("setup_a")
@crewai_event_bus.on(DependsTestEvent)
def setup_b(source, event: DependsTestEvent):
execution_order.append("setup_b")
@crewai_event_bus.on(
DependsTestEvent, depends_on=[Depends(setup_a), Depends(setup_b)]
)
def process(source, event: DependsTestEvent):
execution_order.append("process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# setup_a and setup_b can run in any order (same level)
assert "process" in execution_order
assert execution_order.index("process") > execution_order.index("setup_a")
assert execution_order.index("process") > execution_order.index("setup_b")
@pytest.mark.asyncio
async def test_chain_of_dependencies():
"""Test chain of dependencies (A -> B -> C)."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def handler_a(source, event: DependsTestEvent):
execution_order.append("handler_a")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_a))
def handler_b(source, event: DependsTestEvent):
execution_order.append("handler_b")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_b))
def handler_c(source, event: DependsTestEvent):
execution_order.append("handler_c")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["handler_a", "handler_b", "handler_c"]
@pytest.mark.asyncio
async def test_async_handler_with_dependency():
"""Test async handler with dependency on sync handler."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def sync_setup(source, event: DependsTestEvent):
execution_order.append("sync_setup")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(sync_setup))
async def async_process(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("async_process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["sync_setup", "async_process"]
@pytest.mark.asyncio
async def test_mixed_handlers_with_dependencies():
"""Test mix of sync and async handlers with dependencies."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup(source, event: DependsTestEvent):
execution_order.append("setup")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
def sync_process(source, event: DependsTestEvent):
execution_order.append("sync_process")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
async def async_process(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("async_process")
@crewai_event_bus.on(
DependsTestEvent, depends_on=[Depends(sync_process), Depends(async_process)]
)
def finalize(source, event: DependsTestEvent):
execution_order.append("finalize")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# Verify execution order
assert execution_order[0] == "setup"
assert "finalize" in execution_order
assert execution_order.index("finalize") > execution_order.index("sync_process")
assert execution_order.index("finalize") > execution_order.index("async_process")
@pytest.mark.asyncio
async def test_independent_handlers_run_concurrently():
"""Test that handlers without dependencies can run concurrently."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
async def handler_a(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("handler_a")
@crewai_event_bus.on(DependsTestEvent)
async def handler_b(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("handler_b")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# Both handlers should have executed
assert len(execution_order) == 2
assert "handler_a" in execution_order
assert "handler_b" in execution_order
@pytest.mark.asyncio
async def test_circular_dependency_detection():
"""Test that circular dependencies are detected and raise an error."""
from crewai.events.handler_graph import CircularDependencyError, build_execution_plan
# Create circular dependency: handler_a -> handler_b -> handler_c -> handler_a
def handler_a(source, event: DependsTestEvent):
pass
def handler_b(source, event: DependsTestEvent):
pass
def handler_c(source, event: DependsTestEvent):
pass
# Build a dependency graph with a cycle
handlers = [handler_a, handler_b, handler_c]
dependencies = {
handler_a: [Depends(handler_b)],
handler_b: [Depends(handler_c)],
handler_c: [Depends(handler_a)], # Creates the cycle
}
# Should raise CircularDependencyError about circular dependency
with pytest.raises(CircularDependencyError, match="Circular dependency"):
build_execution_plan(handlers, dependencies)
@pytest.mark.asyncio
async def test_handler_without_dependency_runs_normally():
"""Test that handlers without dependencies still work as before."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def simple_handler(source, event: DependsTestEvent):
execution_order.append("simple_handler")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["simple_handler"]
@pytest.mark.asyncio
async def test_depends_equality():
"""Test Depends equality and hashing."""
def handler_a(source, event):
pass
def handler_b(source, event):
pass
dep_a1 = Depends(handler_a)
dep_a2 = Depends(handler_a)
dep_b = Depends(handler_b)
# Same handler should be equal
assert dep_a1 == dep_a2
assert hash(dep_a1) == hash(dep_a2)
# Different handlers should not be equal
assert dep_a1 != dep_b
assert hash(dep_a1) != hash(dep_b)
@pytest.mark.asyncio
async def test_aemit_ignores_dependencies():
"""Test that aemit only processes async handlers (no dependency support yet)."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def sync_handler(source, event: DependsTestEvent):
execution_order.append("sync_handler")
@crewai_event_bus.on(DependsTestEvent)
async def async_handler(source, event: DependsTestEvent):
execution_order.append("async_handler")
event = DependsTestEvent(value=1)
await crewai_event_bus.aemit("test_source", event)
# Only async handler should execute
assert execution_order == ["async_handler"]

View File

@@ -1,3 +1,5 @@
import threading
import pytest import pytest
from crewai.agent import Agent from crewai.agent import Agent
from crewai.crew import Crew from crewai.crew import Crew
@@ -19,7 +21,10 @@ from crewai.experimental.evaluation import (
create_default_evaluator, create_default_evaluator,
) )
from crewai.experimental.evaluation.agent_evaluator import AgentEvaluator from crewai.experimental.evaluation.agent_evaluator import AgentEvaluator
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
BaseEvaluator,
)
from crewai.task import Task from crewai.task import Task
@@ -51,12 +56,25 @@ class TestAgentEvaluator:
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_evaluate_current_iteration(self, mock_crew): def test_evaluate_current_iteration(self, mock_crew):
from crewai.events.types.task_events import TaskCompletedEvent
agent_evaluator = AgentEvaluator( agent_evaluator = AgentEvaluator(
agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()] agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()]
) )
task_completed_event = threading.Event()
@crewai_event_bus.on(TaskCompletedEvent)
async def on_task_completed(source, event):
# TaskCompletedEvent fires AFTER evaluation results are stored
task_completed_event.set()
mock_crew.kickoff() mock_crew.kickoff()
assert task_completed_event.wait(timeout=5), (
"Timeout waiting for task completion"
)
results = agent_evaluator.get_evaluation_results() results = agent_evaluator.get_evaluation_results()
assert isinstance(results, dict) assert isinstance(results, dict)
@@ -98,73 +116,15 @@ class TestAgentEvaluator:
] ]
assert len(agent_evaluator.evaluators) == len(expected_types) assert len(agent_evaluator.evaluators) == len(expected_types)
for evaluator, expected_type in zip(agent_evaluator.evaluators, expected_types): for evaluator, expected_type in zip(
agent_evaluator.evaluators, expected_types, strict=False
):
assert isinstance(evaluator, expected_type) assert isinstance(evaluator, expected_type)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_eval_lite_agent(self):
agent = Agent(
role="Test Agent",
goal="Complete test tasks successfully",
backstory="An agent created for testing purposes",
)
with crewai_event_bus.scoped_handlers():
events = {}
@crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event):
events["started"] = event
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event):
events["completed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
)
agent.kickoff(messages="Complete this task successfully")
assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id is None
assert events["started"].iteration == 1
assert events["completed"].agent_id == str(agent.id)
assert events["completed"].agent_role == agent.role
assert events["completed"].task_id is None
assert events["completed"].iteration == 1
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 2.0
results = agent_evaluator.get_evaluation_results()
assert isinstance(results, dict)
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
assert result.agent_id == str(agent.id)
assert result.task_id == "lite_task"
(goal_alignment,) = result.metrics.values()
assert goal_alignment.score == 2.0
expected_feedback = "The agent did not demonstrate a clear understanding of the task goal, which is to complete test tasks successfully"
assert expected_feedback in goal_alignment.feedback
assert goal_alignment.raw_response is not None
assert '"score": 2' in goal_alignment.raw_response
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_eval_specific_agents_from_crew(self, mock_crew): def test_eval_specific_agents_from_crew(self, mock_crew):
from crewai.events.types.task_events import TaskCompletedEvent
agent = Agent( agent = Agent(
role="Test Agent Eval", role="Test Agent Eval",
goal="Complete test tasks successfully", goal="Complete test tasks successfully",
@@ -178,111 +138,132 @@ class TestAgentEvaluator:
mock_crew.agents.append(agent) mock_crew.agents.append(agent)
mock_crew.tasks.append(task) mock_crew.tasks.append(task)
with crewai_event_bus.scoped_handlers(): events = {}
events = {} started_event = threading.Event()
completed_event = threading.Event()
task_completed_event = threading.Event()
@crewai_event_bus.on(AgentEvaluationStartedEvent) agent_evaluator = AgentEvaluator(
def capture_started(source, event): agents=[agent], evaluators=[GoalAlignmentEvaluator()]
)
@crewai_event_bus.on(AgentEvaluationStartedEvent)
async def capture_started(source, event):
if event.agent_id == str(agent.id):
events["started"] = event events["started"] = event
started_event.set()
@crewai_event_bus.on(AgentEvaluationCompletedEvent) @crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event): async def capture_completed(source, event):
if event.agent_id == str(agent.id):
events["completed"] = event events["completed"] = event
completed_event.set()
@crewai_event_bus.on(AgentEvaluationFailedEvent) @crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event): def capture_failed(source, event):
events["failed"] = event events["failed"] = event
agent_evaluator = AgentEvaluator( @crewai_event_bus.on(TaskCompletedEvent)
agents=[agent], evaluators=[GoalAlignmentEvaluator()] async def on_task_completed(source, event):
) # TaskCompletedEvent fires AFTER evaluation results are stored
mock_crew.kickoff() if event.task and event.task.id == task.id:
task_completed_event.set()
assert events.keys() == {"started", "completed"} mock_crew.kickoff()
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
assert events["completed"].agent_id == str(agent.id) assert started_event.wait(timeout=5), "Timeout waiting for started event"
assert events["completed"].agent_role == agent.role assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
assert events["completed"].task_id == str(task.id) assert task_completed_event.wait(timeout=5), (
assert events["completed"].iteration == 1 "Timeout waiting for task completion"
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT )
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 5.0
results = agent_evaluator.get_evaluation_results() assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
assert isinstance(results, dict) assert events["completed"].agent_id == str(agent.id)
assert len(results.keys()) == 1 assert events["completed"].agent_role == agent.role
(result,) = results[agent.role] assert events["completed"].task_id == str(task.id)
assert isinstance(result, AgentEvaluationResult) assert events["completed"].iteration == 1
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 5.0
assert result.agent_id == str(agent.id) results = agent_evaluator.get_evaluation_results()
assert result.task_id == str(task.id)
(goal_alignment,) = result.metrics.values() assert isinstance(results, dict)
assert goal_alignment.score == 5.0 assert len(results.keys()) == 1
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output" assert result.agent_id == str(agent.id)
assert expected_feedback in goal_alignment.feedback assert result.task_id == str(task.id)
assert goal_alignment.raw_response is not None (goal_alignment,) = result.metrics.values()
assert '"score": 5' in goal_alignment.raw_response assert goal_alignment.score == 5.0
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
assert expected_feedback in goal_alignment.feedback
assert goal_alignment.raw_response is not None
assert '"score": 5' in goal_alignment.raw_response
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_failed_evaluation(self, mock_crew): def test_failed_evaluation(self, mock_crew):
(agent,) = mock_crew.agents (agent,) = mock_crew.agents
(task,) = mock_crew.tasks (task,) = mock_crew.tasks
with crewai_event_bus.scoped_handlers(): events = {}
events = {} started_event = threading.Event()
failed_event = threading.Event()
@crewai_event_bus.on(AgentEvaluationStartedEvent) @crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event): def capture_started(source, event):
events["started"] = event events["started"] = event
started_event.set()
@crewai_event_bus.on(AgentEvaluationCompletedEvent) @crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event): def capture_completed(source, event):
events["completed"] = event events["completed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent) @crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event): def capture_failed(source, event):
events["failed"] = event events["failed"] = event
failed_event.set()
# Create a mock evaluator that will raise an exception class FailingEvaluator(BaseEvaluator):
from crewai.experimental.evaluation import MetricCategory metric_category = MetricCategory.GOAL_ALIGNMENT
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator
class FailingEvaluator(BaseEvaluator): def evaluate(self, agent, task, execution_trace, final_output):
metric_category = MetricCategory.GOAL_ALIGNMENT raise ValueError("Forced evaluation failure")
def evaluate(self, agent, task, execution_trace, final_output): agent_evaluator = AgentEvaluator(
raise ValueError("Forced evaluation failure") agents=[agent], evaluators=[FailingEvaluator()]
)
mock_crew.kickoff()
agent_evaluator = AgentEvaluator( assert started_event.wait(timeout=5), "Timeout waiting for started event"
agents=[agent], evaluators=[FailingEvaluator()] assert failed_event.wait(timeout=5), "Timeout waiting for failed event"
)
mock_crew.kickoff()
assert events.keys() == {"started", "failed"} assert events.keys() == {"started", "failed"}
assert events["started"].agent_id == str(agent.id) assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id) assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1 assert events["started"].iteration == 1
assert events["failed"].agent_id == str(agent.id) assert events["failed"].agent_id == str(agent.id)
assert events["failed"].agent_role == agent.role assert events["failed"].agent_role == agent.role
assert events["failed"].task_id == str(task.id) assert events["failed"].task_id == str(task.id)
assert events["failed"].iteration == 1 assert events["failed"].iteration == 1
assert events["failed"].error == "Forced evaluation failure" assert events["failed"].error == "Forced evaluation failure"
results = agent_evaluator.get_evaluation_results() results = agent_evaluator.get_evaluation_results()
(result,) = results[agent.role] (result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult) assert isinstance(result, AgentEvaluationResult)
assert result.agent_id == str(agent.id) assert result.agent_id == str(agent.id)
assert result.task_id == str(task.id) assert result.task_id == str(task.id)
assert result.metrics == {} assert result.metrics == {}

View File

@@ -1,23 +1,36 @@
from unittest.mock import MagicMock, patch, ANY import threading
from collections import defaultdict from collections import defaultdict
from crewai.events.event_bus import crewai_event_bus from unittest.mock import ANY, MagicMock, patch
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
import pytest import pytest
from mem0.memory.main import Memory from mem0.memory.main import Memory
from crewai.agent import Agent from crewai.agent import Agent
from crewai.crew import Crew, Process from crewai.crew import Crew, Process
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
from crewai.task import Task from crewai.task import Task
@pytest.fixture(autouse=True)
def cleanup_event_handlers():
"""Cleanup event handlers after each test"""
yield
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
@pytest.fixture @pytest.fixture
def mock_mem0_memory(): def mock_mem0_memory():
mock_memory = MagicMock(spec=Memory) mock_memory = MagicMock(spec=Memory)
@@ -238,24 +251,26 @@ def test_external_memory_search_events(
custom_storage, external_memory_with_mocked_config custom_storage, external_memory_with_mocked_config
): ):
events = defaultdict(list) events = defaultdict(list)
event_received = threading.Event()
external_memory_with_mocked_config.storage = custom_storage external_memory_with_mocked_config.storage = custom_storage
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event): def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event) events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent) @crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event): def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event) events["MemoryQueryCompletedEvent"].append(event)
event_received.set()
external_memory_with_mocked_config.search( external_memory_with_mocked_config.search(
query="test value", query="test value",
limit=3, limit=3,
score_threshold=0.35, score_threshold=0.35,
) )
assert event_received.wait(timeout=5), "Timeout waiting for search events"
assert len(events["MemoryQueryStartedEvent"]) == 1 assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1 assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -300,24 +315,25 @@ def test_external_memory_save_events(
custom_storage, external_memory_with_mocked_config custom_storage, external_memory_with_mocked_config
): ):
events = defaultdict(list) events = defaultdict(list)
event_received = threading.Event()
external_memory_with_mocked_config.storage = custom_storage external_memory_with_mocked_config.storage = custom_storage
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveStartedEvent) @crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_started(source, event): def on_save_completed(source, event):
events["MemorySaveStartedEvent"].append(event) events["MemorySaveCompletedEvent"].append(event)
event_received.set()
@crewai_event_bus.on(MemorySaveCompletedEvent) external_memory_with_mocked_config.save(
def on_save_completed(source, event): value="saving value",
events["MemorySaveCompletedEvent"].append(event) metadata={"task": "test_task"},
)
external_memory_with_mocked_config.save(
value="saving value",
metadata={"task": "test_task"},
)
assert event_received.wait(timeout=5), "Timeout waiting for save events"
assert len(events["MemorySaveStartedEvent"]) == 1 assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1 assert len(events["MemorySaveCompletedEvent"]) == 1

View File

@@ -1,7 +1,9 @@
import threading
from collections import defaultdict from collections import defaultdict
from unittest.mock import ANY from unittest.mock import ANY
import pytest import pytest
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent, MemoryQueryCompletedEvent,
@@ -21,27 +23,37 @@ def long_term_memory():
def test_long_term_memory_save_events(long_term_memory): def test_long_term_memory_save_events(long_term_memory):
events = defaultdict(list) events = defaultdict(list)
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
if (
len(events["MemorySaveStartedEvent"]) == 1
and len(events["MemorySaveCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemorySaveStartedEvent) @crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_started(source, event): def on_save_completed(source, event):
events["MemorySaveStartedEvent"].append(event) events["MemorySaveCompletedEvent"].append(event)
if (
len(events["MemorySaveStartedEvent"]) == 1
and len(events["MemorySaveCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemorySaveCompletedEvent) memory = LongTermMemoryItem(
def on_save_completed(source, event): agent="test_agent",
events["MemorySaveCompletedEvent"].append(event) task="test_task",
expected_output="test_output",
memory = LongTermMemoryItem( datetime="test_datetime",
agent="test_agent", quality=0.5,
task="test_task", metadata={"task": "test_task", "quality": 0.5},
expected_output="test_output", )
datetime="test_datetime", long_term_memory.save(memory)
quality=0.5,
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
assert all_events_received.wait(timeout=5), "Timeout waiting for save events"
assert len(events["MemorySaveStartedEvent"]) == 1 assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1 assert len(events["MemorySaveCompletedEvent"]) == 1
assert len(events["MemorySaveFailedEvent"]) == 0 assert len(events["MemorySaveFailedEvent"]) == 0
@@ -86,21 +98,31 @@ def test_long_term_memory_save_events(long_term_memory):
def test_long_term_memory_search_events(long_term_memory): def test_long_term_memory_search_events(long_term_memory):
events = defaultdict(list) events = defaultdict(list)
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
if (
len(events["MemoryQueryStartedEvent"]) == 1
and len(events["MemoryQueryCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_started(source, event): def on_search_completed(source, event):
events["MemoryQueryStartedEvent"].append(event) events["MemoryQueryCompletedEvent"].append(event)
if (
len(events["MemoryQueryStartedEvent"]) == 1
and len(events["MemoryQueryCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent) test_query = "test query"
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
test_query = "test query" long_term_memory.search(test_query, latest_n=5)
long_term_memory.search(test_query, latest_n=5)
assert all_events_received.wait(timeout=5), "Timeout waiting for search events"
assert len(events["MemoryQueryStartedEvent"]) == 1 assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1 assert len(events["MemoryQueryCompletedEvent"]) == 1
assert len(events["MemoryQueryFailedEvent"]) == 0 assert len(events["MemoryQueryFailedEvent"]) == 0

View File

@@ -1,3 +1,4 @@
import threading
from collections import defaultdict from collections import defaultdict
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
@@ -37,24 +38,33 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory): def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list) events = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
with patch.object(short_term_memory.storage, "search", return_value=[]): with patch.object(short_term_memory.storage, "search", return_value=[]):
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event): def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event) events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent) @crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event): def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event) events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
# Call the save method short_term_memory.search(
short_term_memory.search( query="test value",
query="test value", limit=3,
limit=3, score_threshold=0.35,
score_threshold=0.35, )
)
assert search_started.wait(timeout=2), (
"Timeout waiting for search started event"
)
assert search_completed.wait(timeout=2), (
"Timeout waiting for search completed event"
)
assert len(events["MemoryQueryStartedEvent"]) == 1 assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1 assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -98,20 +108,26 @@ def test_short_term_memory_search_events(short_term_memory):
def test_short_term_memory_save_events(short_term_memory): def test_short_term_memory_save_events(short_term_memory):
events = defaultdict(list) events = defaultdict(list)
with crewai_event_bus.scoped_handlers(): save_started = threading.Event()
save_completed = threading.Event()
@crewai_event_bus.on(MemorySaveStartedEvent) @crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event): def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event) events["MemorySaveStartedEvent"].append(event)
save_started.set()
@crewai_event_bus.on(MemorySaveCompletedEvent) @crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event): def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event) events["MemorySaveCompletedEvent"].append(event)
save_completed.set()
short_term_memory.save( short_term_memory.save(
value="test value", value="test value",
metadata={"task": "test_task"}, metadata={"task": "test_task"},
) )
assert save_started.wait(timeout=2), "Timeout waiting for save started event"
assert save_completed.wait(timeout=2), "Timeout waiting for save completed event"
assert len(events["MemorySaveStartedEvent"]) == 1 assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1 assert len(events["MemorySaveCompletedEvent"]) == 1

View File

@@ -1,9 +1,10 @@
"""Test Agent creation and execution basic functionality.""" """Test Agent creation and execution basic functionality."""
import json
import threading
from collections import defaultdict from collections import defaultdict
from concurrent.futures import Future from concurrent.futures import Future
from hashlib import md5 from hashlib import md5
import json
import re import re
from unittest import mock from unittest import mock
from unittest.mock import ANY, MagicMock, patch from unittest.mock import ANY, MagicMock, patch
@@ -2476,62 +2477,63 @@ def test_using_contextual_memory():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_memory_events_are_emitted(): def test_memory_events_are_emitted():
events = defaultdict(list) events = defaultdict(list)
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MemorySaveStartedEvent)
def handle_memory_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveStartedEvent) @crewai_event_bus.on(MemorySaveCompletedEvent)
def handle_memory_save_started(source, event): def handle_memory_save_completed(source, event):
events["MemorySaveStartedEvent"].append(event) events["MemorySaveCompletedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent) @crewai_event_bus.on(MemorySaveFailedEvent)
def handle_memory_save_completed(source, event): def handle_memory_save_failed(source, event):
events["MemorySaveCompletedEvent"].append(event) events["MemorySaveFailedEvent"].append(event)
@crewai_event_bus.on(MemorySaveFailedEvent) @crewai_event_bus.on(MemoryQueryStartedEvent)
def handle_memory_save_failed(source, event): def handle_memory_query_started(source, event):
events["MemorySaveFailedEvent"].append(event) events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryCompletedEvent)
def handle_memory_query_started(source, event): def handle_memory_query_completed(source, event):
events["MemoryQueryStartedEvent"].append(event) events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent) @crewai_event_bus.on(MemoryQueryFailedEvent)
def handle_memory_query_completed(source, event): def handle_memory_query_failed(source, event):
events["MemoryQueryCompletedEvent"].append(event) events["MemoryQueryFailedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryFailedEvent) @crewai_event_bus.on(MemoryRetrievalStartedEvent)
def handle_memory_query_failed(source, event): def handle_memory_retrieval_started(source, event):
events["MemoryQueryFailedEvent"].append(event) events["MemoryRetrievalStartedEvent"].append(event)
@crewai_event_bus.on(MemoryRetrievalStartedEvent) @crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def handle_memory_retrieval_started(source, event): def handle_memory_retrieval_completed(source, event):
events["MemoryRetrievalStartedEvent"].append(event) events["MemoryRetrievalCompletedEvent"].append(event)
event_received.set()
@crewai_event_bus.on(MemoryRetrievalCompletedEvent) math_researcher = Agent(
def handle_memory_retrieval_completed(source, event): role="Researcher",
events["MemoryRetrievalCompletedEvent"].append(event) goal="You research about math.",
backstory="You're an expert in research and you love to learn new things.",
allow_delegation=False,
)
math_researcher = Agent( task1 = Task(
role="Researcher", description="Research a topic to teach a kid aged 6 about math.",
goal="You research about math.", expected_output="A topic, explanation, angle, and examples.",
backstory="You're an expert in research and you love to learn new things.", agent=math_researcher,
allow_delegation=False, )
)
task1 = Task( crew = Crew(
description="Research a topic to teach a kid aged 6 about math.", agents=[math_researcher],
expected_output="A topic, explanation, angle, and examples.", tasks=[task1],
agent=math_researcher, memory=True,
) )
crew = Crew( crew.kickoff()
agents=[math_researcher],
tasks=[task1],
memory=True,
)
crew.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for memory events"
assert len(events["MemorySaveStartedEvent"]) == 3 assert len(events["MemorySaveStartedEvent"]) == 3
assert len(events["MemorySaveCompletedEvent"]) == 3 assert len(events["MemorySaveCompletedEvent"]) == 3
assert len(events["MemorySaveFailedEvent"]) == 0 assert len(events["MemorySaveFailedEvent"]) == 0
@@ -2907,19 +2909,29 @@ def test_crew_train_success(
copy_mock.return_value = crew copy_mock.return_value = crew
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(CrewTrainStartedEvent) @crewai_event_bus.on(CrewTrainStartedEvent)
def on_crew_train_started(source, event: CrewTrainStartedEvent): def on_crew_train_started(source, event: CrewTrainStartedEvent):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
@crewai_event_bus.on(CrewTrainCompletedEvent) @crewai_event_bus.on(CrewTrainCompletedEvent)
def on_crew_train_completed(source, event: CrewTrainCompletedEvent): def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
crew.train( crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl" n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
) )
assert all_events_received.wait(timeout=5), "Timeout waiting for all train events"
# Ensure kickoff is called on the copied crew # Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls( kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})] [mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
@@ -3726,17 +3738,27 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator, research
llm_instance = LLM("gpt-4o-mini") llm_instance = LLM("gpt-4o-mini")
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(CrewTestStartedEvent) @crewai_event_bus.on(CrewTestStartedEvent)
def on_crew_test_started(source, event: CrewTestStartedEvent): def on_crew_test_started(source, event: CrewTestStartedEvent):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
@crewai_event_bus.on(CrewTestCompletedEvent) @crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(source, event: CrewTestCompletedEvent): def on_crew_test_completed(source, event: CrewTestCompletedEvent):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
crew.test(n_iterations, llm_instance, inputs={"topic": "AI"}) crew.test(n_iterations, llm_instance, inputs={"topic": "AI"})
assert all_events_received.wait(timeout=5), "Timeout waiting for all test events"
# Ensure kickoff is called on the copied crew # Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls( kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})] [mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]

View File

@@ -1,9 +1,12 @@
"""Test Flow creation and execution basic functionality.""" """Test Flow creation and execution basic functionality."""
import asyncio import asyncio
import threading
from datetime import datetime from datetime import datetime
import pytest import pytest
from pydantic import BaseModel
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import ( from crewai.events.types.flow_events import (
FlowFinishedEvent, FlowFinishedEvent,
@@ -13,7 +16,6 @@ from crewai.events.types.flow_events import (
MethodExecutionStartedEvent, MethodExecutionStartedEvent,
) )
from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.flow import Flow, and_, listen, or_, router, start
from pydantic import BaseModel
def test_simple_sequential_flow(): def test_simple_sequential_flow():
@@ -439,20 +441,42 @@ def test_unstructured_flow_event_emission():
flow = PoemFlow() flow = PoemFlow()
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = (
7 # 1 FlowStarted + 5 MethodExecutionStarted + 1 FlowFinished
)
@crewai_event_bus.on(FlowStartedEvent) @crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event): def handle_flow_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent) @crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event): def handle_method_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent) @crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event): def handle_flow_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff(inputs={"separator": ", "}) flow.kickoff(inputs={"separator": ", "})
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
# (async handlers may append out of order)
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent) assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "PoemFlow" assert received_events[0].flow_name == "PoemFlow"
assert received_events[0].inputs == {"separator": ", "} assert received_events[0].inputs == {"separator": ", "}
@@ -642,28 +666,48 @@ def test_structured_flow_event_emission():
return f"Welcome, {self.state.name}!" return f"Welcome, {self.state.name}!"
flow = OnboardingFlow() flow = OnboardingFlow()
flow.kickoff(inputs={"name": "Anakin"})
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
@crewai_event_bus.on(FlowStartedEvent) @crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event): def handle_flow_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent) @crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event): def handle_method_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionFinishedEvent) @crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event): def handle_method_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent) @crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event): def handle_flow_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff(inputs={"name": "Anakin"}) flow.kickoff(inputs={"name": "Anakin"})
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent) assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "OnboardingFlow" assert received_events[0].flow_name == "OnboardingFlow"
assert received_events[0].inputs == {"name": "Anakin"} assert received_events[0].inputs == {"name": "Anakin"}
@@ -711,25 +755,46 @@ def test_stateless_flow_event_emission():
flow = StatelessFlow() flow = StatelessFlow()
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
@crewai_event_bus.on(FlowStartedEvent) @crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event): def handle_flow_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent) @crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event): def handle_method_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionFinishedEvent) @crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event): def handle_method_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent) @crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event): def handle_flow_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff() flow.kickoff()
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent) assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "StatelessFlow" assert received_events[0].flow_name == "StatelessFlow"
assert received_events[0].inputs is None assert received_events[0].inputs is None
@@ -769,13 +834,16 @@ def test_flow_plotting():
flow = StatelessFlow() flow = StatelessFlow()
flow.kickoff() flow.kickoff()
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(FlowPlotEvent) @crewai_event_bus.on(FlowPlotEvent)
def handle_flow_plot(source, event): def handle_flow_plot(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
flow.plot("test_flow") flow.plot("test_flow")
assert event_received.wait(timeout=5), "Timeout waiting for plot event"
assert len(received_events) == 1 assert len(received_events) == 1
assert isinstance(received_events[0], FlowPlotEvent) assert isinstance(received_events[0], FlowPlotEvent)
assert received_events[0].flow_name == "StatelessFlow" assert received_events[0].flow_name == "StatelessFlow"

View File

@@ -1,3 +1,4 @@
import threading
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@@ -175,78 +176,92 @@ def test_task_guardrail_process_output(task_output):
def test_guardrail_emits_events(sample_agent): def test_guardrail_emits_events(sample_agent):
started_guardrail = [] started_guardrail = []
completed_guardrail = [] completed_guardrail = []
all_events_received = threading.Event()
expected_started = 3 # 2 from first task, 1 from second
expected_completed = 3 # 2 from first task, 1 from second
task = Task( task1 = Task(
description="Gather information about available books on the First World War", description="Gather information about available books on the First World War",
agent=sample_agent, agent=sample_agent,
expected_output="A list of available books on the First World War", expected_output="A list of available books on the First World War",
guardrail="Ensure the authors are from Italy", guardrail="Ensure the authors are from Italy",
) )
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
@crewai_event_bus.on(LLMGuardrailStartedEvent) started_guardrail.append(
def handle_guardrail_started(source, event): {"guardrail": event.guardrail, "retry_count": event.retry_count}
assert source == task
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
assert source == task
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
result = task.execute_sync(agent=sample_agent)
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
) )
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
task.execute_sync(agent=sample_agent) @crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
expected_started_events = [ result = task1.execute_sync(agent=sample_agent)
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
{
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
expected_completed_events = [ def custom_guardrail(result: TaskOutput):
{ return (True, "good result from callable function")
"success": False,
"result": None, task2 = Task(
"error": "The task result does not comply with the guardrail because none of " description="Test task",
"the listed authors are from Italy. All authors mentioned are from " expected_output="Output",
"different countries, including Germany, the UK, the USA, and others, " guardrail=custom_guardrail,
"which violates the requirement that authors must be Italian.", )
"retry_count": 0,
}, task2.execute_sync(agent=sample_agent)
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{ # Wait for all events to be received
"success": True, assert all_events_received.wait(timeout=10), (
"result": "good result from callable function", "Timeout waiting for all guardrail events"
"error": None, )
"retry_count": 0,
}, expected_started_events = [
] {"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
assert started_guardrail == expected_started_events {"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
assert completed_guardrail == expected_completed_events {
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
expected_completed_events = [
{
"success": False,
"result": None,
"error": "The task result does not comply with the guardrail because none of "
"the listed authors are from Italy. All authors mentioned are from "
"different countries, including Germany, the UK, the USA, and others, "
"which violates the requirement that authors must be Italian.",
"retry_count": 0,
},
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{
"success": True,
"result": "good result from callable function",
"error": None,
"retry_count": 0,
},
]
assert started_guardrail == expected_started_events
assert completed_guardrail == expected_completed_events
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])

View File

@@ -1,6 +1,7 @@
import datetime import datetime
import json import json
import random import random
import threading
import time import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -32,7 +33,7 @@ class RandomNumberTool(BaseTool):
args_schema: type[BaseModel] = RandomNumberToolInput args_schema: type[BaseModel] = RandomNumberToolInput
def _run(self, min_value: int, max_value: int) -> int: def _run(self, min_value: int, max_value: int) -> int:
return random.randint(min_value, max_value) return random.randint(min_value, max_value) # noqa: S311
# Example agent and task # Example agent and task
@@ -470,13 +471,21 @@ def test_tool_selection_error_event_direct():
) )
received_events = [] received_events = []
first_event_received = threading.Event()
second_event_received = threading.Event()
@crewai_event_bus.on(ToolSelectionErrorEvent) @crewai_event_bus.on(ToolSelectionErrorEvent)
def event_handler(source, event): def event_handler(source, event):
received_events.append(event) received_events.append(event)
if event.tool_name == "Non Existent Tool":
first_event_received.set()
elif event.tool_name == "":
second_event_received.set()
with pytest.raises(Exception): with pytest.raises(Exception): # noqa: B017
tool_usage._select_tool("Non Existent Tool") tool_usage._select_tool("Non Existent Tool")
assert first_event_received.wait(timeout=5), "Timeout waiting for first event"
assert len(received_events) == 1 assert len(received_events) == 1
event = received_events[0] event = received_events[0]
assert isinstance(event, ToolSelectionErrorEvent) assert isinstance(event, ToolSelectionErrorEvent)
@@ -488,12 +497,12 @@ def test_tool_selection_error_event_direct():
assert "A test tool" in event.tool_class assert "A test tool" in event.tool_class
assert "don't exist" in event.error assert "don't exist" in event.error
received_events.clear() with pytest.raises(Exception): # noqa: B017
with pytest.raises(Exception):
tool_usage._select_tool("") tool_usage._select_tool("")
assert len(received_events) == 1 assert second_event_received.wait(timeout=5), "Timeout waiting for second event"
event = received_events[0] assert len(received_events) == 2
event = received_events[1]
assert isinstance(event, ToolSelectionErrorEvent) assert isinstance(event, ToolSelectionErrorEvent)
assert event.agent_key == "test_key" assert event.agent_key == "test_key"
assert event.agent_role == "test_role" assert event.agent_role == "test_role"
@@ -562,7 +571,7 @@ def test_tool_validate_input_error_event():
# Test invalid input # Test invalid input
invalid_input = "invalid json {[}" invalid_input = "invalid json {[}"
with pytest.raises(Exception): with pytest.raises(Exception): # noqa: B017
tool_usage._validate_tool_input(invalid_input) tool_usage._validate_tool_input(invalid_input)
# Verify event was emitted # Verify event was emitted
@@ -616,12 +625,13 @@ def test_tool_usage_finished_event_with_result():
action=MagicMock(), action=MagicMock(),
) )
# Track received events
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event): def event_handler(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
# Call on_tool_use_finished with test data # Call on_tool_use_finished with test data
started_at = time.time() started_at = time.time()
@@ -634,7 +644,7 @@ def test_tool_usage_finished_event_with_result():
result=result, result=result,
) )
# Verify event was emitted assert event_received.wait(timeout=5), "Timeout waiting for event"
assert len(received_events) == 1, "Expected one event to be emitted" assert len(received_events) == 1, "Expected one event to be emitted"
event = received_events[0] event = received_events[0]
assert isinstance(event, ToolUsageFinishedEvent) assert isinstance(event, ToolUsageFinishedEvent)
@@ -695,12 +705,13 @@ def test_tool_usage_finished_event_with_cached_result():
action=MagicMock(), action=MagicMock(),
) )
# Track received events
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event): def event_handler(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
# Call on_tool_use_finished with test data and from_cache=True # Call on_tool_use_finished with test data and from_cache=True
started_at = time.time() started_at = time.time()
@@ -713,7 +724,7 @@ def test_tool_usage_finished_event_with_cached_result():
result=result, result=result,
) )
# Verify event was emitted assert event_received.wait(timeout=5), "Timeout waiting for event"
assert len(received_events) == 1, "Expected one event to be emitted" assert len(received_events) == 1, "Expected one event to be emitted"
event = received_events[0] event = received_events[0]
assert isinstance(event, ToolUsageFinishedEvent) assert isinstance(event, ToolUsageFinishedEvent)

View File

@@ -14,6 +14,7 @@ from crewai.events.listeners.tracing.trace_listener import (
) )
from crewai.events.listeners.tracing.types import TraceEvent from crewai.events.listeners.tracing.types import TraceEvent
from crewai.flow.flow import Flow, start from crewai.flow.flow import Flow, start
from tests.utils import wait_for_event_handlers
class TestTraceListenerSetup: class TestTraceListenerSetup:
@@ -39,38 +40,44 @@ class TestTraceListenerSetup:
): ):
yield yield
@pytest.fixture(autouse=True)
def clear_event_bus(self):
"""Clear event bus listeners before and after each test"""
from crewai.events.event_bus import crewai_event_bus
# Store original handlers
original_handlers = crewai_event_bus._handlers.copy()
# Clear for test
crewai_event_bus._handlers.clear()
yield
# Restore original state
crewai_event_bus._handlers.clear()
crewai_event_bus._handlers.update(original_handlers)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_tracing_singletons(self): def reset_tracing_singletons(self):
"""Reset tracing singleton instances between tests""" """Reset tracing singleton instances between tests"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
# Clear event bus handlers BEFORE creating any new singletons
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset TraceCollectionListener singleton # Reset TraceCollectionListener singleton
if hasattr(TraceCollectionListener, "_instance"): if hasattr(TraceCollectionListener, "_instance"):
TraceCollectionListener._instance = None TraceCollectionListener._instance = None
TraceCollectionListener._initialized = False TraceCollectionListener._initialized = False
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
yield yield
# Clean up after test # Clean up after test
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
if hasattr(TraceCollectionListener, "_instance"): if hasattr(TraceCollectionListener, "_instance"):
TraceCollectionListener._instance = None TraceCollectionListener._instance = None
TraceCollectionListener._initialized = False TraceCollectionListener._initialized = False
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_plus_api_calls(self): def mock_plus_api_calls(self):
"""Mock all PlusAPI HTTP calls to avoid network requests""" """Mock all PlusAPI HTTP calls to avoid network requests"""
@@ -167,15 +174,26 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
trace_listener = None trace_listener = None
for handler_list in crewai_event_bus._handlers.values(): with crewai_event_bus._rwlock.r_locked():
for handler in handler_list: for handler_set in crewai_event_bus._sync_handlers.values():
if hasattr(handler, "__self__") and isinstance( for handler in handler_set:
handler.__self__, TraceCollectionListener if hasattr(handler, "__self__") and isinstance(
): handler.__self__, TraceCollectionListener
trace_listener = handler.__self__ ):
trace_listener = handler.__self__
break
if trace_listener:
break break
if trace_listener: if not trace_listener:
break for handler_set in crewai_event_bus._async_handlers.values():
for handler in handler_set:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_listener = handler.__self__
break
if trace_listener:
break
if not trace_listener: if not trace_listener:
pytest.skip( pytest.skip(
@@ -221,6 +239,7 @@ class TestTraceListenerSetup:
wraps=trace_listener.batch_manager.add_event, wraps=trace_listener.batch_manager.add_event,
) as add_event_mock: ) as add_event_mock:
crew.kickoff() crew.kickoff()
wait_for_event_handlers()
assert add_event_mock.call_count >= 2 assert add_event_mock.call_count >= 2
@@ -267,24 +286,22 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
trace_handlers = [] trace_handlers = []
for handlers in crewai_event_bus._handlers.values(): with crewai_event_bus._rwlock.r_locked():
for handler in handlers: for handlers in crewai_event_bus._sync_handlers.values():
if hasattr(handler, "__self__") and isinstance( for handler in handlers:
handler.__self__, TraceCollectionListener if hasattr(handler, "__self__") and isinstance(
): handler.__self__, TraceCollectionListener
trace_handlers.append(handler) ):
elif hasattr(handler, "__name__") and any( trace_handlers.append(handler)
trace_name in handler.__name__ for handlers in crewai_event_bus._async_handlers.values():
for trace_name in [ for handler in handlers:
"on_crew_started", if hasattr(handler, "__self__") and isinstance(
"on_crew_completed", handler.__self__, TraceCollectionListener
"on_flow_started", ):
] trace_handlers.append(handler)
):
trace_handlers.append(handler)
assert len(trace_handlers) == 0, ( assert len(trace_handlers) == 0, (
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled" f"Found {len(trace_handlers)} TraceCollectionListener handlers when tracing should be disabled"
) )
def test_trace_listener_setup_correctly_for_crew(self): def test_trace_listener_setup_correctly_for_crew(self):
@@ -385,6 +402,7 @@ class TestTraceListenerSetup:
): ):
crew = Crew(agents=[agent], tasks=[task], tracing=True) crew = Crew(agents=[agent], tasks=[task], tracing=True)
crew.kickoff() crew.kickoff()
wait_for_event_handlers()
mock_plus_api_class.assert_called_with(api_key="mock_token_12345") mock_plus_api_class.assert_called_with(api_key="mock_token_12345")
@@ -396,15 +414,33 @@ class TestTraceListenerSetup:
def teardown_method(self): def teardown_method(self):
"""Cleanup after each test method""" """Cleanup after each test method"""
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
crewai_event_bus._handlers.clear() with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls):
"""Final cleanup after all tests in this class""" """Final cleanup after all tests in this class"""
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
crewai_event_bus._handlers.clear() with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls): def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
@@ -466,6 +502,7 @@ class TestTraceListenerSetup:
) as mock_add_event, ) as mock_add_event,
): ):
result = crew.kickoff() result = crew.kickoff()
wait_for_event_handlers()
assert result is not None assert result is not None
assert mock_handle_completion.call_count >= 1 assert mock_handle_completion.call_count >= 1
@@ -543,6 +580,7 @@ class TestTraceListenerSetup:
) )
crew.kickoff() crew.kickoff()
wait_for_event_handlers()
assert mock_handle_completion.call_count >= 1, ( assert mock_handle_completion.call_count >= 1, (
"handle_execution_completion should be called" "handle_execution_completion should be called"
@@ -561,7 +599,6 @@ class TestTraceListenerSetup:
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls): def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
"""Test the consolidation logic for first-time users vs regular tracing""" """Test the consolidation logic for first-time users vs regular tracing"""
with ( with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}), patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch( patch(
@@ -579,7 +616,9 @@ class TestTraceListenerSetup:
): ):
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear() with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
trace_listener = TraceCollectionListener() trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus) trace_listener.setup_listeners(crewai_event_bus)
@@ -600,6 +639,9 @@ class TestTraceListenerSetup:
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize: with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
result = crew.kickoff() result = crew.kickoff()
assert trace_listener.batch_manager.wait_for_pending_events(timeout=5.0), (
"Timeout waiting for trace event handlers to complete"
)
assert mock_initialize.call_count >= 1 assert mock_initialize.call_count >= 1
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
assert result is not None assert result is not None
@@ -700,6 +742,7 @@ class TestTraceListenerSetup:
) as mock_mark_failed, ) as mock_mark_failed,
): ):
crew.kickoff() crew.kickoff()
wait_for_event_handlers()
mock_mark_failed.assert_called_once() mock_mark_failed.assert_called_once()
call_args = mock_mark_failed.call_args_list[0] call_args = mock_mark_failed.call_args_list[0]

View File

@@ -0,0 +1,206 @@
"""Tests for async event handling in CrewAI event bus.
This module tests async handler registration, execution, and the aemit method.
"""
import asyncio
import pytest
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import crewai_event_bus
class AsyncTestEvent(BaseEvent):
pass
@pytest.mark.asyncio
async def test_async_handler_execution():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
event = AsyncTestEvent(type="async_test")
crewai_event_bus.emit("test_source", event)
await asyncio.sleep(0.1)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_aemit_with_async_handlers():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
event = AsyncTestEvent(type="async_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_multiple_async_handlers():
received_events_1 = []
received_events_2 = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def handler_1(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events_1.append(event)
@crewai_event_bus.on(AsyncTestEvent)
async def handler_2(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.02)
received_events_2.append(event)
event = AsyncTestEvent(type="async_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events_1) == 1
assert len(received_events_2) == 1
@pytest.mark.asyncio
async def test_mixed_sync_and_async_handlers():
sync_events = []
async_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
def sync_handler(source: object, event: BaseEvent) -> None:
sync_events.append(event)
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
async_events.append(event)
event = AsyncTestEvent(type="mixed_test")
crewai_event_bus.emit("test_source", event)
await asyncio.sleep(0.1)
assert len(sync_events) == 1
assert len(async_events) == 1
@pytest.mark.asyncio
async def test_async_handler_error_handling():
successful_handler_called = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def failing_handler(source: object, event: BaseEvent) -> None:
raise ValueError("Async handler error")
@crewai_event_bus.on(AsyncTestEvent)
async def successful_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
successful_handler_called.append(True)
event = AsyncTestEvent(type="error_test")
await crewai_event_bus.aemit("test_source", event)
assert len(successful_handler_called) == 1
@pytest.mark.asyncio
async def test_aemit_with_no_handlers():
with crewai_event_bus.scoped_handlers():
event = AsyncTestEvent(type="no_handlers")
await crewai_event_bus.aemit("test_source", event)
@pytest.mark.asyncio
async def test_async_handler_registration_via_register_handler():
received_events = []
with crewai_event_bus.scoped_handlers():
async def custom_async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
crewai_event_bus.register_handler(AsyncTestEvent, custom_async_handler)
event = AsyncTestEvent(type="register_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_emit_async_handlers_fire_and_forget():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def slow_async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.05)
received_events.append(event)
event = AsyncTestEvent(type="fire_forget_test")
crewai_event_bus.emit("test_source", event)
assert len(received_events) == 0
await asyncio.sleep(0.1)
assert len(received_events) == 1
@pytest.mark.asyncio
async def test_scoped_handlers_with_async():
received_before = []
received_during = []
received_after = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def before_handler(source: object, event: BaseEvent) -> None:
received_before.append(event)
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def scoped_handler(source: object, event: BaseEvent) -> None:
received_during.append(event)
event1 = AsyncTestEvent(type="during_scope")
await crewai_event_bus.aemit("test_source", event1)
assert len(received_before) == 0
assert len(received_during) == 1
@crewai_event_bus.on(AsyncTestEvent)
async def after_handler(source: object, event: BaseEvent) -> None:
received_after.append(event)
event2 = AsyncTestEvent(type="after_scope")
await crewai_event_bus.aemit("test_source", event2)
assert len(received_before) == 1
assert len(received_during) == 1
assert len(received_after) == 1

View File

@@ -1,3 +1,4 @@
import threading
from unittest.mock import Mock from unittest.mock import Mock
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
@@ -21,27 +22,42 @@ def test_specific_event_handler():
mock_handler.assert_called_once_with("source_object", event) mock_handler.assert_called_once_with("source_object", event)
def test_wildcard_event_handler(): def test_multiple_handlers_same_event():
mock_handler = Mock() """Test that multiple handlers can be registered for the same event type."""
mock_handler1 = Mock()
mock_handler2 = Mock()
@crewai_event_bus.on(BaseEvent) @crewai_event_bus.on(TestEvent)
def handler(source, event): def handler1(source, event):
mock_handler(source, event) mock_handler1(source, event)
@crewai_event_bus.on(TestEvent)
def handler2(source, event):
mock_handler2(source, event)
event = TestEvent(type="test_event") event = TestEvent(type="test_event")
crewai_event_bus.emit("source_object", event) crewai_event_bus.emit("source_object", event)
mock_handler.assert_called_once_with("source_object", event) mock_handler1.assert_called_once_with("source_object", event)
mock_handler2.assert_called_once_with("source_object", event)
def test_event_bus_error_handling(capfd): def test_event_bus_error_handling():
@crewai_event_bus.on(BaseEvent) """Test that handler exceptions are caught and don't break the event bus."""
called = threading.Event()
error_caught = threading.Event()
@crewai_event_bus.on(TestEvent)
def broken_handler(source, event): def broken_handler(source, event):
called.set()
raise ValueError("Simulated handler failure") raise ValueError("Simulated handler failure")
@crewai_event_bus.on(TestEvent)
def working_handler(source, event):
error_caught.set()
event = TestEvent(type="test_event") event = TestEvent(type="test_event")
crewai_event_bus.emit("source_object", event) crewai_event_bus.emit("source_object", event)
out, err = capfd.readouterr() assert called.wait(timeout=2), "Broken handler was never called"
assert "Simulated handler failure" in out assert error_caught.wait(timeout=2), "Working handler was never called after error"
assert "Handler 'broken_handler' failed" in out

View File

@@ -0,0 +1,264 @@
"""Tests for read-write lock implementation.
This module tests the RWLock class for correct concurrent read and write behavior.
"""
import threading
import time
from crewai.events.utils.rw_lock import RWLock
def test_multiple_readers_concurrent():
lock = RWLock()
active_readers = [0]
max_concurrent_readers = [0]
lock_for_counters = threading.Lock()
def reader(reader_id: int) -> None:
with lock.r_locked():
with lock_for_counters:
active_readers[0] += 1
max_concurrent_readers[0] = max(
max_concurrent_readers[0], active_readers[0]
)
time.sleep(0.1)
with lock_for_counters:
active_readers[0] -= 1
threads = [threading.Thread(target=reader, args=(i,)) for i in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert max_concurrent_readers[0] == 5
def test_writer_blocks_readers():
lock = RWLock()
writer_holding_lock = [False]
reader_accessed_during_write = [False]
def writer() -> None:
with lock.w_locked():
writer_holding_lock[0] = True
time.sleep(0.2)
writer_holding_lock[0] = False
def reader() -> None:
time.sleep(0.05)
with lock.r_locked():
if writer_holding_lock[0]:
reader_accessed_during_write[0] = True
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert not reader_accessed_during_write[0]
def test_writer_blocks_other_writers():
lock = RWLock()
execution_order: list[int] = []
lock_for_order = threading.Lock()
def writer(writer_id: int) -> None:
with lock.w_locked():
with lock_for_order:
execution_order.append(writer_id)
time.sleep(0.1)
threads = [threading.Thread(target=writer, args=(i,)) for i in range(3)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(execution_order) == 3
assert len(set(execution_order)) == 3
def test_readers_block_writers():
lock = RWLock()
reader_count = [0]
writer_accessed_during_read = [False]
lock_for_counters = threading.Lock()
def reader() -> None:
with lock.r_locked():
with lock_for_counters:
reader_count[0] += 1
time.sleep(0.2)
with lock_for_counters:
reader_count[0] -= 1
def writer() -> None:
time.sleep(0.05)
with lock.w_locked():
with lock_for_counters:
if reader_count[0] > 0:
writer_accessed_during_read[0] = True
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join()
writer_thread.join()
assert not writer_accessed_during_read[0]
def test_alternating_readers_and_writers():
lock = RWLock()
operations: list[str] = []
lock_for_operations = threading.Lock()
def reader(reader_id: int) -> None:
with lock.r_locked():
with lock_for_operations:
operations.append(f"r{reader_id}_start")
time.sleep(0.05)
with lock_for_operations:
operations.append(f"r{reader_id}_end")
def writer(writer_id: int) -> None:
with lock.w_locked():
with lock_for_operations:
operations.append(f"w{writer_id}_start")
time.sleep(0.05)
with lock_for_operations:
operations.append(f"w{writer_id}_end")
threads = [
threading.Thread(target=reader, args=(0,)),
threading.Thread(target=writer, args=(0,)),
threading.Thread(target=reader, args=(1,)),
threading.Thread(target=writer, args=(1,)),
threading.Thread(target=reader, args=(2,)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(operations) == 10
start_ops = [op for op in operations if "_start" in op]
end_ops = [op for op in operations if "_end" in op]
assert len(start_ops) == 5
assert len(end_ops) == 5
def test_context_manager_releases_on_exception():
lock = RWLock()
exception_raised = False
try:
with lock.r_locked():
raise ValueError("Test exception")
except ValueError:
exception_raised = True
assert exception_raised
acquired = False
with lock.w_locked():
acquired = True
assert acquired
def test_write_lock_releases_on_exception():
lock = RWLock()
exception_raised = False
try:
with lock.w_locked():
raise ValueError("Test exception")
except ValueError:
exception_raised = True
assert exception_raised
acquired = False
with lock.r_locked():
acquired = True
assert acquired
def test_stress_many_readers_few_writers():
lock = RWLock()
read_count = [0]
write_count = [0]
lock_for_counters = threading.Lock()
def reader() -> None:
for _ in range(10):
with lock.r_locked():
with lock_for_counters:
read_count[0] += 1
time.sleep(0.001)
def writer() -> None:
for _ in range(5):
with lock.w_locked():
with lock_for_counters:
write_count[0] += 1
time.sleep(0.01)
reader_threads = [threading.Thread(target=reader) for _ in range(10)]
writer_threads = [threading.Thread(target=writer) for _ in range(2)]
all_threads = reader_threads + writer_threads
for thread in all_threads:
thread.start()
for thread in all_threads:
thread.join()
assert read_count[0] == 100
assert write_count[0] == 10
def test_nested_read_locks_same_thread():
lock = RWLock()
nested_acquired = False
with lock.r_locked():
with lock.r_locked():
nested_acquired = True
assert nested_acquired
def test_manual_acquire_release():
lock = RWLock()
lock.r_acquire()
lock.r_release()
lock.w_acquire()
lock.w_release()
with lock.r_locked():
pass

View File

@@ -0,0 +1,247 @@
"""Tests for event bus shutdown and cleanup behavior.
This module tests graceful shutdown, task completion, and cleanup operations.
"""
import asyncio
import threading
import time
import pytest
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus
class ShutdownTestEvent(BaseEvent):
pass
def test_shutdown_prevents_new_events():
bus = CrewAIEventsBus()
received_events = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
received_events.append(event)
bus._shutting_down = True
event = ShutdownTestEvent(type="after_shutdown")
bus.emit("test_source", event)
time.sleep(0.1)
assert len(received_events) == 0
bus._shutting_down = False
@pytest.mark.asyncio
async def test_aemit_during_shutdown():
bus = CrewAIEventsBus()
received_events = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
async def handler(source: object, event: BaseEvent) -> None:
received_events.append(event)
bus._shutting_down = True
event = ShutdownTestEvent(type="aemit_during_shutdown")
await bus.aemit("test_source", event)
assert len(received_events) == 0
bus._shutting_down = False
def test_shutdown_flag_prevents_emit():
bus = CrewAIEventsBus()
emitted_count = [0]
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
emitted_count[0] += 1
event1 = ShutdownTestEvent(type="before_shutdown")
bus.emit("test_source", event1)
time.sleep(0.1)
assert emitted_count[0] == 1
bus._shutting_down = True
event2 = ShutdownTestEvent(type="during_shutdown")
bus.emit("test_source", event2)
time.sleep(0.1)
assert emitted_count[0] == 1
bus._shutting_down = False
def test_concurrent_access_during_shutdown_flag():
bus = CrewAIEventsBus()
received_events = []
lock = threading.Lock()
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
def emit_events() -> None:
for i in range(10):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
time.sleep(0.01)
def set_shutdown_flag() -> None:
time.sleep(0.05)
bus._shutting_down = True
emit_thread = threading.Thread(target=emit_events)
shutdown_thread = threading.Thread(target=set_shutdown_flag)
emit_thread.start()
shutdown_thread.start()
emit_thread.join()
shutdown_thread.join()
time.sleep(0.2)
assert len(received_events) < 10
assert len(received_events) > 0
bus._shutting_down = False
@pytest.mark.asyncio
async def test_async_handlers_complete_before_shutdown_flag():
bus = CrewAIEventsBus()
completed_handlers = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.05)
if not bus._shutting_down:
completed_handlers.append(event)
for i in range(5):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
await asyncio.sleep(0.3)
assert len(completed_handlers) == 5
def test_scoped_handlers_cleanup():
bus = CrewAIEventsBus()
received_before = []
received_during = []
received_after = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def before_handler(source: object, event: BaseEvent) -> None:
received_before.append(event)
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def during_handler(source: object, event: BaseEvent) -> None:
received_during.append(event)
event1 = ShutdownTestEvent(type="during")
bus.emit("source", event1)
time.sleep(0.1)
assert len(received_before) == 0
assert len(received_during) == 1
event2 = ShutdownTestEvent(type="after_inner_scope")
bus.emit("source", event2)
time.sleep(0.1)
assert len(received_before) == 1
assert len(received_during) == 1
event3 = ShutdownTestEvent(type="after_outer_scope")
bus.emit("source", event3)
time.sleep(0.1)
assert len(received_before) == 1
assert len(received_during) == 1
assert len(received_after) == 0
def test_handler_registration_thread_safety():
bus = CrewAIEventsBus()
handlers_registered = [0]
lock = threading.Lock()
with bus.scoped_handlers():
def register_handlers() -> None:
for _ in range(20):
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
pass
with lock:
handlers_registered[0] += 1
time.sleep(0.001)
threads = [threading.Thread(target=register_handlers) for _ in range(3)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert handlers_registered[0] == 60
@pytest.mark.asyncio
async def test_mixed_sync_async_handler_execution():
bus = CrewAIEventsBus()
sync_executed = []
async_executed = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def sync_handler(source: object, event: BaseEvent) -> None:
time.sleep(0.01)
sync_executed.append(event)
@bus.on(ShutdownTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
async_executed.append(event)
for i in range(5):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
await asyncio.sleep(0.2)
assert len(sync_executed) == 5
assert len(async_executed) == 5

View File

@@ -0,0 +1,189 @@
"""Tests for thread safety in CrewAI event bus.
This module tests concurrent event emission and handler registration.
"""
import threading
import time
from collections.abc import Callable
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import crewai_event_bus
class ThreadSafetyTestEvent(BaseEvent):
pass
def test_concurrent_emit_from_multiple_threads():
received_events: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(ThreadSafetyTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
threads: list[threading.Thread] = []
num_threads = 10
events_per_thread = 10
def emit_events(thread_id: int) -> None:
for i in range(events_per_thread):
event = ThreadSafetyTestEvent(type=f"thread_{thread_id}_event_{i}")
crewai_event_bus.emit(f"source_{thread_id}", event)
for i in range(num_threads):
thread = threading.Thread(target=emit_events, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
time.sleep(0.5)
assert len(received_events) == num_threads * events_per_thread
def test_concurrent_handler_registration():
handlers_executed: list[int] = []
lock = threading.Lock()
def create_handler(handler_id: int) -> Callable[[object, BaseEvent], None]:
def handler(source: object, event: BaseEvent) -> None:
with lock:
handlers_executed.append(handler_id)
return handler
with crewai_event_bus.scoped_handlers():
threads: list[threading.Thread] = []
num_handlers = 20
def register_handler(handler_id: int) -> None:
crewai_event_bus.register_handler(
ThreadSafetyTestEvent, create_handler(handler_id)
)
for i in range(num_handlers):
thread = threading.Thread(target=register_handler, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
event = ThreadSafetyTestEvent(type="registration_test")
crewai_event_bus.emit("test_source", event)
time.sleep(0.5)
assert len(handlers_executed) == num_handlers
assert set(handlers_executed) == set(range(num_handlers))
def test_concurrent_emit_and_registration():
received_events: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
def emit_continuously() -> None:
for i in range(50):
event = ThreadSafetyTestEvent(type=f"emit_event_{i}")
crewai_event_bus.emit("emitter", event)
time.sleep(0.001)
def register_continuously() -> None:
for _ in range(10):
@crewai_event_bus.on(ThreadSafetyTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
time.sleep(0.005)
emit_thread = threading.Thread(target=emit_continuously)
register_thread = threading.Thread(target=register_continuously)
emit_thread.start()
register_thread.start()
emit_thread.join()
register_thread.join()
time.sleep(0.5)
assert len(received_events) > 0
def test_stress_test_rapid_emit():
received_count = [0]
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(ThreadSafetyTestEvent)
def counter_handler(source: object, event: BaseEvent) -> None:
with lock:
received_count[0] += 1
num_events = 1000
for i in range(num_events):
event = ThreadSafetyTestEvent(type=f"rapid_event_{i}")
crewai_event_bus.emit("rapid_source", event)
time.sleep(1.0)
assert received_count[0] == num_events
def test_multiple_event_types_concurrent():
class EventTypeA(BaseEvent):
pass
class EventTypeB(BaseEvent):
pass
received_a: list[BaseEvent] = []
received_b: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(EventTypeA)
def handler_a(source: object, event: BaseEvent) -> None:
with lock:
received_a.append(event)
@crewai_event_bus.on(EventTypeB)
def handler_b(source: object, event: BaseEvent) -> None:
with lock:
received_b.append(event)
def emit_type_a() -> None:
for i in range(50):
crewai_event_bus.emit("source_a", EventTypeA(type=f"type_a_{i}"))
def emit_type_b() -> None:
for i in range(50):
crewai_event_bus.emit("source_b", EventTypeB(type=f"type_b_{i}"))
thread_a = threading.Thread(target=emit_type_a)
thread_b = threading.Thread(target=emit_type_b)
thread_a.start()
thread_b.start()
thread_a.join()
thread_b.join()
time.sleep(0.5)
assert len(received_a) == 50
assert len(received_b) == 50

View File

@@ -1,3 +1,4 @@
import threading
from datetime import datetime from datetime import datetime
import os import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@@ -49,6 +50,8 @@ from crewai.tools.base_tool import BaseTool
from pydantic import Field from pydantic import Field
import pytest import pytest
from ..utils import wait_for_event_handlers
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def vcr_config(request) -> dict: def vcr_config(request) -> dict:
@@ -118,6 +121,7 @@ def test_crew_emits_start_kickoff_event(
# Now when Crew creates EventListener, it will use our mocked telemetry # Now when Crew creates EventListener, it will use our mocked telemetry
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
wait_for_event_handlers()
mock_telemetry.crew_execution_span.assert_called_once_with(crew, None) mock_telemetry.crew_execution_span.assert_called_once_with(crew, None)
mock_telemetry.end_crew.assert_called_once_with(crew, "hi") mock_telemetry.end_crew.assert_called_once_with(crew, "hi")
@@ -131,15 +135,20 @@ def test_crew_emits_start_kickoff_event(
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_end_kickoff_event(base_agent, base_task): def test_crew_emits_end_kickoff_event(base_agent, base_task):
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(CrewKickoffCompletedEvent) @crewai_event_bus.on(CrewKickoffCompletedEvent)
def handle_crew_end(source, event): def handle_crew_end(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for crew kickoff completed event"
)
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].crew_name == "TestCrew" assert received_events[0].crew_name == "TestCrew"
assert isinstance(received_events[0].timestamp, datetime) assert isinstance(received_events[0].timestamp, datetime)
@@ -165,6 +174,7 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
eval_llm = LLM(model="gpt-4o-mini") eval_llm = LLM(model="gpt-4o-mini")
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.test(n_iterations=1, eval_llm=eval_llm) crew.test(n_iterations=1, eval_llm=eval_llm)
wait_for_event_handlers()
assert len(received_events) == 3 assert len(received_events) == 3
assert received_events[0].crew_name == "TestCrew" assert received_events[0].crew_name == "TestCrew"
@@ -181,40 +191,44 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_kickoff_failed_event(base_agent, base_task): def test_crew_emits_kickoff_failed_event(base_agent, base_task):
received_events = [] received_events = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(CrewKickoffFailedEvent)
def handle_crew_failed(source, event):
received_events.append(event)
event_received.set()
@crewai_event_bus.on(CrewKickoffFailedEvent) crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
def handle_crew_failed(source, event):
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") with patch.object(Crew, "_execute_tasks") as mock_execute:
error_message = "Simulated crew kickoff failure"
mock_execute.side_effect = Exception(error_message)
with patch.object(Crew, "_execute_tasks") as mock_execute: with pytest.raises(Exception): # noqa: B017
error_message = "Simulated crew kickoff failure" crew.kickoff()
mock_execute.side_effect = Exception(error_message)
with pytest.raises(Exception): # noqa: B017 assert event_received.wait(timeout=5), "Timeout waiting for failed event"
crew.kickoff() assert len(received_events) == 1
assert received_events[0].error == error_message
assert len(received_events) == 1 assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].error == error_message assert received_events[0].type == "crew_kickoff_failed"
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_failed"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_start_task_event(base_agent, base_task): def test_crew_emits_start_task_event(base_agent, base_task):
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(TaskStartedEvent) @crewai_event_bus.on(TaskStartedEvent)
def handle_task_start(source, event): def handle_task_start(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for task started event"
assert len(received_events) == 1 assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime) assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "task_started" assert received_events[0].type == "task_started"
@@ -225,10 +239,12 @@ def test_crew_emits_end_task_event(
base_agent, base_task, reset_event_listener_singleton base_agent, base_task, reset_event_listener_singleton
): ):
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(TaskCompletedEvent) @crewai_event_bus.on(TaskCompletedEvent)
def handle_task_end(source, event): def handle_task_end(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
mock_span = Mock() mock_span = Mock()
@@ -246,6 +262,7 @@ def test_crew_emits_end_task_event(
mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task) mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task)
mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew) mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew)
assert event_received.wait(timeout=5), "Timeout waiting for task completed event"
assert len(received_events) == 1 assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime) assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "task_completed" assert received_events[0].type == "task_completed"
@@ -255,11 +272,13 @@ def test_crew_emits_end_task_event(
def test_task_emits_failed_event_on_execution_error(base_agent, base_task): def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
received_events = [] received_events = []
received_sources = [] received_sources = []
event_received = threading.Event()
@crewai_event_bus.on(TaskFailedEvent) @crewai_event_bus.on(TaskFailedEvent)
def handle_task_failed(source, event): def handle_task_failed(source, event):
received_events.append(event) received_events.append(event)
received_sources.append(source) received_sources.append(source)
event_received.set()
with patch.object( with patch.object(
Task, Task,
@@ -281,6 +300,9 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
agent.execute_task(task=task) agent.execute_task(task=task)
assert event_received.wait(timeout=5), (
"Timeout waiting for task failed event"
)
assert len(received_events) == 1 assert len(received_events) == 1
assert received_sources[0] == task assert received_sources[0] == task
assert received_events[0].error == error_message assert received_events[0].error == error_message
@@ -291,17 +313,27 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_started_and_completed_events(base_agent, base_task): def test_agent_emits_execution_started_and_completed_events(base_agent, base_task):
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(AgentExecutionStartedEvent) @crewai_event_bus.on(AgentExecutionStartedEvent)
def handle_agent_start(source, event): def handle_agent_start(source, event):
received_events.append(event) with lock:
received_events.append(event)
@crewai_event_bus.on(AgentExecutionCompletedEvent) @crewai_event_bus.on(AgentExecutionCompletedEvent)
def handle_agent_completed(source, event): def handle_agent_completed(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) >= 2:
all_events_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert all_events_received.wait(timeout=5), (
"Timeout waiting for agent execution events"
)
assert len(received_events) == 2 assert len(received_events) == 2
assert received_events[0].agent == base_agent assert received_events[0].agent == base_agent
assert received_events[0].task == base_task assert received_events[0].task == base_task
@@ -320,10 +352,12 @@ def test_agent_emits_execution_started_and_completed_events(base_agent, base_tas
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_error_event(base_agent, base_task): def test_agent_emits_execution_error_event(base_agent, base_task):
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(AgentExecutionErrorEvent) @crewai_event_bus.on(AgentExecutionErrorEvent)
def handle_agent_start(source, event): def handle_agent_start(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
error_message = "Error happening while sending prompt to model." error_message = "Error happening while sending prompt to model."
base_agent.max_retry_limit = 0 base_agent.max_retry_limit = 0
@@ -337,6 +371,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
task=base_task, task=base_task,
) )
assert event_received.wait(timeout=5), (
"Timeout waiting for agent execution error event"
)
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].agent == base_agent assert received_events[0].agent == base_agent
assert received_events[0].task == base_task assert received_events[0].task == base_task
@@ -358,10 +395,12 @@ class SayHiTool(BaseTool):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_finished_events(): def test_tools_emits_finished_events():
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event): def handle_tool_end(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
agent = Agent( agent = Agent(
role="base_agent", role="base_agent",
@@ -377,6 +416,10 @@ def test_tools_emits_finished_events():
) )
crew = Crew(agents=[agent], tasks=[task], name="TestCrew") crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for tool usage finished event"
)
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].agent_key == agent.key assert received_events[0].agent_key == agent.key
assert received_events[0].agent_role == agent.role assert received_events[0].agent_role == agent.role
@@ -389,10 +432,15 @@ def test_tools_emits_finished_events():
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_error_events(): def test_tools_emits_error_events():
received_events = [] received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(ToolUsageErrorEvent) @crewai_event_bus.on(ToolUsageErrorEvent)
def handle_tool_end(source, event): def handle_tool_end(source, event):
received_events.append(event) with lock:
received_events.append(event)
if len(received_events) >= 48:
all_events_received.set()
class ErrorTool(BaseTool): class ErrorTool(BaseTool):
name: str = Field( name: str = Field(
@@ -423,6 +471,9 @@ def test_tools_emits_error_events():
crew = Crew(agents=[agent], tasks=[task], name="TestCrew") crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert all_events_received.wait(timeout=5), (
"Timeout waiting for tool usage error events"
)
assert len(received_events) == 48 assert len(received_events) == 48
assert received_events[0].agent_key == agent.key assert received_events[0].agent_key == agent.key
assert received_events[0].agent_role == agent.role assert received_events[0].agent_role == agent.role
@@ -435,11 +486,13 @@ def test_tools_emits_error_events():
def test_flow_emits_start_event(reset_event_listener_singleton): def test_flow_emits_start_event(reset_event_listener_singleton):
received_events = [] received_events = []
event_received = threading.Event()
mock_span = Mock() mock_span = Mock()
@crewai_event_bus.on(FlowStartedEvent) @crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event): def handle_flow_start(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]): class TestFlow(Flow[dict]):
@start() @start()
@@ -458,6 +511,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
flow = TestFlow() flow = TestFlow()
flow.kickoff() flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"]) mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow" assert received_events[0].flow_name == "TestFlow"
@@ -466,6 +520,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
def test_flow_name_emitted_to_event_bus(): def test_flow_name_emitted_to_event_bus():
received_events = [] received_events = []
event_received = threading.Event()
class MyFlowClass(Flow): class MyFlowClass(Flow):
name = "PRODUCTION_FLOW" name = "PRODUCTION_FLOW"
@@ -477,118 +532,133 @@ def test_flow_name_emitted_to_event_bus():
@crewai_event_bus.on(FlowStartedEvent) @crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event): def handle_flow_start(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
flow = MyFlowClass() flow = MyFlowClass()
flow.kickoff() flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].flow_name == "PRODUCTION_FLOW" assert received_events[0].flow_name == "PRODUCTION_FLOW"
def test_flow_emits_finish_event(): def test_flow_emits_finish_event():
received_events = [] received_events = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_finish(source, event):
received_events.append(event)
event_received.set()
@crewai_event_bus.on(FlowFinishedEvent) class TestFlow(Flow[dict]):
def handle_flow_finish(source, event): @start()
received_events.append(event) def begin(self):
return "completed"
class TestFlow(Flow[dict]): flow = TestFlow()
@start() result = flow.kickoff()
def begin(self):
return "completed"
flow = TestFlow() assert event_received.wait(timeout=5), "Timeout waiting for finish event"
result = flow.kickoff() assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow"
assert len(received_events) == 1 assert received_events[0].type == "flow_finished"
assert received_events[0].flow_name == "TestFlow" assert received_events[0].result == "completed"
assert received_events[0].type == "flow_finished" assert result == "completed"
assert received_events[0].result == "completed"
assert result == "completed"
def test_flow_emits_method_execution_started_event(): def test_flow_emits_method_execution_started_event():
received_events = [] received_events = []
lock = threading.Lock()
second_event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MethodExecutionStartedEvent)
async def handle_method_start(source, event):
@crewai_event_bus.on(MethodExecutionStartedEvent) with lock:
def handle_method_start(source, event):
received_events.append(event) received_events.append(event)
if event.method_name == "second_method":
second_event_received.set()
class TestFlow(Flow[dict]): class TestFlow(Flow[dict]):
@start() @start()
def begin(self): def begin(self):
return "started" return "started"
@listen("begin") @listen("begin")
def second_method(self): def second_method(self):
return "executed" return "executed"
flow = TestFlow() flow = TestFlow()
flow.kickoff() flow.kickoff()
assert len(received_events) == 2 assert second_event_received.wait(timeout=5), (
"Timeout waiting for second_method event"
)
assert len(received_events) == 2
assert received_events[0].method_name == "begin" # Events may arrive in any order due to async handlers, so check both are present
assert received_events[0].flow_name == "TestFlow" method_names = {event.method_name for event in received_events}
assert received_events[0].type == "method_execution_started" assert method_names == {"begin", "second_method"}
assert received_events[1].method_name == "second_method" for event in received_events:
assert received_events[1].flow_name == "TestFlow" assert event.flow_name == "TestFlow"
assert received_events[1].type == "method_execution_started" assert event.type == "method_execution_started"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_register_handler_adds_new_handler(base_agent, base_task): def test_register_handler_adds_new_handler(base_agent, base_task):
received_events = [] received_events = []
event_received = threading.Event()
def custom_handler(source, event): def custom_handler(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
with crewai_event_bus.scoped_handlers(): crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert len(received_events) == 1 assert event_received.wait(timeout=5), "Timeout waiting for handler event"
assert isinstance(received_events[0].timestamp, datetime) assert len(received_events) == 1
assert received_events[0].type == "crew_kickoff_started" assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_started"
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_multiple_handlers_for_same_event(base_agent, base_task): def test_multiple_handlers_for_same_event(base_agent, base_task):
received_events_1 = [] received_events_1 = []
received_events_2 = [] received_events_2 = []
event_received = threading.Event()
def handler_1(source, event): def handler_1(source, event):
received_events_1.append(event) received_events_1.append(event)
def handler_2(source, event): def handler_2(source, event):
received_events_2.append(event) received_events_2.append(event)
event_received.set()
with crewai_event_bus.scoped_handlers(): crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1) crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew") crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff() crew.kickoff()
assert len(received_events_1) == 1 assert event_received.wait(timeout=5), "Timeout waiting for handler events"
assert len(received_events_2) == 1 assert len(received_events_1) == 1
assert received_events_1[0].type == "crew_kickoff_started" assert len(received_events_2) == 1
assert received_events_2[0].type == "crew_kickoff_started" assert received_events_1[0].type == "crew_kickoff_started"
assert received_events_2[0].type == "crew_kickoff_started"
def test_flow_emits_created_event(): def test_flow_emits_created_event():
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(FlowCreatedEvent) @crewai_event_bus.on(FlowCreatedEvent)
def handle_flow_created(source, event): def handle_flow_created(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]): class TestFlow(Flow[dict]):
@start() @start()
@@ -598,6 +668,7 @@ def test_flow_emits_created_event():
flow = TestFlow() flow = TestFlow()
flow.kickoff() flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow created event"
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow" assert received_events[0].flow_name == "TestFlow"
assert received_events[0].type == "flow_created" assert received_events[0].type == "flow_created"
@@ -605,11 +676,13 @@ def test_flow_emits_created_event():
def test_flow_emits_method_execution_failed_event(): def test_flow_emits_method_execution_failed_event():
received_events = [] received_events = []
event_received = threading.Event()
error = Exception("Simulated method failure") error = Exception("Simulated method failure")
@crewai_event_bus.on(MethodExecutionFailedEvent) @crewai_event_bus.on(MethodExecutionFailedEvent)
def handle_method_failed(source, event): def handle_method_failed(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]): class TestFlow(Flow[dict]):
@start() @start()
@@ -620,6 +693,9 @@ def test_flow_emits_method_execution_failed_event():
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
flow.kickoff() flow.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for method execution failed event"
)
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].method_name == "begin" assert received_events[0].method_name == "begin"
assert received_events[0].flow_name == "TestFlow" assert received_events[0].flow_name == "TestFlow"
@@ -641,6 +717,7 @@ def test_llm_emits_call_started_event():
llm = LLM(model="gpt-4o-mini") llm = LLM(model="gpt-4o-mini")
llm.call("Hello, how are you?") llm.call("Hello, how are you?")
wait_for_event_handlers()
assert len(received_events) == 2 assert len(received_events) == 2
assert received_events[0].type == "llm_call_started" assert received_events[0].type == "llm_call_started"
@@ -656,10 +733,12 @@ def test_llm_emits_call_started_event():
@pytest.mark.isolated @pytest.mark.isolated
def test_llm_emits_call_failed_event(): def test_llm_emits_call_failed_event():
received_events = [] received_events = []
event_received = threading.Event()
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_call_failed(source, event): def handle_llm_call_failed(source, event):
received_events.append(event) received_events.append(event)
event_received.set()
error_message = "OpenAI API call failed: Simulated API failure" error_message = "OpenAI API call failed: Simulated API failure"
@@ -673,6 +752,7 @@ def test_llm_emits_call_failed_event():
llm.call("Hello, how are you?") llm.call("Hello, how are you?")
assert str(exc_info.value) == "Simulated API failure" assert str(exc_info.value) == "Simulated API failure"
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
assert len(received_events) == 1 assert len(received_events) == 1
assert received_events[0].type == "llm_call_failed" assert received_events[0].type == "llm_call_failed"
assert received_events[0].error == error_message assert received_events[0].error == error_message
@@ -686,24 +766,28 @@ def test_llm_emits_call_failed_event():
def test_llm_emits_stream_chunk_events(): def test_llm_emits_stream_chunk_events():
"""Test that LLM emits stream chunk events when streaming is enabled.""" """Test that LLM emits stream chunk events when streaming is enabled."""
received_chunks = [] received_chunks = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 1:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent) # Create an LLM with streaming enabled
def handle_stream_chunk(source, event): llm = LLM(model="gpt-4o", stream=True)
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled # Call the LLM with a simple message
llm = LLM(model="gpt-4o", stream=True) response = llm.call("Tell me a short joke")
# Call the LLM with a simple message # Wait for at least one chunk
response = llm.call("Tell me a short joke") assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
# Verify that we received chunks # Verify that we received chunks
assert len(received_chunks) > 0 assert len(received_chunks) > 0
# Verify that concatenating all chunks equals the final response # Verify that concatenating all chunks equals the final response
assert "".join(received_chunks) == response assert "".join(received_chunks) == response
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -711,23 +795,21 @@ def test_llm_no_stream_chunks_when_streaming_disabled():
"""Test that LLM doesn't emit stream chunk events when streaming is disabled.""" """Test that LLM doesn't emit stream chunk events when streaming is disabled."""
received_chunks = [] received_chunks = []
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
@crewai_event_bus.on(LLMStreamChunkEvent) # Create an LLM with streaming disabled
def handle_stream_chunk(source, event): llm = LLM(model="gpt-4o", stream=False)
received_chunks.append(event.chunk)
# Create an LLM with streaming disabled # Call the LLM with a simple message
llm = LLM(model="gpt-4o", stream=False) response = llm.call("Tell me a short joke")
# Call the LLM with a simple message # Verify that we didn't receive any chunks
response = llm.call("Tell me a short joke") assert len(received_chunks) == 0
# Verify that we didn't receive any chunks # Verify we got a response
assert len(received_chunks) == 0 assert response and isinstance(response, str)
# Verify we got a response
assert response and isinstance(response, str)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -735,98 +817,105 @@ def test_streaming_fallback_to_non_streaming():
"""Test that streaming falls back to non-streaming when there's an error.""" """Test that streaming falls back to non-streaming when there's an error."""
received_chunks = [] received_chunks = []
fallback_called = False fallback_called = False
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 2:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent) # Create an LLM with streaming enabled
def handle_stream_chunk(source, event): llm = LLM(model="gpt-4o", stream=True)
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled # Store original methods
llm = LLM(model="gpt-4o", stream=True) original_call = llm.call
# Store original methods # Create a mock call method that handles the streaming error
original_call = llm.call def mock_call(messages, tools=None, callbacks=None, available_functions=None):
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
# Create a mock call method that handles the streaming error # Mark that fallback would be called
def mock_call(messages, tools=None, callbacks=None, available_functions=None): fallback_called = True
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
# Mark that fallback would be called # Return a response as if fallback succeeded
fallback_called = True return "Fallback response after streaming error"
# Return a response as if fallback succeeded # Replace the call method with our mock
return "Fallback response after streaming error" llm.call = mock_call
# Replace the call method with our mock try:
llm.call = mock_call # Call the LLM
response = llm.call("Tell me a short joke")
wait_for_event_handlers()
try: assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
# Call the LLM
response = llm.call("Tell me a short joke")
# Verify that we received some chunks # Verify that we received some chunks
assert len(received_chunks) == 2 assert len(received_chunks) == 2
assert received_chunks[0] == "Test chunk 1" assert received_chunks[0] == "Test chunk 1"
assert received_chunks[1] == "Test chunk 2" assert received_chunks[1] == "Test chunk 2"
# Verify fallback was triggered # Verify fallback was triggered
assert fallback_called assert fallback_called
# Verify we got the fallback response # Verify we got the fallback response
assert response == "Fallback response after streaming error" assert response == "Fallback response after streaming error"
finally: finally:
# Restore the original method # Restore the original method
llm.call = original_call llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_empty_response_handling(): def test_streaming_empty_response_handling():
"""Test that streaming handles empty responses correctly.""" """Test that streaming handles empty responses correctly."""
received_chunks = [] received_chunks = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 3:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent) # Create an LLM with streaming enabled
def handle_stream_chunk(source, event): llm = LLM(model="gpt-3.5-turbo", stream=True)
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled # Store original methods
llm = LLM(model="gpt-3.5-turbo", stream=True) original_call = llm.call
# Store original methods # Create a mock call method that simulates empty chunks
original_call = llm.call def mock_call(messages, tools=None, callbacks=None, available_functions=None):
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
# Create a mock call method that simulates empty chunks # Return the default message for empty responses
def mock_call(messages, tools=None, callbacks=None, available_functions=None): return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
# Return the default message for empty responses # Replace the call method with our mock
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request." llm.call = mock_call
# Replace the call method with our mock try:
llm.call = mock_call # Call the LLM - this should handle empty response
response = llm.call("Tell me a short joke")
try: assert event_received.wait(timeout=5), "Timeout waiting for empty chunks"
# Call the LLM - this should handle empty response
response = llm.call("Tell me a short joke")
# Verify that we received empty chunks # Verify that we received empty chunks
assert len(received_chunks) == 3 assert len(received_chunks) == 3
assert all(chunk == "" for chunk in received_chunks) assert all(chunk == "" for chunk in received_chunks)
# Verify the response is the default message for empty responses # Verify the response is the default message for empty responses
assert "I apologize" in response and "couldn't generate" in response assert "I apologize" in response and "couldn't generate" in response
finally: finally:
# Restore the original method # Restore the original method
llm.call = original_call llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -835,41 +924,49 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
failed_event = [] failed_event = []
started_event = [] started_event = []
stream_event = [] stream_event = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_failed(source, event): def handle_llm_started(source, event):
failed_event.append(event) started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent) @crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_started(source, event): def handle_llm_completed(source, event):
started_event.append(event) completed_event.append(event)
if len(started_event) >= 1 and len(stream_event) >= 12:
event_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent) @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_completed(source, event): def handle_llm_stream_chunk(source, event):
completed_event.append(event) stream_event.append(event)
if (
len(completed_event) >= 1
and len(started_event) >= 1
and len(stream_event) >= 12
):
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent) agent = Agent(
def handle_llm_stream_chunk(source, event): role="TestAgent",
stream_event.append(event) llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
task = Task(
description="Just say hi",
expected_output="hi",
llm=LLM(model="gpt-4o-mini", stream=True),
agent=agent,
)
agent = Agent( crew = Crew(agents=[agent], tasks=[task])
role="TestAgent", crew.kickoff()
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
task = Task(
description="Just say hi",
expected_output="hi",
llm=LLM(model="gpt-4o-mini", stream=True),
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
assert len(completed_event) == 1 assert len(completed_event) == 1
assert len(failed_event) == 0 assert len(failed_event) == 0
assert len(started_event) == 1 assert len(started_event) == 1
@@ -899,28 +996,30 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
failed_event = [] failed_event = []
started_event = [] started_event = []
stream_event = [] stream_event = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_failed(source, event): def handle_llm_started(source, event):
failed_event.append(event) started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent) @crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_started(source, event): def handle_llm_completed(source, event):
started_event.append(event) completed_event.append(event)
if len(started_event) >= 1:
event_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent) @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_completed(source, event): def handle_llm_stream_chunk(source, event):
completed_event.append(event) stream_event.append(event)
@crewai_event_bus.on(LLMStreamChunkEvent) crew = Crew(agents=[base_agent], tasks=[base_task])
def handle_llm_stream_chunk(source, event): crew.kickoff()
stream_event.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task])
crew.kickoff()
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
assert len(completed_event) == 1 assert len(completed_event) == 1
assert len(failed_event) == 0 assert len(failed_event) == 0
assert len(started_event) == 1 assert len(started_event) == 1
@@ -950,32 +1049,41 @@ def test_llm_emits_event_with_lite_agent():
failed_event = [] failed_event = []
started_event = [] started_event = []
stream_event = [] stream_event = []
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_failed(source, event): def handle_llm_started(source, event):
failed_event.append(event) started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent) @crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_started(source, event): def handle_llm_completed(source, event):
started_event.append(event) completed_event.append(event)
if len(started_event) >= 1 and len(stream_event) >= 15:
all_events_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent) @crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_completed(source, event): def handle_llm_stream_chunk(source, event):
completed_event.append(event) stream_event.append(event)
if (
len(completed_event) >= 1
and len(started_event) >= 1
and len(stream_event) >= 15
):
all_events_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent) agent = Agent(
def handle_llm_stream_chunk(source, event): role="Speaker",
stream_event.append(event) llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
agent = Agent( assert all_events_received.wait(timeout=10), "Timeout waiting for all events"
role="Speaker",
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
assert len(completed_event) == 1 assert len(completed_event) == 1
assert len(failed_event) == 0 assert len(failed_event) == 0

39
lib/crewai/tests/utils.py Normal file
View File

@@ -0,0 +1,39 @@
"""Test utilities for CrewAI tests."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
def wait_for_event_handlers(timeout: float = 5.0) -> None:
"""Wait for all pending event handlers to complete.
This helper ensures all sync and async handlers finish processing before
proceeding. Useful in tests to make assertions deterministic.
Args:
timeout: Maximum time to wait in seconds.
"""
from crewai.events.event_bus import crewai_event_bus
loop = getattr(crewai_event_bus, "_loop", None)
if loop and not loop.is_closed():
async def _wait_for_async_tasks() -> None:
tasks = {
t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()
}
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
future = asyncio.run_coroutine_threadsafe(_wait_for_async_tasks(), loop)
try:
future.result(timeout=timeout)
except Exception: # noqa: S110
pass
crewai_event_bus._sync_executor.shutdown(wait=True)
crewai_event_bus._sync_executor = ThreadPoolExecutor(
max_workers=10,
thread_name_prefix="CrewAISyncHandler",
)

11
uv.lock generated
View File

@@ -465,15 +465,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/01/f3/a9d961cfba236dc85f27f2f2c6eab88e12698754aaa02459ba7dfafc5062/bedrock_agentcore-0.1.7-py3-none-any.whl", hash = "sha256:441dde64fea596e9571e47ae37ee3b033e58d8d255018f13bdcde8ae8bef2075", size = 77216, upload-time = "2025-10-01T16:18:38.153Z" }, { url = "https://files.pythonhosted.org/packages/01/f3/a9d961cfba236dc85f27f2f2c6eab88e12698754aaa02459ba7dfafc5062/bedrock_agentcore-0.1.7-py3-none-any.whl", hash = "sha256:441dde64fea596e9571e47ae37ee3b033e58d8d255018f13bdcde8ae8bef2075", size = 77216, upload-time = "2025-10-01T16:18:38.153Z" },
] ]
[[package]]
name = "blinker"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460, upload-time = "2024-11-08T17:25:47.436Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
]
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.40.45" version = "1.40.45"
@@ -987,7 +978,6 @@ name = "crewai"
source = { editable = "lib/crewai" } source = { editable = "lib/crewai" }
dependencies = [ dependencies = [
{ name = "appdirs" }, { name = "appdirs" },
{ name = "blinker" },
{ name = "chromadb" }, { name = "chromadb" },
{ name = "click" }, { name = "click" },
{ name = "instructor" }, { name = "instructor" },
@@ -1061,7 +1051,6 @@ watson = [
requires-dist = [ requires-dist = [
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" }, { name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
{ name = "appdirs", specifier = ">=1.4.4" }, { name = "appdirs", specifier = ">=1.4.4" },
{ name = "blinker", specifier = ">=1.9.0" },
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" }, { name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" }, { name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
{ name = "chromadb", specifier = "~=1.1.0" }, { name = "chromadb", specifier = "~=1.1.0" },