feat: improve event bus thread safety and async support

Add thread-safe, async-compatible event bus with read–write locking and
handler dependency ordering. Remove blinker dependency and implement
direct dispatch. Improve type safety, error handling, and deterministic
event synchronization.

Refactor tests to auto-wait for async handlers, ensure clean teardown,
and add comprehensive concurrency coverage. Replace thread-local state
in AgentEvaluator with instance-based locking for correct cross-thread
access. Enhance tracing reliability and event finalization.
This commit is contained in:
Greyson LaLonde
2025-10-14 13:28:58 -04:00
committed by GitHub
parent cec4e4c2e9
commit 53b239c6df
34 changed files with 3360 additions and 876 deletions

View File

@@ -35,7 +35,6 @@ dependencies = [
"uv>=0.4.25",
"tomli-w>=1.1.0",
"tomli>=2.0.2",
"blinker>=1.9.0",
"json5>=0.10.0",
"portalocker==2.7.0",
"pydantic-settings>=2.10.1",

View File

@@ -5,10 +5,13 @@ This module provides the event infrastructure that allows users to:
- Track memory operations and performance
- Build custom logging and analytics
- Extend CrewAI with custom event handlers
- Declare handler dependencies for ordered execution
"""
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.depends import Depends
from crewai.events.event_bus import crewai_event_bus
from crewai.events.handler_graph import CircularDependencyError
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
@@ -109,6 +112,7 @@ __all__ = [
"AgentReasoningFailedEvent",
"AgentReasoningStartedEvent",
"BaseEventListener",
"CircularDependencyError",
"CrewKickoffCompletedEvent",
"CrewKickoffFailedEvent",
"CrewKickoffStartedEvent",
@@ -119,6 +123,7 @@ __all__ = [
"CrewTrainCompletedEvent",
"CrewTrainFailedEvent",
"CrewTrainStartedEvent",
"Depends",
"FlowCreatedEvent",
"FlowEvent",
"FlowFinishedEvent",

View File

@@ -9,6 +9,7 @@ class BaseEventListener(ABC):
def __init__(self):
super().__init__()
self.setup_listeners(crewai_event_bus)
crewai_event_bus.validate_dependencies()
@abstractmethod
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):

View File

@@ -0,0 +1,105 @@
"""Dependency injection system for event handlers.
This module provides a FastAPI-style dependency system that allows event handlers
to declare dependencies on other handlers, ensuring proper execution order while
maintaining parallelism where possible.
"""
from collections.abc import Coroutine
from typing import Any, Generic, Protocol, TypeVar
from crewai.events.base_events import BaseEvent
EventT_co = TypeVar("EventT_co", bound=BaseEvent, contravariant=True)
class EventHandler(Protocol[EventT_co]):
"""Protocol for event handler functions.
Generic protocol that accepts any subclass of BaseEvent.
Handlers can be either synchronous (returning None) or asynchronous
(returning a coroutine).
"""
def __call__(
self, source: Any, event: EventT_co, /
) -> None | Coroutine[Any, Any, None]:
"""Event handler signature.
Args:
source: The object that emitted the event
event: The event instance (any BaseEvent subclass)
Returns:
None for sync handlers, Coroutine for async handlers
"""
...
T = TypeVar("T", bound=EventHandler[Any])
class Depends(Generic[T]):
"""Declares a dependency on another event handler.
Similar to FastAPI's Depends, this allows handlers to specify that they
depend on other handlers completing first. Handlers with dependencies will
execute after their dependencies, while independent handlers can run in parallel.
Args:
handler: The handler function that this handler depends on
Example:
>>> from crewai.events import Depends, crewai_event_bus
>>> from crewai.events import LLMCallStartedEvent
>>> @crewai_event_bus.on(LLMCallStartedEvent)
>>> def setup_context(source, event):
... return {"initialized": True}
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
>>> def process(source, event):
... # Runs after setup_context completes
... pass
"""
def __init__(self, handler: T) -> None:
"""Initialize a dependency on a handler.
Args:
handler: The handler function this depends on
"""
self.handler = handler
def __repr__(self) -> str:
"""Return a string representation of the dependency.
Returns:
A string showing the dependent handler name
"""
handler_name = getattr(self.handler, "__name__", repr(self.handler))
return f"Depends({handler_name})"
def __eq__(self, other: object) -> bool:
"""Check equality based on the handler reference.
Args:
other: Another Depends instance to compare
Returns:
True if both depend on the same handler, False otherwise
"""
if not isinstance(other, Depends):
return False
return self.handler is other.handler
def __hash__(self) -> int:
"""Return hash based on handler identity.
Since equality is based on identity (is), we hash the handler
object directly rather than its id for consistency.
Returns:
Hash of the handler object
"""
return id(self.handler)

View File

@@ -1,125 +1,507 @@
from __future__ import annotations
"""Event bus for managing and dispatching events in CrewAI.
import threading
from collections.abc import Callable
This module provides a singleton event bus that allows registration and handling
of events throughout the CrewAI system, supporting both synchronous and asynchronous
event handlers with optional dependency management.
"""
import asyncio
import atexit
from collections.abc import Callable, Generator
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, TypeVar, cast
import threading
from typing import Any, Final, ParamSpec, TypeVar
from blinker import Signal
from typing_extensions import Self
from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes
from crewai.events.depends import Depends
from crewai.events.handler_graph import build_execution_plan
from crewai.events.types.event_bus_types import (
AsyncHandler,
AsyncHandlerSet,
ExecutionPlan,
Handler,
SyncHandler,
SyncHandlerSet,
)
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
from crewai.events.utils.rw_lock import RWLock
EventT = TypeVar("EventT", bound=BaseEvent)
P = ParamSpec("P")
R = TypeVar("R")
class CrewAIEventsBus:
"""
A singleton event bus that uses blinker signals for event handling.
Allows both internal (Flow/Crew) and external event handling.
"""Singleton event bus for handling events in CrewAI.
This class manages event registration and emission for both synchronous
and asynchronous event handlers, automatically scheduling async handlers
in a dedicated background event loop.
Synchronous handlers execute in a thread pool executor to ensure completion
before program exit. Asynchronous handlers execute in a dedicated event loop
running in a daemon thread, with graceful shutdown waiting for completion.
Attributes:
_instance: Singleton instance of the event bus
_instance_lock: Reentrant lock for singleton initialization (class-level)
_rwlock: Read-write lock for handler registration and access (instance-level)
_sync_handlers: Mapping of event types to registered synchronous handlers
_async_handlers: Mapping of event types to registered asynchronous handlers
_sync_executor: Thread pool executor for running synchronous handlers
_loop: Dedicated asyncio event loop for async handler execution
_loop_thread: Background daemon thread running the event loop
_console: Console formatter for error output
"""
_instance = None
_lock = threading.Lock()
_instance: Self | None = None
_instance_lock: threading.RLock = threading.RLock()
_rwlock: RWLock
_sync_handlers: dict[type[BaseEvent], SyncHandlerSet]
_async_handlers: dict[type[BaseEvent], AsyncHandlerSet]
_handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]]
_execution_plan_cache: dict[type[BaseEvent], ExecutionPlan]
_console: ConsoleFormatter
_shutting_down: bool
def __new__(cls):
def __new__(cls) -> Self:
"""Create or return the singleton instance.
Returns:
The singleton CrewAIEventsBus instance
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None: # prevent race condition
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self) -> None:
"""Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus")
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
"""Initialize the event bus internal state.
Creates handler dictionaries and starts a dedicated background
event loop for async handler execution.
"""
self._shutting_down = False
self._rwlock = RWLock()
self._sync_handlers: dict[type[BaseEvent], SyncHandlerSet] = {}
self._async_handlers: dict[type[BaseEvent], AsyncHandlerSet] = {}
self._handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]] = {}
self._execution_plan_cache: dict[type[BaseEvent], ExecutionPlan] = {}
self._sync_executor = ThreadPoolExecutor(
max_workers=10,
thread_name_prefix="CrewAISyncHandler",
)
self._console = ConsoleFormatter()
self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
target=self._run_loop,
name="CrewAIEventsLoop",
daemon=True,
)
self._loop_thread.start()
def _run_loop(self) -> None:
"""Run the background async event loop."""
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _register_handler(
self,
event_type: type[BaseEvent],
handler: Callable[..., Any],
dependencies: list[Depends] | None = None,
) -> None:
"""Register a handler for the given event type.
Args:
event_type: The event class to listen for
handler: The handler function to register
dependencies: Optional list of dependencies
"""
with self._rwlock.w_locked():
if is_async_handler(handler):
existing_async = self._async_handlers.get(event_type, frozenset())
self._async_handlers[event_type] = existing_async | {handler}
else:
existing_sync = self._sync_handlers.get(event_type, frozenset())
self._sync_handlers[event_type] = existing_sync | {handler}
if dependencies:
if event_type not in self._handler_dependencies:
self._handler_dependencies[event_type] = {}
self._handler_dependencies[event_type][handler] = dependencies
self._execution_plan_cache.pop(event_type, None)
def on(
self, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
"""
Decorator to register an event handler for a specific event type.
self,
event_type: type[BaseEvent],
depends_on: Depends | list[Depends] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to register an event handler for a specific event type.
Usage:
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed(
source: Any, event: AgentExecutionCompletedEvent
):
print(f"👍 Agent '{event.agent}' completed task")
print(f" Output: {event.output}")
Args:
event_type: The event class to listen for
depends_on: Optional dependency or list of dependencies. Handlers with
dependencies will execute after their dependencies complete.
Returns:
Decorator function that registers the handler
Example:
>>> from crewai.events import crewai_event_bus, Depends
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent)
>>> def setup_context(source, event):
... print("Setting up context")
>>>
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
>>> def process(source, event):
... print("Processing (runs after setup_context)")
"""
def decorator(
handler: Callable[[Any, EventT], None],
) -> Callable[[Any, EventT], None]:
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(
cast(Callable[[Any, EventT], None], handler)
)
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
"""Register the handler and return it unchanged.
Args:
handler: Event handler function to register
Returns:
The same handler function unchanged
"""
deps = None
if depends_on is not None:
deps = [depends_on] if isinstance(depends_on, Depends) else depends_on
self._register_handler(event_type, handler, dependencies=deps)
return handler
return decorator
@staticmethod
def _call_handler(
handler: Callable, source: Any, event: BaseEvent, event_type: type
def _call_handlers(
self,
source: Any,
event: BaseEvent,
handlers: SyncHandlerSet,
) -> None:
"""Call a single handler with error handling."""
try:
handler(source, event)
except Exception as e:
print(
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
"""Call provided synchronous handlers.
Args:
source: The emitting object
event: The event instance
handlers: Frozenset of sync handlers to call
"""
errors: list[tuple[SyncHandler, Exception]] = [
(handler, error)
for handler in handlers
if (error := is_call_handler_safe(handler, source, event)) is not None
]
if errors:
for handler, error in errors:
self._console.print(
f"[CrewAIEventsBus] Sync handler error in {handler.__name__}: {error}"
)
async def _acall_handlers(
self,
source: Any,
event: BaseEvent,
handlers: AsyncHandlerSet,
) -> None:
"""Asynchronously call provided async handlers.
Args:
source: The object that emitted the event
event: The event instance
handlers: Frozenset of async handlers to call
"""
coros = [handler(source, event) for handler in handlers]
results = await asyncio.gather(*coros, return_exceptions=True)
for handler, result in zip(handlers, results, strict=False):
if isinstance(result, Exception):
self._console.print(
f"[CrewAIEventsBus] Async handler error in {getattr(handler, '__name__', handler)}: {result}"
)
async def _emit_with_dependencies(self, source: Any, event: BaseEvent) -> None:
"""Emit an event with dependency-aware handler execution.
Handlers are grouped into execution levels based on their dependencies.
Within each level, async handlers run concurrently while sync handlers
run sequentially (or in thread pool). Each level completes before the
next level starts.
Uses a cached execution plan for performance. The plan is built once
per event type and cached until handlers are modified.
Args:
source: The emitting object
event: The event instance to emit
"""
event_type = type(event)
with self._rwlock.r_locked():
if self._shutting_down:
return
cached_plan = self._execution_plan_cache.get(event_type)
if cached_plan is not None:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
if cached_plan is None:
with self._rwlock.w_locked():
if self._shutting_down:
return
cached_plan = self._execution_plan_cache.get(event_type)
if cached_plan is None:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
dependencies = dict(self._handler_dependencies.get(event_type, {}))
all_handlers = list(sync_handlers | async_handlers)
if not all_handlers:
return
cached_plan = build_execution_plan(all_handlers, dependencies)
self._execution_plan_cache[event_type] = cached_plan
else:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
for level in cached_plan:
level_sync = frozenset(h for h in level if h in sync_handlers)
level_async = frozenset(h for h in level if h in async_handlers)
if level_sync:
if event_type is LLMStreamChunkEvent:
self._call_handlers(source, event, level_sync)
else:
future = self._sync_executor.submit(
self._call_handlers, source, event, level_sync
)
await asyncio.get_running_loop().run_in_executor(
None, future.result
)
if level_async:
await self._acall_handlers(source, event, level_async)
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
"""Emit an event to all registered handlers.
If handlers have dependencies (registered with depends_on), they execute
in dependency order. Otherwise, handlers execute as before (sync in thread
pool, async fire-and-forget).
Stream chunk events always execute synchronously to preserve ordering.
Args:
source: The emitting object
event: The event instance to emit
Returns:
Future that completes when handlers finish. Returns:
- Future for sync-only handlers (ThreadPoolExecutor future)
- Future for async handlers or mixed handlers (asyncio future)
- Future for dependency-managed handlers (asyncio future)
- None if no handlers or sync stream chunk events
Example:
>>> future = crewai_event_bus.emit(source, event)
>>> if future:
... await asyncio.wrap_future(future) # In async test
... # or future.result(timeout=5.0) in sync code
"""
event_type = type(event)
with self._rwlock.r_locked():
if self._shutting_down:
self._console.print(
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
)
return None
has_dependencies = event_type in self._handler_dependencies
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
if has_dependencies:
return asyncio.run_coroutine_threadsafe(
self._emit_with_dependencies(source, event),
self._loop,
)
def emit(self, source: Any, event: BaseEvent) -> None:
"""
Emit an event to all registered handlers
if sync_handlers:
if event_type is LLMStreamChunkEvent:
self._call_handlers(source, event, sync_handlers)
else:
sync_future = self._sync_executor.submit(
self._call_handlers, source, event, sync_handlers
)
if not async_handlers:
return sync_future
if async_handlers:
return asyncio.run_coroutine_threadsafe(
self._acall_handlers(source, event, async_handlers),
self._loop,
)
return None
async def aemit(self, source: Any, event: BaseEvent) -> None:
"""Asynchronously emit an event to registered async handlers.
Only processes async handlers. Use in async contexts.
Args:
source: The object emitting the event
event: The event instance to emit
"""
for event_type, handlers in self._handlers.items():
if isinstance(event, event_type):
for handler in handlers:
self._call_handler(handler, source, event, event_type)
event_type = type(event)
self._signal.send(source, event=event)
with self._rwlock.r_locked():
if self._shutting_down:
self._console.print(
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
)
return
async_handlers = self._async_handlers.get(event_type, frozenset())
if async_handlers:
await self._acall_handlers(source, event, async_handlers)
def register_handler(
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
self,
event_type: type[BaseEvent],
handler: SyncHandler | AsyncHandler,
) -> None:
"""Register an event handler for a specific event type"""
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(
cast(Callable[[Any, EventTypes], None], handler)
)
"""Register an event handler for a specific event type.
Args:
event_type: The event class to listen for
handler: The handler function to register
"""
self._register_handler(event_type, handler)
def validate_dependencies(self) -> None:
"""Validate all registered handler dependencies.
Attempts to build execution plans for all event types with dependencies.
This detects circular dependencies and cross-event-type dependencies
before events are emitted.
Raises:
CircularDependencyError: If circular dependencies or unresolved
dependencies (e.g., cross-event-type) are detected
"""
with self._rwlock.r_locked():
for event_type in self._handler_dependencies:
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
dependencies = dict(self._handler_dependencies.get(event_type, {}))
all_handlers = list(sync_handlers | async_handlers)
if all_handlers and dependencies:
build_execution_plan(all_handlers, dependencies)
@contextmanager
def scoped_handlers(self):
"""
Context manager for temporary event handling scope.
Useful for testing or temporary event handling.
def scoped_handlers(self) -> Generator[None, Any, None]:
"""Context manager for temporary event handling scope.
Usage:
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(CrewKickoffStarted)
def temp_handler(source, event):
print("Temporary handler")
# Do stuff...
# Handlers are cleared after the context
Useful for testing or temporary event handling. All handlers registered
within this context are cleared when the context exits.
Example:
>>> from crewai.events.event_bus import crewai_event_bus
>>> from crewai.events.event_types import CrewKickoffStartedEvent
>>> with crewai_event_bus.scoped_handlers():
...
... @crewai_event_bus.on(CrewKickoffStartedEvent)
... def temp_handler(source, event):
... print("Temporary handler")
...
... # Do stuff...
... # Handlers are cleared after the context
"""
previous_handlers = self._handlers.copy()
self._handlers.clear()
with self._rwlock.w_locked():
prev_sync = self._sync_handlers
prev_async = self._async_handlers
prev_deps = self._handler_dependencies
prev_cache = self._execution_plan_cache
self._sync_handlers = {}
self._async_handlers = {}
self._handler_dependencies = {}
self._execution_plan_cache = {}
try:
yield
finally:
self._handlers = previous_handlers
with self._rwlock.w_locked():
self._sync_handlers = prev_sync
self._async_handlers = prev_async
self._handler_dependencies = prev_deps
self._execution_plan_cache = prev_cache
def shutdown(self, wait: bool = True) -> None:
"""Gracefully shutdown the event loop and wait for all tasks to finish.
Args:
wait: If True, wait for all pending tasks to complete before stopping.
If False, cancel all pending tasks immediately.
"""
with self._rwlock.w_locked():
self._shutting_down = True
loop = getattr(self, "_loop", None)
if loop is None or loop.is_closed():
return
if wait:
async def _wait_for_all_tasks() -> None:
tasks = {
t
for t in asyncio.all_tasks(loop)
if t is not asyncio.current_task()
}
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
future = asyncio.run_coroutine_threadsafe(_wait_for_all_tasks(), loop)
try:
future.result()
except Exception as e:
self._console.print(f"[CrewAIEventsBus] Error waiting for tasks: {e}")
else:
def _cancel_tasks() -> None:
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
task.cancel()
loop.call_soon_threadsafe(_cancel_tasks)
loop.call_soon_threadsafe(loop.stop)
self._loop_thread.join()
loop.close()
self._sync_executor.shutdown(wait=wait)
with self._rwlock.w_locked():
self._sync_handlers.clear()
self._async_handlers.clear()
self._execution_plan_cache.clear()
# Global instance
crewai_event_bus = CrewAIEventsBus()
crewai_event_bus: Final[CrewAIEventsBus] = CrewAIEventsBus()
atexit.register(crewai_event_bus.shutdown)

View File

@@ -0,0 +1,130 @@
"""Dependency graph resolution for event handlers.
This module resolves handler dependencies into execution levels, ensuring
handlers execute in correct order while maximizing parallelism.
"""
from collections import defaultdict, deque
from collections.abc import Sequence
from crewai.events.depends import Depends
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
class CircularDependencyError(Exception):
"""Exception raised when circular dependencies are detected in event handlers.
Attributes:
handlers: The handlers involved in the circular dependency
"""
def __init__(self, handlers: list[Handler]) -> None:
"""Initialize the circular dependency error.
Args:
handlers: The handlers involved in the circular dependency
"""
handler_names = ", ".join(
getattr(h, "__name__", repr(h)) for h in handlers[:5]
)
message = f"Circular dependency detected in event handlers: {handler_names}"
super().__init__(message)
self.handlers = handlers
class HandlerGraph:
"""Resolves handler dependencies into parallel execution levels.
Handlers are organized into levels where:
- Level 0: Handlers with no dependencies (can run first)
- Level N: Handlers that depend on handlers in levels 0...N-1
Handlers within the same level can execute in parallel.
Attributes:
levels: List of handler sets, where each level can execute in parallel
"""
def __init__(
self,
handlers: dict[Handler, list[Depends]],
) -> None:
"""Initialize the dependency graph.
Args:
handlers: Mapping of handler -> list of `crewai.events.depends.Depends` objects
"""
self.handlers = handlers
self.levels: ExecutionPlan = []
self._resolve()
def _resolve(self) -> None:
"""Resolve dependencies into execution levels using topological sort."""
dependents: dict[Handler, set[Handler]] = defaultdict(set)
in_degree: dict[Handler, int] = {}
for handler in self.handlers:
in_degree[handler] = 0
for handler, deps in self.handlers.items():
in_degree[handler] = len(deps)
for dep in deps:
dependents[dep.handler].add(handler)
queue: deque[Handler] = deque(
[h for h, deg in in_degree.items() if deg == 0]
)
while queue:
current_level: set[Handler] = set()
for _ in range(len(queue)):
handler = queue.popleft()
current_level.add(handler)
for dependent in dependents[handler]:
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
queue.append(dependent)
if current_level:
self.levels.append(current_level)
remaining = [h for h, deg in in_degree.items() if deg > 0]
if remaining:
raise CircularDependencyError(remaining)
def get_execution_plan(self) -> ExecutionPlan:
"""Get the ordered execution plan.
Returns:
List of handler sets, where each set represents handlers that can
execute in parallel. Sets are ordered such that dependencies are
satisfied.
"""
return self.levels
def build_execution_plan(
handlers: Sequence[Handler],
dependencies: dict[Handler, list[Depends]],
) -> ExecutionPlan:
"""Build an execution plan from handlers and their dependencies.
Args:
handlers: All handlers for an event type
dependencies: Mapping of handler -> list of dependencies
Returns:
Execution plan as list of levels, where each level is a set of
handlers that can execute in parallel
Raises:
CircularDependencyError: If circular dependencies are detected
"""
handler_dict: dict[Handler, list[Depends]] = {
h: dependencies.get(h, []) for h in handlers
}
graph = HandlerGraph(handler_dict)
return graph.get_execution_plan()

View File

@@ -1,8 +1,9 @@
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from logging import getLogger
from threading import Condition, Lock
from typing import Any
import uuid
from rich.console import Console
from rich.panel import Panel
@@ -14,6 +15,7 @@ from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
from crewai.utilities.constants import CREWAI_BASE_URL
logger = getLogger(__name__)
@@ -41,6 +43,11 @@ class TraceBatchManager:
"""Single responsibility: Manage batches and event buffering"""
def __init__(self):
self._init_lock = Lock()
self._pending_events_lock = Lock()
self._pending_events_cv = Condition(self._pending_events_lock)
self._pending_events_count = 0
self.is_current_batch_ephemeral: bool = False
self.trace_batch_id: str | None = None
self.current_batch: TraceBatch | None = None
@@ -64,24 +71,28 @@ class TraceBatchManager:
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
) -> TraceBatch:
"""Initialize a new trace batch"""
self.current_batch = TraceBatch(
user_context=user_context, execution_metadata=execution_metadata
)
self.event_buffer.clear()
self.is_current_batch_ephemeral = use_ephemeral
"""Initialize a new trace batch (thread-safe)"""
with self._init_lock:
if self.current_batch is not None:
logger.debug("Batch already initialized, skipping duplicate initialization")
return self.current_batch
self.record_start_time("execution")
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
self.current_batch = TraceBatch(
user_context=user_context, execution_metadata=execution_metadata
)
self.backend_initialized = True
self.is_current_batch_ephemeral = use_ephemeral
return self.current_batch
self.record_start_time("execution")
if should_auto_collect_first_time_traces():
self.trace_batch_id = self.current_batch.batch_id
else:
self._initialize_backend_batch(
user_context, execution_metadata, use_ephemeral
)
self.backend_initialized = True
return self.current_batch
def _initialize_backend_batch(
self,
@@ -148,6 +159,38 @@ class TraceBatchManager:
f"Error initializing trace batch: {e}. Continuing without tracing."
)
def begin_event_processing(self):
"""Mark that an event handler started processing (for synchronization)"""
with self._pending_events_lock:
self._pending_events_count += 1
def end_event_processing(self):
"""Mark that an event handler finished processing (for synchronization)"""
with self._pending_events_cv:
self._pending_events_count -= 1
if self._pending_events_count == 0:
self._pending_events_cv.notify_all()
def wait_for_pending_events(self, timeout: float = 2.0) -> bool:
"""Wait for all pending event handlers to finish processing
Args:
timeout: Maximum time to wait in seconds (default: 2.0)
Returns:
True if all handlers completed, False if timeout occurred
"""
with self._pending_events_cv:
if self._pending_events_count > 0:
logger.debug(f"Waiting for {self._pending_events_count} pending event handlers...")
self._pending_events_cv.wait(timeout)
if self._pending_events_count > 0:
logger.error(
f"Timeout waiting for event handlers. {self._pending_events_count} still pending. Events may be incomplete!"
)
return False
return True
def add_event(self, trace_event: TraceEvent):
"""Add event to buffer"""
self.event_buffer.append(trace_event)
@@ -180,8 +223,8 @@ class TraceBatchManager:
self.event_buffer.clear()
return 200
logger.warning(
f"Failed to send events: {response.status_code}. Events will be lost."
logger.error(
f"Failed to send events: {response.status_code}. Response: {response.text}. Events will be lost."
)
return 500
@@ -196,15 +239,33 @@ class TraceBatchManager:
if not self.current_batch:
return None
self.current_batch.events = self.event_buffer.copy()
if self.event_buffer:
all_handlers_completed = self.wait_for_pending_events(timeout=2.0)
if not all_handlers_completed:
logger.error("Event handler timeout - marking batch as failed due to incomplete events")
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Timeout waiting for event handlers - events incomplete"
)
return None
sorted_events = sorted(
self.event_buffer,
key=lambda e: e.timestamp if hasattr(e, 'timestamp') and e.timestamp else ''
)
self.current_batch.events = sorted_events
events_sent_count = len(sorted_events)
if sorted_events:
original_buffer = self.event_buffer
self.event_buffer = sorted_events
events_sent_to_backend_status = self._send_events_to_backend()
self.event_buffer = original_buffer
if events_sent_to_backend_status == 500:
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Error sending events to backend"
)
return None
self._finalize_backend_batch()
self._finalize_backend_batch(events_sent_count)
finalized_batch = self.current_batch
@@ -220,18 +281,20 @@ class TraceBatchManager:
return finalized_batch
def _finalize_backend_batch(self):
"""Send batch finalization to backend"""
def _finalize_backend_batch(self, events_count: int = 0):
"""Send batch finalization to backend
Args:
events_count: Number of events that were successfully sent
"""
if not self.plus_api or not self.trace_batch_id:
return
try:
total_events = len(self.current_batch.events) if self.current_batch else 0
payload = {
"status": "completed",
"duration_ms": self.calculate_duration("execution"),
"final_event_count": total_events,
"final_event_count": events_count,
}
response = (

View File

@@ -170,14 +170,6 @@ class TraceCollectionListener(BaseEventListener):
def on_flow_finished(source, event):
self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow":
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
self.first_time_handler.handle_execution_completion()
else:
# Normal flow finalization
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
self._handle_action_event("flow_plot", source, event)
@@ -383,10 +375,12 @@ class TraceCollectionListener(BaseEventListener):
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for context end events"""
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
self.batch_manager.begin_event_processing()
try:
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
finally:
self.batch_manager.end_event_processing()
def _handle_action_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for action events (LLM calls, tool usage)"""
@@ -399,18 +393,29 @@ class TraceCollectionListener(BaseEventListener):
}
self.batch_manager.initialize_batch(user_context, execution_metadata)
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
self.batch_manager.begin_event_processing()
try:
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
finally:
self.batch_manager.end_event_processing()
def _create_trace_event(
self, event_type: str, source: Any, event: Any
) -> TraceEvent:
"""Create a trace event"""
trace_event = TraceEvent(
type=event_type,
)
if hasattr(event, 'timestamp') and event.timestamp:
trace_event = TraceEvent(
type=event_type,
timestamp=event.timestamp.isoformat(),
)
else:
trace_event = TraceEvent(
type=event_type,
)
trace_event.event_data = self._build_event_data(event_type, event, source)
return trace_event
def _build_event_data(

View File

@@ -0,0 +1,14 @@
"""Type definitions for event handlers."""
from collections.abc import Callable, Coroutine
from typing import Any, TypeAlias
from crewai.events.base_events import BaseEvent
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
ExecutionPlan: TypeAlias = list[set[Handler]]

View File

@@ -0,0 +1,59 @@
"""Handler utility functions for event processing."""
import functools
import inspect
from typing import Any
from typing_extensions import TypeIs
from crewai.events.base_events import BaseEvent
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
def is_async_handler(
handler: Any,
) -> TypeIs[AsyncHandler]:
"""Type guard to check if handler is an async handler.
Args:
handler: The handler to check
Returns:
True if handler is an async coroutine function
"""
try:
if inspect.iscoroutinefunction(handler) or (
callable(handler) and inspect.iscoroutinefunction(handler.__call__)
):
return True
except AttributeError:
return False
if isinstance(handler, functools.partial) and inspect.iscoroutinefunction(
handler.func
):
return True
return False
def is_call_handler_safe(
handler: SyncHandler,
source: Any,
event: BaseEvent,
) -> Exception | None:
"""Safely call a single handler and return any exception.
Args:
handler: The handler function to call
source: The object that emitted the event
event: The event instance
Returns:
Exception if handler raised one, None otherwise
"""
try:
handler(source, event)
return None
except Exception as e:
return e

View File

@@ -0,0 +1,81 @@
"""Read-write lock for thread-safe concurrent access.
This module provides a reader-writer lock implementation that allows multiple
concurrent readers or a single exclusive writer.
"""
from collections.abc import Generator
from contextlib import contextmanager
from threading import Condition
class RWLock:
"""Read-write lock for managing concurrent read and exclusive write access.
Allows multiple threads to acquire read locks simultaneously, but ensures
exclusive access for write operations. Writers are prioritized when waiting.
Attributes:
_cond: Condition variable for coordinating lock access
_readers: Count of active readers
_writer: Whether a writer currently holds the lock
"""
def __init__(self) -> None:
"""Initialize the read-write lock."""
self._cond = Condition()
self._readers = 0
self._writer = False
def r_acquire(self) -> None:
"""Acquire a read lock, blocking if a writer holds the lock."""
with self._cond:
while self._writer:
self._cond.wait()
self._readers += 1
def r_release(self) -> None:
"""Release a read lock and notify waiting writers if last reader."""
with self._cond:
self._readers -= 1
if self._readers == 0:
self._cond.notify_all()
@contextmanager
def r_locked(self) -> Generator[None, None, None]:
"""Context manager for acquiring a read lock.
Yields:
None
"""
try:
self.r_acquire()
yield
finally:
self.r_release()
def w_acquire(self) -> None:
"""Acquire a write lock, blocking if any readers or writers are active."""
with self._cond:
while self._writer or self._readers > 0:
self._cond.wait()
self._writer = True
def w_release(self) -> None:
"""Release a write lock and notify all waiting threads."""
with self._cond:
self._writer = False
self._cond.notify_all()
@contextmanager
def w_locked(self) -> Generator[None, None, None]:
"""Context manager for acquiring a write lock.
Yields:
None
"""
try:
self.w_acquire()
yield
finally:
self.w_release()

View File

@@ -52,19 +52,14 @@ class AgentEvaluator:
self.console_formatter = ConsoleFormatter()
self.display_formatter = EvaluationDisplayFormatter()
self._thread_local: threading.local = threading.local()
self._execution_state = ExecutionState()
self._state_lock = threading.Lock()
for agent in self.agents:
self._execution_state.agent_evaluators[str(agent.id)] = self.evaluators
self._subscribe_to_events()
@property
def _execution_state(self) -> ExecutionState:
if not hasattr(self._thread_local, "execution_state"):
self._thread_local.execution_state = ExecutionState()
return self._thread_local.execution_state
def _subscribe_to_events(self) -> None:
from typing import cast
@@ -112,21 +107,22 @@ class AgentEvaluator:
state=state,
)
current_iteration = self._execution_state.iteration
if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
with self._state_lock:
current_iteration = self._execution_state.iteration
if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
if (
agent.role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent.role
] = []
if (
agent.role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent.role
] = []
self._execution_state.iterations_results[current_iteration][
agent.role
].append(result)
].append(result)
def _handle_lite_agent_completed(
self, source: object, event: LiteAgentExecutionCompletedEvent
@@ -164,22 +160,23 @@ class AgentEvaluator:
state=state,
)
current_iteration = self._execution_state.iteration
if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
with self._state_lock:
current_iteration = self._execution_state.iteration
if current_iteration not in self._execution_state.iterations_results:
self._execution_state.iterations_results[current_iteration] = {}
agent_role = target_agent.role
if (
agent_role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent_role
] = []
agent_role = target_agent.role
if (
agent_role
not in self._execution_state.iterations_results[current_iteration]
):
self._execution_state.iterations_results[current_iteration][
agent_role
] = []
self._execution_state.iterations_results[current_iteration][
agent_role
].append(result)
].append(result)
def set_iteration(self, iteration: int) -> None:
self._execution_state.iteration = iteration

View File

@@ -3,6 +3,7 @@ import copy
import inspect
import logging
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any, ClassVar, Generic, TypeVar, cast
from uuid import uuid4
@@ -463,6 +464,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._completed_methods: set[str] = set() # Track completed methods for reload
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
self._event_futures: list[Future[None]] = []
# Initialize state with initial values
self._state = self._create_initial_state()
@@ -855,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._initialize_state(filtered_inputs)
# Emit FlowStartedEvent and log the start of the flow.
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
@@ -863,6 +865,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
inputs=inputs,
),
)
if future:
self._event_futures.append(future)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
@@ -881,7 +885,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
final_output = self._method_outputs[-1] if self._method_outputs else None
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
@@ -889,6 +893,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=final_output,
),
)
if future:
self._event_futures.append(future)
if self._event_futures:
await asyncio.gather(*[asyncio.wrap_future(f) for f in self._event_futures])
self._event_futures.clear()
if (
is_tracing_enabled()
or self.tracing
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
if trace_listener.batch_manager.batch_owner_type == "flow":
if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion()
else:
trace_listener.batch_manager.finalize_batch()
return final_output
finally:
@@ -971,7 +994,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {}
)
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
type="method_execution_started",
@@ -981,6 +1004,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
state=self._copy_state(),
),
)
if future:
self._event_futures.append(future)
result = (
await method(*args, **kwargs)
@@ -994,7 +1019,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
self._completed_methods.add(method_name)
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
@@ -1004,10 +1029,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=result,
),
)
if future:
self._event_futures.append(future)
return result
except Exception as e:
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self,
MethodExecutionFailedEvent(
type="method_execution_failed",
@@ -1016,6 +1043,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
error=e,
),
)
if future:
self._event_futures.append(future)
raise e
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:

View File

@@ -1,6 +1,7 @@
"""Test Agent creation and execution basic functionality."""
import os
import threading
from unittest import mock
from unittest.mock import MagicMock, patch
@@ -185,14 +186,17 @@ def test_agent_execution_with_tools():
expected_output="The result of the multiplication.",
)
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
received_events.append(event)
event_received.set()
output = agent.execute_task(task)
assert output == "The result of the multiplication is 12."
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent)
assert received_events[0].tool_name == "multiplier"
@@ -284,10 +288,12 @@ def test_cache_hitting():
'multiplier-{"first_number": 12, "second_number": 3}': 36,
}
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
received_events.append(event)
event_received.set()
with (
patch.object(CacheHandler, "read") as read,
@@ -303,6 +309,7 @@ def test_cache_hitting():
read.assert_called_with(
tool="multiplier", input='{"first_number": 2, "second_number": 6}'
)
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent)
assert received_events[0].from_cache

View File

@@ -1,4 +1,5 @@
# mypy: ignore-errors
import threading
from collections import defaultdict
from typing import cast
from unittest.mock import Mock, patch
@@ -156,14 +157,17 @@ def test_lite_agent_with_tools():
)
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageStartedEvent)
def event_handler(source, event):
received_events.append(event)
event_received.set()
agent.kickoff("What are the effects of climate change on coral reefs?")
# Verify tool usage events were emitted
assert event_received.wait(timeout=5), "Timeout waiting for tool usage events"
assert len(received_events) > 0, "Tool usage events should be emitted"
event = received_events[0]
assert isinstance(event, ToolUsageStartedEvent)
@@ -316,15 +320,18 @@ def test_sets_parent_flow_when_inside_flow():
return agent.kickoff("Test query")
flow = MyFlow()
with crewai_event_bus.scoped_handlers():
event_received = threading.Event()
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def capture_agent(source, event):
nonlocal captured_agent
captured_agent = source
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def capture_agent(source, event):
nonlocal captured_agent
captured_agent = source
event_received.set()
flow.kickoff()
assert captured_agent.parent_flow is flow
flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for agent execution event"
assert captured_agent.parent_flow is flow
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -342,30 +349,43 @@ def test_guardrail_is_called_using_string():
guardrail="""Only include Brazilian players, both women and men""",
)
with crewai_event_bus.scoped_handlers():
all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["started"].append(event)
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["completed"].append(event)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
result = agent.kickoff(messages="Top 10 best players in the world?")
result = agent.kickoff(messages="Top 10 best players in the world?")
assert len(guardrail_events["started"]) == 2
assert len(guardrail_events["completed"]) == 2
assert not guardrail_events["completed"][0].success
assert guardrail_events["completed"][1].success
assert (
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
in result.raw
)
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
assert len(guardrail_events["started"]) == 2
assert len(guardrail_events["completed"]) == 2
assert not guardrail_events["completed"][0].success
assert guardrail_events["completed"][1].success
assert (
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
in result.raw
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -376,29 +396,42 @@ def test_guardrail_is_called_using_callable():
LLMGuardrailStartedEvent,
)
with crewai_event_bus.scoped_handlers():
all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
)
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
)
result = agent.kickoff(messages="Top 1 best players in the world?")
result = agent.kickoff(messages="Top 1 best players in the world?")
assert len(guardrail_events["started"]) == 1
assert len(guardrail_events["completed"]) == 1
assert guardrail_events["completed"][0].success
assert "Pelé - Santos, 1958" in result.raw
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
assert len(guardrail_events["started"]) == 1
assert len(guardrail_events["completed"]) == 1
assert guardrail_events["completed"][0].success
assert "Pelé - Santos, 1958" in result.raw
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -409,37 +442,50 @@ def test_guardrail_reached_attempt_limit():
LLMGuardrailStartedEvent,
)
with crewai_event_bus.scoped_handlers():
all_events_received = threading.Event()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (
False,
"You are not allowed to include Brazilian players",
),
guardrail_max_retries=2,
)
with pytest.raises(
Exception, match="Agent's guardrail failed validation after 2 retries"
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 3
and len(guardrail_events["completed"]) == 3
):
agent.kickoff(messages="Top 10 best players in the world?")
all_events_received.set()
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
assert not guardrail_events["completed"][0].success
assert not guardrail_events["completed"][1].success
assert not guardrail_events["completed"][2].success
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 3
and len(guardrail_events["completed"]) == 3
):
all_events_received.set()
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail=lambda output: (
False,
"You are not allowed to include Brazilian players",
),
guardrail_max_retries=2,
)
with pytest.raises(
Exception, match="Agent's guardrail failed validation after 2 retries"
):
agent.kickoff(messages="Top 10 best players in the world?")
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
assert not guardrail_events["completed"][0].success
assert not guardrail_events["completed"][1].success
assert not guardrail_events["completed"][2].success
@pytest.mark.vcr(filter_headers=["authorization"])

View File

@@ -33,7 +33,7 @@ def setup_test_environment():
except (OSError, IOError) as e:
raise RuntimeError(
f"Test storage directory {storage_dir} is not writable: {e}"
)
) from e
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
os.environ["CREWAI_TESTING"] = "true"
@@ -159,6 +159,29 @@ def mock_opentelemetry_components():
}
@pytest.fixture(autouse=True)
def clear_event_bus_handlers():
"""Clear event bus handlers after each test for isolation.
Handlers registered during the test are allowed to run, then cleaned up
after the test completes.
"""
from crewai.events.event_bus import crewai_event_bus
from crewai.experimental.evaluation.evaluation_listener import (
EvaluationTraceCallback,
)
yield
crewai_event_bus.shutdown(wait=True)
crewai_event_bus._initialize()
callback = EvaluationTraceCallback()
callback.traces.clear()
callback.current_agent_id = None
callback.current_task_id = None
@pytest.fixture(scope="module")
def vcr_config(request) -> dict:
import os

View File

@@ -0,0 +1,286 @@
"""Tests for FastAPI-style dependency injection in event handlers."""
import asyncio
import pytest
from crewai.events import Depends, crewai_event_bus
from crewai.events.base_events import BaseEvent
class DependsTestEvent(BaseEvent):
"""Test event for dependency tests."""
value: int = 0
type: str = "test_event"
@pytest.mark.asyncio
async def test_basic_dependency():
"""Test that handler with dependency runs after its dependency."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup(source, event: DependsTestEvent):
execution_order.append("setup")
@crewai_event_bus.on(DependsTestEvent, Depends(setup))
def process(source, event: DependsTestEvent):
execution_order.append("process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["setup", "process"]
@pytest.mark.asyncio
async def test_multiple_dependencies():
"""Test handler with multiple dependencies."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup_a(source, event: DependsTestEvent):
execution_order.append("setup_a")
@crewai_event_bus.on(DependsTestEvent)
def setup_b(source, event: DependsTestEvent):
execution_order.append("setup_b")
@crewai_event_bus.on(
DependsTestEvent, depends_on=[Depends(setup_a), Depends(setup_b)]
)
def process(source, event: DependsTestEvent):
execution_order.append("process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# setup_a and setup_b can run in any order (same level)
assert "process" in execution_order
assert execution_order.index("process") > execution_order.index("setup_a")
assert execution_order.index("process") > execution_order.index("setup_b")
@pytest.mark.asyncio
async def test_chain_of_dependencies():
"""Test chain of dependencies (A -> B -> C)."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def handler_a(source, event: DependsTestEvent):
execution_order.append("handler_a")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_a))
def handler_b(source, event: DependsTestEvent):
execution_order.append("handler_b")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_b))
def handler_c(source, event: DependsTestEvent):
execution_order.append("handler_c")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["handler_a", "handler_b", "handler_c"]
@pytest.mark.asyncio
async def test_async_handler_with_dependency():
"""Test async handler with dependency on sync handler."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def sync_setup(source, event: DependsTestEvent):
execution_order.append("sync_setup")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(sync_setup))
async def async_process(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("async_process")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["sync_setup", "async_process"]
@pytest.mark.asyncio
async def test_mixed_handlers_with_dependencies():
"""Test mix of sync and async handlers with dependencies."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def setup(source, event: DependsTestEvent):
execution_order.append("setup")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
def sync_process(source, event: DependsTestEvent):
execution_order.append("sync_process")
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
async def async_process(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("async_process")
@crewai_event_bus.on(
DependsTestEvent, depends_on=[Depends(sync_process), Depends(async_process)]
)
def finalize(source, event: DependsTestEvent):
execution_order.append("finalize")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# Verify execution order
assert execution_order[0] == "setup"
assert "finalize" in execution_order
assert execution_order.index("finalize") > execution_order.index("sync_process")
assert execution_order.index("finalize") > execution_order.index("async_process")
@pytest.mark.asyncio
async def test_independent_handlers_run_concurrently():
"""Test that handlers without dependencies can run concurrently."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
async def handler_a(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("handler_a")
@crewai_event_bus.on(DependsTestEvent)
async def handler_b(source, event: DependsTestEvent):
await asyncio.sleep(0.01)
execution_order.append("handler_b")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
# Both handlers should have executed
assert len(execution_order) == 2
assert "handler_a" in execution_order
assert "handler_b" in execution_order
@pytest.mark.asyncio
async def test_circular_dependency_detection():
"""Test that circular dependencies are detected and raise an error."""
from crewai.events.handler_graph import CircularDependencyError, build_execution_plan
# Create circular dependency: handler_a -> handler_b -> handler_c -> handler_a
def handler_a(source, event: DependsTestEvent):
pass
def handler_b(source, event: DependsTestEvent):
pass
def handler_c(source, event: DependsTestEvent):
pass
# Build a dependency graph with a cycle
handlers = [handler_a, handler_b, handler_c]
dependencies = {
handler_a: [Depends(handler_b)],
handler_b: [Depends(handler_c)],
handler_c: [Depends(handler_a)], # Creates the cycle
}
# Should raise CircularDependencyError about circular dependency
with pytest.raises(CircularDependencyError, match="Circular dependency"):
build_execution_plan(handlers, dependencies)
@pytest.mark.asyncio
async def test_handler_without_dependency_runs_normally():
"""Test that handlers without dependencies still work as before."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def simple_handler(source, event: DependsTestEvent):
execution_order.append("simple_handler")
event = DependsTestEvent(value=1)
future = crewai_event_bus.emit("test_source", event)
if future:
await asyncio.wrap_future(future)
assert execution_order == ["simple_handler"]
@pytest.mark.asyncio
async def test_depends_equality():
"""Test Depends equality and hashing."""
def handler_a(source, event):
pass
def handler_b(source, event):
pass
dep_a1 = Depends(handler_a)
dep_a2 = Depends(handler_a)
dep_b = Depends(handler_b)
# Same handler should be equal
assert dep_a1 == dep_a2
assert hash(dep_a1) == hash(dep_a2)
# Different handlers should not be equal
assert dep_a1 != dep_b
assert hash(dep_a1) != hash(dep_b)
@pytest.mark.asyncio
async def test_aemit_ignores_dependencies():
"""Test that aemit only processes async handlers (no dependency support yet)."""
execution_order = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(DependsTestEvent)
def sync_handler(source, event: DependsTestEvent):
execution_order.append("sync_handler")
@crewai_event_bus.on(DependsTestEvent)
async def async_handler(source, event: DependsTestEvent):
execution_order.append("async_handler")
event = DependsTestEvent(value=1)
await crewai_event_bus.aemit("test_source", event)
# Only async handler should execute
assert execution_order == ["async_handler"]

View File

@@ -1,3 +1,5 @@
import threading
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
@@ -19,7 +21,10 @@ from crewai.experimental.evaluation import (
create_default_evaluator,
)
from crewai.experimental.evaluation.agent_evaluator import AgentEvaluator
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult
from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
BaseEvaluator,
)
from crewai.task import Task
@@ -51,12 +56,25 @@ class TestAgentEvaluator:
@pytest.mark.vcr(filter_headers=["authorization"])
def test_evaluate_current_iteration(self, mock_crew):
from crewai.events.types.task_events import TaskCompletedEvent
agent_evaluator = AgentEvaluator(
agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()]
)
task_completed_event = threading.Event()
@crewai_event_bus.on(TaskCompletedEvent)
async def on_task_completed(source, event):
# TaskCompletedEvent fires AFTER evaluation results are stored
task_completed_event.set()
mock_crew.kickoff()
assert task_completed_event.wait(timeout=5), (
"Timeout waiting for task completion"
)
results = agent_evaluator.get_evaluation_results()
assert isinstance(results, dict)
@@ -98,73 +116,15 @@ class TestAgentEvaluator:
]
assert len(agent_evaluator.evaluators) == len(expected_types)
for evaluator, expected_type in zip(agent_evaluator.evaluators, expected_types):
for evaluator, expected_type in zip(
agent_evaluator.evaluators, expected_types, strict=False
):
assert isinstance(evaluator, expected_type)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_eval_lite_agent(self):
agent = Agent(
role="Test Agent",
goal="Complete test tasks successfully",
backstory="An agent created for testing purposes",
)
with crewai_event_bus.scoped_handlers():
events = {}
@crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event):
events["started"] = event
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event):
events["completed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
)
agent.kickoff(messages="Complete this task successfully")
assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id is None
assert events["started"].iteration == 1
assert events["completed"].agent_id == str(agent.id)
assert events["completed"].agent_role == agent.role
assert events["completed"].task_id is None
assert events["completed"].iteration == 1
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 2.0
results = agent_evaluator.get_evaluation_results()
assert isinstance(results, dict)
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
assert result.agent_id == str(agent.id)
assert result.task_id == "lite_task"
(goal_alignment,) = result.metrics.values()
assert goal_alignment.score == 2.0
expected_feedback = "The agent did not demonstrate a clear understanding of the task goal, which is to complete test tasks successfully"
assert expected_feedback in goal_alignment.feedback
assert goal_alignment.raw_response is not None
assert '"score": 2' in goal_alignment.raw_response
@pytest.mark.vcr(filter_headers=["authorization"])
def test_eval_specific_agents_from_crew(self, mock_crew):
from crewai.events.types.task_events import TaskCompletedEvent
agent = Agent(
role="Test Agent Eval",
goal="Complete test tasks successfully",
@@ -178,111 +138,132 @@ class TestAgentEvaluator:
mock_crew.agents.append(agent)
mock_crew.tasks.append(task)
with crewai_event_bus.scoped_handlers():
events = {}
events = {}
started_event = threading.Event()
completed_event = threading.Event()
task_completed_event = threading.Event()
@crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event):
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
)
@crewai_event_bus.on(AgentEvaluationStartedEvent)
async def capture_started(source, event):
if event.agent_id == str(agent.id):
events["started"] = event
started_event.set()
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event):
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
async def capture_completed(source, event):
if event.agent_id == str(agent.id):
events["completed"] = event
completed_event.set()
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
)
mock_crew.kickoff()
@crewai_event_bus.on(TaskCompletedEvent)
async def on_task_completed(source, event):
# TaskCompletedEvent fires AFTER evaluation results are stored
if event.task and event.task.id == task.id:
task_completed_event.set()
assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
mock_crew.kickoff()
assert events["completed"].agent_id == str(agent.id)
assert events["completed"].agent_role == agent.role
assert events["completed"].task_id == str(task.id)
assert events["completed"].iteration == 1
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 5.0
assert started_event.wait(timeout=5), "Timeout waiting for started event"
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
assert task_completed_event.wait(timeout=5), (
"Timeout waiting for task completion"
)
results = agent_evaluator.get_evaluation_results()
assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
assert isinstance(results, dict)
assert len(results.keys()) == 1
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
assert events["completed"].agent_id == str(agent.id)
assert events["completed"].agent_role == agent.role
assert events["completed"].task_id == str(task.id)
assert events["completed"].iteration == 1
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
assert isinstance(events["completed"].score, EvaluationScore)
assert events["completed"].score.score == 5.0
assert result.agent_id == str(agent.id)
assert result.task_id == str(task.id)
results = agent_evaluator.get_evaluation_results()
(goal_alignment,) = result.metrics.values()
assert goal_alignment.score == 5.0
assert isinstance(results, dict)
assert len(results.keys()) == 1
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
assert expected_feedback in goal_alignment.feedback
assert result.agent_id == str(agent.id)
assert result.task_id == str(task.id)
assert goal_alignment.raw_response is not None
assert '"score": 5' in goal_alignment.raw_response
(goal_alignment,) = result.metrics.values()
assert goal_alignment.score == 5.0
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
assert expected_feedback in goal_alignment.feedback
assert goal_alignment.raw_response is not None
assert '"score": 5' in goal_alignment.raw_response
@pytest.mark.vcr(filter_headers=["authorization"])
def test_failed_evaluation(self, mock_crew):
(agent,) = mock_crew.agents
(task,) = mock_crew.tasks
with crewai_event_bus.scoped_handlers():
events = {}
events = {}
started_event = threading.Event()
failed_event = threading.Event()
@crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event):
events["started"] = event
@crewai_event_bus.on(AgentEvaluationStartedEvent)
def capture_started(source, event):
events["started"] = event
started_event.set()
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event):
events["completed"] = event
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
def capture_completed(source, event):
events["completed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
events["failed"] = event
failed_event.set()
# Create a mock evaluator that will raise an exception
from crewai.experimental.evaluation import MetricCategory
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator
class FailingEvaluator(BaseEvaluator):
metric_category = MetricCategory.GOAL_ALIGNMENT
class FailingEvaluator(BaseEvaluator):
metric_category = MetricCategory.GOAL_ALIGNMENT
def evaluate(self, agent, task, execution_trace, final_output):
raise ValueError("Forced evaluation failure")
def evaluate(self, agent, task, execution_trace, final_output):
raise ValueError("Forced evaluation failure")
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[FailingEvaluator()]
)
mock_crew.kickoff()
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[FailingEvaluator()]
)
mock_crew.kickoff()
assert started_event.wait(timeout=5), "Timeout waiting for started event"
assert failed_event.wait(timeout=5), "Timeout waiting for failed event"
assert events.keys() == {"started", "failed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
assert events.keys() == {"started", "failed"}
assert events["started"].agent_id == str(agent.id)
assert events["started"].agent_role == agent.role
assert events["started"].task_id == str(task.id)
assert events["started"].iteration == 1
assert events["failed"].agent_id == str(agent.id)
assert events["failed"].agent_role == agent.role
assert events["failed"].task_id == str(task.id)
assert events["failed"].iteration == 1
assert events["failed"].error == "Forced evaluation failure"
assert events["failed"].agent_id == str(agent.id)
assert events["failed"].agent_role == agent.role
assert events["failed"].task_id == str(task.id)
assert events["failed"].iteration == 1
assert events["failed"].error == "Forced evaluation failure"
results = agent_evaluator.get_evaluation_results()
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
results = agent_evaluator.get_evaluation_results()
(result,) = results[agent.role]
assert isinstance(result, AgentEvaluationResult)
assert result.agent_id == str(agent.id)
assert result.task_id == str(task.id)
assert result.agent_id == str(agent.id)
assert result.task_id == str(task.id)
assert result.metrics == {}
assert result.metrics == {}

View File

@@ -1,23 +1,36 @@
from unittest.mock import MagicMock, patch, ANY
import threading
from collections import defaultdict
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
from unittest.mock import ANY, MagicMock, patch
import pytest
from mem0.memory.main import Memory
from crewai.agent import Agent
from crewai.crew import Crew, Process
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.storage.interface import Storage
from crewai.task import Task
@pytest.fixture(autouse=True)
def cleanup_event_handlers():
"""Cleanup event handlers after each test"""
yield
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
@pytest.fixture
def mock_mem0_memory():
mock_memory = MagicMock(spec=Memory)
@@ -238,24 +251,26 @@ def test_external_memory_search_events(
custom_storage, external_memory_with_mocked_config
):
events = defaultdict(list)
event_received = threading.Event()
external_memory_with_mocked_config.storage = custom_storage
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
event_received.set()
external_memory_with_mocked_config.search(
query="test value",
limit=3,
score_threshold=0.35,
)
external_memory_with_mocked_config.search(
query="test value",
limit=3,
score_threshold=0.35,
)
assert event_received.wait(timeout=5), "Timeout waiting for search events"
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -300,24 +315,25 @@ def test_external_memory_save_events(
custom_storage, external_memory_with_mocked_config
):
events = defaultdict(list)
event_received = threading.Event()
external_memory_with_mocked_config.storage = custom_storage
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
event_received.set()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
external_memory_with_mocked_config.save(
value="saving value",
metadata={"task": "test_task"},
)
external_memory_with_mocked_config.save(
value="saving value",
metadata={"task": "test_task"},
)
assert event_received.wait(timeout=5), "Timeout waiting for save events"
assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1

View File

@@ -1,7 +1,9 @@
import threading
from collections import defaultdict
from unittest.mock import ANY
import pytest
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
@@ -21,27 +23,37 @@ def long_term_memory():
def test_long_term_memory_save_events(long_term_memory):
events = defaultdict(list)
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
if (
len(events["MemorySaveStartedEvent"]) == 1
and len(events["MemorySaveCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
if (
len(events["MemorySaveStartedEvent"]) == 1
and len(events["MemorySaveCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
memory = LongTermMemoryItem(
agent="test_agent",
task="test_task",
expected_output="test_output",
datetime="test_datetime",
quality=0.5,
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
memory = LongTermMemoryItem(
agent="test_agent",
task="test_task",
expected_output="test_output",
datetime="test_datetime",
quality=0.5,
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
assert all_events_received.wait(timeout=5), "Timeout waiting for save events"
assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1
assert len(events["MemorySaveFailedEvent"]) == 0
@@ -86,21 +98,31 @@ def test_long_term_memory_save_events(long_term_memory):
def test_long_term_memory_search_events(long_term_memory):
events = defaultdict(list)
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
if (
len(events["MemoryQueryStartedEvent"]) == 1
and len(events["MemoryQueryCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
if (
len(events["MemoryQueryStartedEvent"]) == 1
and len(events["MemoryQueryCompletedEvent"]) == 1
):
all_events_received.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
test_query = "test query"
test_query = "test query"
long_term_memory.search(test_query, latest_n=5)
long_term_memory.search(test_query, latest_n=5)
assert all_events_received.wait(timeout=5), "Timeout waiting for search events"
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
assert len(events["MemoryQueryFailedEvent"]) == 0

View File

@@ -1,3 +1,4 @@
import threading
from collections import defaultdict
from unittest.mock import ANY, patch
@@ -37,24 +38,33 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
with patch.object(short_term_memory.storage, "search", return_value=[]):
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
# Call the save method
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
short_term_memory.search(
query="test value",
limit=3,
score_threshold=0.35,
)
assert search_started.wait(timeout=2), (
"Timeout waiting for search started event"
)
assert search_completed.wait(timeout=2), (
"Timeout waiting for search completed event"
)
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -98,20 +108,26 @@ def test_short_term_memory_search_events(short_term_memory):
def test_short_term_memory_save_events(short_term_memory):
events = defaultdict(list)
with crewai_event_bus.scoped_handlers():
save_started = threading.Event()
save_completed = threading.Event()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
save_started.set()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
save_completed.set()
short_term_memory.save(
value="test value",
metadata={"task": "test_task"},
)
short_term_memory.save(
value="test value",
metadata={"task": "test_task"},
)
assert save_started.wait(timeout=2), "Timeout waiting for save started event"
assert save_completed.wait(timeout=2), "Timeout waiting for save completed event"
assert len(events["MemorySaveStartedEvent"]) == 1
assert len(events["MemorySaveCompletedEvent"]) == 1

View File

@@ -1,9 +1,10 @@
"""Test Agent creation and execution basic functionality."""
import json
import threading
from collections import defaultdict
from concurrent.futures import Future
from hashlib import md5
import json
import re
from unittest import mock
from unittest.mock import ANY, MagicMock, patch
@@ -2476,62 +2477,63 @@ def test_using_contextual_memory():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_memory_events_are_emitted():
events = defaultdict(list)
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemorySaveStartedEvent)
def handle_memory_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveStartedEvent)
def handle_memory_save_started(source, event):
events["MemorySaveStartedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def handle_memory_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def handle_memory_save_completed(source, event):
events["MemorySaveCompletedEvent"].append(event)
@crewai_event_bus.on(MemorySaveFailedEvent)
def handle_memory_save_failed(source, event):
events["MemorySaveFailedEvent"].append(event)
@crewai_event_bus.on(MemorySaveFailedEvent)
def handle_memory_save_failed(source, event):
events["MemorySaveFailedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def handle_memory_query_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryStartedEvent)
def handle_memory_query_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def handle_memory_query_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def handle_memory_query_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryFailedEvent)
def handle_memory_query_failed(source, event):
events["MemoryQueryFailedEvent"].append(event)
@crewai_event_bus.on(MemoryQueryFailedEvent)
def handle_memory_query_failed(source, event):
events["MemoryQueryFailedEvent"].append(event)
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
def handle_memory_retrieval_started(source, event):
events["MemoryRetrievalStartedEvent"].append(event)
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
def handle_memory_retrieval_started(source, event):
events["MemoryRetrievalStartedEvent"].append(event)
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def handle_memory_retrieval_completed(source, event):
events["MemoryRetrievalCompletedEvent"].append(event)
event_received.set()
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def handle_memory_retrieval_completed(source, event):
events["MemoryRetrievalCompletedEvent"].append(event)
math_researcher = Agent(
role="Researcher",
goal="You research about math.",
backstory="You're an expert in research and you love to learn new things.",
allow_delegation=False,
)
math_researcher = Agent(
role="Researcher",
goal="You research about math.",
backstory="You're an expert in research and you love to learn new things.",
allow_delegation=False,
)
task1 = Task(
description="Research a topic to teach a kid aged 6 about math.",
expected_output="A topic, explanation, angle, and examples.",
agent=math_researcher,
)
task1 = Task(
description="Research a topic to teach a kid aged 6 about math.",
expected_output="A topic, explanation, angle, and examples.",
agent=math_researcher,
)
crew = Crew(
agents=[math_researcher],
tasks=[task1],
memory=True,
)
crew = Crew(
agents=[math_researcher],
tasks=[task1],
memory=True,
)
crew.kickoff()
crew.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for memory events"
assert len(events["MemorySaveStartedEvent"]) == 3
assert len(events["MemorySaveCompletedEvent"]) == 3
assert len(events["MemorySaveFailedEvent"]) == 0
@@ -2907,19 +2909,29 @@ def test_crew_train_success(
copy_mock.return_value = crew
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(CrewTrainStartedEvent)
def on_crew_train_started(source, event: CrewTrainStartedEvent):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
@crewai_event_bus.on(CrewTrainCompletedEvent)
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
)
assert all_events_received.wait(timeout=5), "Timeout waiting for all train events"
# Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
@@ -3726,17 +3738,27 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator, research
llm_instance = LLM("gpt-4o-mini")
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(CrewTestStartedEvent)
def on_crew_test_started(source, event: CrewTestStartedEvent):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
@crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == 2:
all_events_received.set()
crew.test(n_iterations, llm_instance, inputs={"topic": "AI"})
assert all_events_received.wait(timeout=5), "Timeout waiting for all test events"
# Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]

View File

@@ -1,9 +1,12 @@
"""Test Flow creation and execution basic functionality."""
import asyncio
import threading
from datetime import datetime
import pytest
from pydantic import BaseModel
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
@@ -13,7 +16,6 @@ from crewai.events.types.flow_events import (
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from pydantic import BaseModel
def test_simple_sequential_flow():
@@ -439,20 +441,42 @@ def test_unstructured_flow_event_emission():
flow = PoemFlow()
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = (
7 # 1 FlowStarted + 5 MethodExecutionStarted + 1 FlowFinished
)
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff(inputs={"separator": ", "})
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
# (async handlers may append out of order)
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "PoemFlow"
assert received_events[0].inputs == {"separator": ", "}
@@ -642,28 +666,48 @@ def test_structured_flow_event_emission():
return f"Welcome, {self.state.name}!"
flow = OnboardingFlow()
flow.kickoff(inputs={"name": "Anakin"})
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff(inputs={"name": "Anakin"})
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "OnboardingFlow"
assert received_events[0].inputs == {"name": "Anakin"}
@@ -711,25 +755,46 @@ def test_stateless_flow_event_emission():
flow = StatelessFlow()
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) == expected_event_count:
all_events_received.set()
flow.kickoff()
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
# Sort events by timestamp to ensure deterministic order
with lock:
received_events.sort(key=lambda e: e.timestamp)
assert isinstance(received_events[0], FlowStartedEvent)
assert received_events[0].flow_name == "StatelessFlow"
assert received_events[0].inputs is None
@@ -769,13 +834,16 @@ def test_flow_plotting():
flow = StatelessFlow()
flow.kickoff()
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(FlowPlotEvent)
def handle_flow_plot(source, event):
received_events.append(event)
event_received.set()
flow.plot("test_flow")
assert event_received.wait(timeout=5), "Timeout waiting for plot event"
assert len(received_events) == 1
assert isinstance(received_events[0], FlowPlotEvent)
assert received_events[0].flow_name == "StatelessFlow"

View File

@@ -1,3 +1,4 @@
import threading
from unittest.mock import Mock, patch
import pytest
@@ -175,78 +176,92 @@ def test_task_guardrail_process_output(task_output):
def test_guardrail_emits_events(sample_agent):
started_guardrail = []
completed_guardrail = []
all_events_received = threading.Event()
expected_started = 3 # 2 from first task, 1 from second
expected_completed = 3 # 2 from first task, 1 from second
task = Task(
task1 = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
guardrail="Ensure the authors are from Italy",
)
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
assert source == task
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
assert source == task
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
result = task.execute_sync(agent=sample_agent)
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
task.execute_sync(agent=sample_agent)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
expected_started_events = [
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
{
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
result = task1.execute_sync(agent=sample_agent)
expected_completed_events = [
{
"success": False,
"result": None,
"error": "The task result does not comply with the guardrail because none of "
"the listed authors are from Italy. All authors mentioned are from "
"different countries, including Germany, the UK, the USA, and others, "
"which violates the requirement that authors must be Italian.",
"retry_count": 0,
},
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{
"success": True,
"result": "good result from callable function",
"error": None,
"retry_count": 0,
},
]
assert started_guardrail == expected_started_events
assert completed_guardrail == expected_completed_events
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task2 = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
)
task2.execute_sync(agent=sample_agent)
# Wait for all events to be received
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
expected_started_events = [
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
{
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
expected_completed_events = [
{
"success": False,
"result": None,
"error": "The task result does not comply with the guardrail because none of "
"the listed authors are from Italy. All authors mentioned are from "
"different countries, including Germany, the UK, the USA, and others, "
"which violates the requirement that authors must be Italian.",
"retry_count": 0,
},
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{
"success": True,
"result": "good result from callable function",
"error": None,
"retry_count": 0,
},
]
assert started_guardrail == expected_started_events
assert completed_guardrail == expected_completed_events
@pytest.mark.vcr(filter_headers=["authorization"])

View File

@@ -1,6 +1,7 @@
import datetime
import json
import random
import threading
import time
from unittest.mock import MagicMock, patch
@@ -32,7 +33,7 @@ class RandomNumberTool(BaseTool):
args_schema: type[BaseModel] = RandomNumberToolInput
def _run(self, min_value: int, max_value: int) -> int:
return random.randint(min_value, max_value)
return random.randint(min_value, max_value) # noqa: S311
# Example agent and task
@@ -470,13 +471,21 @@ def test_tool_selection_error_event_direct():
)
received_events = []
first_event_received = threading.Event()
second_event_received = threading.Event()
@crewai_event_bus.on(ToolSelectionErrorEvent)
def event_handler(source, event):
received_events.append(event)
if event.tool_name == "Non Existent Tool":
first_event_received.set()
elif event.tool_name == "":
second_event_received.set()
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
tool_usage._select_tool("Non Existent Tool")
assert first_event_received.wait(timeout=5), "Timeout waiting for first event"
assert len(received_events) == 1
event = received_events[0]
assert isinstance(event, ToolSelectionErrorEvent)
@@ -488,12 +497,12 @@ def test_tool_selection_error_event_direct():
assert "A test tool" in event.tool_class
assert "don't exist" in event.error
received_events.clear()
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
tool_usage._select_tool("")
assert len(received_events) == 1
event = received_events[0]
assert second_event_received.wait(timeout=5), "Timeout waiting for second event"
assert len(received_events) == 2
event = received_events[1]
assert isinstance(event, ToolSelectionErrorEvent)
assert event.agent_key == "test_key"
assert event.agent_role == "test_role"
@@ -562,7 +571,7 @@ def test_tool_validate_input_error_event():
# Test invalid input
invalid_input = "invalid json {[}"
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
tool_usage._validate_tool_input(invalid_input)
# Verify event was emitted
@@ -616,12 +625,13 @@ def test_tool_usage_finished_event_with_result():
action=MagicMock(),
)
# Track received events
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event):
received_events.append(event)
event_received.set()
# Call on_tool_use_finished with test data
started_at = time.time()
@@ -634,7 +644,7 @@ def test_tool_usage_finished_event_with_result():
result=result,
)
# Verify event was emitted
assert event_received.wait(timeout=5), "Timeout waiting for event"
assert len(received_events) == 1, "Expected one event to be emitted"
event = received_events[0]
assert isinstance(event, ToolUsageFinishedEvent)
@@ -695,12 +705,13 @@ def test_tool_usage_finished_event_with_cached_result():
action=MagicMock(),
)
# Track received events
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent)
def event_handler(source, event):
received_events.append(event)
event_received.set()
# Call on_tool_use_finished with test data and from_cache=True
started_at = time.time()
@@ -713,7 +724,7 @@ def test_tool_usage_finished_event_with_cached_result():
result=result,
)
# Verify event was emitted
assert event_received.wait(timeout=5), "Timeout waiting for event"
assert len(received_events) == 1, "Expected one event to be emitted"
event = received_events[0]
assert isinstance(event, ToolUsageFinishedEvent)

View File

@@ -14,6 +14,7 @@ from crewai.events.listeners.tracing.trace_listener import (
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.flow.flow import Flow, start
from tests.utils import wait_for_event_handlers
class TestTraceListenerSetup:
@@ -39,38 +40,44 @@ class TestTraceListenerSetup:
):
yield
@pytest.fixture(autouse=True)
def clear_event_bus(self):
"""Clear event bus listeners before and after each test"""
from crewai.events.event_bus import crewai_event_bus
# Store original handlers
original_handlers = crewai_event_bus._handlers.copy()
# Clear for test
crewai_event_bus._handlers.clear()
yield
# Restore original state
crewai_event_bus._handlers.clear()
crewai_event_bus._handlers.update(original_handlers)
@pytest.fixture(autouse=True)
def reset_tracing_singletons(self):
"""Reset tracing singleton instances between tests"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
# Clear event bus handlers BEFORE creating any new singletons
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset TraceCollectionListener singleton
if hasattr(TraceCollectionListener, "_instance"):
TraceCollectionListener._instance = None
TraceCollectionListener._initialized = False
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
yield
# Clean up after test
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
if hasattr(TraceCollectionListener, "_instance"):
TraceCollectionListener._instance = None
TraceCollectionListener._initialized = False
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@pytest.fixture(autouse=True)
def mock_plus_api_calls(self):
"""Mock all PlusAPI HTTP calls to avoid network requests"""
@@ -167,15 +174,26 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus
trace_listener = None
for handler_list in crewai_event_bus._handlers.values():
for handler in handler_list:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_listener = handler.__self__
with crewai_event_bus._rwlock.r_locked():
for handler_set in crewai_event_bus._sync_handlers.values():
for handler in handler_set:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_listener = handler.__self__
break
if trace_listener:
break
if trace_listener:
break
if not trace_listener:
for handler_set in crewai_event_bus._async_handlers.values():
for handler in handler_set:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_listener = handler.__self__
break
if trace_listener:
break
if not trace_listener:
pytest.skip(
@@ -221,6 +239,7 @@ class TestTraceListenerSetup:
wraps=trace_listener.batch_manager.add_event,
) as add_event_mock:
crew.kickoff()
wait_for_event_handlers()
assert add_event_mock.call_count >= 2
@@ -267,24 +286,22 @@ class TestTraceListenerSetup:
from crewai.events.event_bus import crewai_event_bus
trace_handlers = []
for handlers in crewai_event_bus._handlers.values():
for handler in handlers:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_handlers.append(handler)
elif hasattr(handler, "__name__") and any(
trace_name in handler.__name__
for trace_name in [
"on_crew_started",
"on_crew_completed",
"on_flow_started",
]
):
trace_handlers.append(handler)
with crewai_event_bus._rwlock.r_locked():
for handlers in crewai_event_bus._sync_handlers.values():
for handler in handlers:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_handlers.append(handler)
for handlers in crewai_event_bus._async_handlers.values():
for handler in handlers:
if hasattr(handler, "__self__") and isinstance(
handler.__self__, TraceCollectionListener
):
trace_handlers.append(handler)
assert len(trace_handlers) == 0, (
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
f"Found {len(trace_handlers)} TraceCollectionListener handlers when tracing should be disabled"
)
def test_trace_listener_setup_correctly_for_crew(self):
@@ -385,6 +402,7 @@ class TestTraceListenerSetup:
):
crew = Crew(agents=[agent], tasks=[task], tracing=True)
crew.kickoff()
wait_for_event_handlers()
mock_plus_api_class.assert_called_with(api_key="mock_token_12345")
@@ -396,15 +414,33 @@ class TestTraceListenerSetup:
def teardown_method(self):
"""Cleanup after each test method"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
crewai_event_bus._handlers.clear()
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@classmethod
def teardown_class(cls):
"""Final cleanup after all tests in this class"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
crewai_event_bus._handlers.clear()
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
crewai_event_bus._handler_dependencies = {}
crewai_event_bus._execution_plan_cache = {}
# Reset EventListener singleton
if hasattr(EventListener, "_instance"):
EventListener._instance = None
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
@@ -466,6 +502,7 @@ class TestTraceListenerSetup:
) as mock_add_event,
):
result = crew.kickoff()
wait_for_event_handlers()
assert result is not None
assert mock_handle_completion.call_count >= 1
@@ -543,6 +580,7 @@ class TestTraceListenerSetup:
)
crew.kickoff()
wait_for_event_handlers()
assert mock_handle_completion.call_count >= 1, (
"handle_execution_completion should be called"
@@ -561,7 +599,6 @@ class TestTraceListenerSetup:
@pytest.mark.vcr(filter_headers=["authorization"])
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
"""Test the consolidation logic for first-time users vs regular tracing"""
with (
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
patch(
@@ -579,7 +616,9 @@ class TestTraceListenerSetup:
):
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus._handlers.clear()
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers = {}
crewai_event_bus._async_handlers = {}
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
@@ -600,6 +639,9 @@ class TestTraceListenerSetup:
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
result = crew.kickoff()
assert trace_listener.batch_manager.wait_for_pending_events(timeout=5.0), (
"Timeout waiting for trace event handlers to complete"
)
assert mock_initialize.call_count >= 1
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
assert result is not None
@@ -700,6 +742,7 @@ class TestTraceListenerSetup:
) as mock_mark_failed,
):
crew.kickoff()
wait_for_event_handlers()
mock_mark_failed.assert_called_once()
call_args = mock_mark_failed.call_args_list[0]

View File

@@ -0,0 +1,206 @@
"""Tests for async event handling in CrewAI event bus.
This module tests async handler registration, execution, and the aemit method.
"""
import asyncio
import pytest
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import crewai_event_bus
class AsyncTestEvent(BaseEvent):
pass
@pytest.mark.asyncio
async def test_async_handler_execution():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
event = AsyncTestEvent(type="async_test")
crewai_event_bus.emit("test_source", event)
await asyncio.sleep(0.1)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_aemit_with_async_handlers():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
event = AsyncTestEvent(type="async_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_multiple_async_handlers():
received_events_1 = []
received_events_2 = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def handler_1(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events_1.append(event)
@crewai_event_bus.on(AsyncTestEvent)
async def handler_2(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.02)
received_events_2.append(event)
event = AsyncTestEvent(type="async_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events_1) == 1
assert len(received_events_2) == 1
@pytest.mark.asyncio
async def test_mixed_sync_and_async_handlers():
sync_events = []
async_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
def sync_handler(source: object, event: BaseEvent) -> None:
sync_events.append(event)
@crewai_event_bus.on(AsyncTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
async_events.append(event)
event = AsyncTestEvent(type="mixed_test")
crewai_event_bus.emit("test_source", event)
await asyncio.sleep(0.1)
assert len(sync_events) == 1
assert len(async_events) == 1
@pytest.mark.asyncio
async def test_async_handler_error_handling():
successful_handler_called = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def failing_handler(source: object, event: BaseEvent) -> None:
raise ValueError("Async handler error")
@crewai_event_bus.on(AsyncTestEvent)
async def successful_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
successful_handler_called.append(True)
event = AsyncTestEvent(type="error_test")
await crewai_event_bus.aemit("test_source", event)
assert len(successful_handler_called) == 1
@pytest.mark.asyncio
async def test_aemit_with_no_handlers():
with crewai_event_bus.scoped_handlers():
event = AsyncTestEvent(type="no_handlers")
await crewai_event_bus.aemit("test_source", event)
@pytest.mark.asyncio
async def test_async_handler_registration_via_register_handler():
received_events = []
with crewai_event_bus.scoped_handlers():
async def custom_async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
received_events.append(event)
crewai_event_bus.register_handler(AsyncTestEvent, custom_async_handler)
event = AsyncTestEvent(type="register_test")
await crewai_event_bus.aemit("test_source", event)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_emit_async_handlers_fire_and_forget():
received_events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def slow_async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.05)
received_events.append(event)
event = AsyncTestEvent(type="fire_forget_test")
crewai_event_bus.emit("test_source", event)
assert len(received_events) == 0
await asyncio.sleep(0.1)
assert len(received_events) == 1
@pytest.mark.asyncio
async def test_scoped_handlers_with_async():
received_before = []
received_during = []
received_after = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def before_handler(source: object, event: BaseEvent) -> None:
received_before.append(event)
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(AsyncTestEvent)
async def scoped_handler(source: object, event: BaseEvent) -> None:
received_during.append(event)
event1 = AsyncTestEvent(type="during_scope")
await crewai_event_bus.aemit("test_source", event1)
assert len(received_before) == 0
assert len(received_during) == 1
@crewai_event_bus.on(AsyncTestEvent)
async def after_handler(source: object, event: BaseEvent) -> None:
received_after.append(event)
event2 = AsyncTestEvent(type="after_scope")
await crewai_event_bus.aemit("test_source", event2)
assert len(received_before) == 1
assert len(received_during) == 1
assert len(received_after) == 1

View File

@@ -1,3 +1,4 @@
import threading
from unittest.mock import Mock
from crewai.events.base_events import BaseEvent
@@ -21,27 +22,42 @@ def test_specific_event_handler():
mock_handler.assert_called_once_with("source_object", event)
def test_wildcard_event_handler():
mock_handler = Mock()
def test_multiple_handlers_same_event():
"""Test that multiple handlers can be registered for the same event type."""
mock_handler1 = Mock()
mock_handler2 = Mock()
@crewai_event_bus.on(BaseEvent)
def handler(source, event):
mock_handler(source, event)
@crewai_event_bus.on(TestEvent)
def handler1(source, event):
mock_handler1(source, event)
@crewai_event_bus.on(TestEvent)
def handler2(source, event):
mock_handler2(source, event)
event = TestEvent(type="test_event")
crewai_event_bus.emit("source_object", event)
mock_handler.assert_called_once_with("source_object", event)
mock_handler1.assert_called_once_with("source_object", event)
mock_handler2.assert_called_once_with("source_object", event)
def test_event_bus_error_handling(capfd):
@crewai_event_bus.on(BaseEvent)
def test_event_bus_error_handling():
"""Test that handler exceptions are caught and don't break the event bus."""
called = threading.Event()
error_caught = threading.Event()
@crewai_event_bus.on(TestEvent)
def broken_handler(source, event):
called.set()
raise ValueError("Simulated handler failure")
@crewai_event_bus.on(TestEvent)
def working_handler(source, event):
error_caught.set()
event = TestEvent(type="test_event")
crewai_event_bus.emit("source_object", event)
out, err = capfd.readouterr()
assert "Simulated handler failure" in out
assert "Handler 'broken_handler' failed" in out
assert called.wait(timeout=2), "Broken handler was never called"
assert error_caught.wait(timeout=2), "Working handler was never called after error"

View File

@@ -0,0 +1,264 @@
"""Tests for read-write lock implementation.
This module tests the RWLock class for correct concurrent read and write behavior.
"""
import threading
import time
from crewai.events.utils.rw_lock import RWLock
def test_multiple_readers_concurrent():
lock = RWLock()
active_readers = [0]
max_concurrent_readers = [0]
lock_for_counters = threading.Lock()
def reader(reader_id: int) -> None:
with lock.r_locked():
with lock_for_counters:
active_readers[0] += 1
max_concurrent_readers[0] = max(
max_concurrent_readers[0], active_readers[0]
)
time.sleep(0.1)
with lock_for_counters:
active_readers[0] -= 1
threads = [threading.Thread(target=reader, args=(i,)) for i in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert max_concurrent_readers[0] == 5
def test_writer_blocks_readers():
lock = RWLock()
writer_holding_lock = [False]
reader_accessed_during_write = [False]
def writer() -> None:
with lock.w_locked():
writer_holding_lock[0] = True
time.sleep(0.2)
writer_holding_lock[0] = False
def reader() -> None:
time.sleep(0.05)
with lock.r_locked():
if writer_holding_lock[0]:
reader_accessed_during_write[0] = True
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert not reader_accessed_during_write[0]
def test_writer_blocks_other_writers():
lock = RWLock()
execution_order: list[int] = []
lock_for_order = threading.Lock()
def writer(writer_id: int) -> None:
with lock.w_locked():
with lock_for_order:
execution_order.append(writer_id)
time.sleep(0.1)
threads = [threading.Thread(target=writer, args=(i,)) for i in range(3)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(execution_order) == 3
assert len(set(execution_order)) == 3
def test_readers_block_writers():
lock = RWLock()
reader_count = [0]
writer_accessed_during_read = [False]
lock_for_counters = threading.Lock()
def reader() -> None:
with lock.r_locked():
with lock_for_counters:
reader_count[0] += 1
time.sleep(0.2)
with lock_for_counters:
reader_count[0] -= 1
def writer() -> None:
time.sleep(0.05)
with lock.w_locked():
with lock_for_counters:
if reader_count[0] > 0:
writer_accessed_during_read[0] = True
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join()
writer_thread.join()
assert not writer_accessed_during_read[0]
def test_alternating_readers_and_writers():
lock = RWLock()
operations: list[str] = []
lock_for_operations = threading.Lock()
def reader(reader_id: int) -> None:
with lock.r_locked():
with lock_for_operations:
operations.append(f"r{reader_id}_start")
time.sleep(0.05)
with lock_for_operations:
operations.append(f"r{reader_id}_end")
def writer(writer_id: int) -> None:
with lock.w_locked():
with lock_for_operations:
operations.append(f"w{writer_id}_start")
time.sleep(0.05)
with lock_for_operations:
operations.append(f"w{writer_id}_end")
threads = [
threading.Thread(target=reader, args=(0,)),
threading.Thread(target=writer, args=(0,)),
threading.Thread(target=reader, args=(1,)),
threading.Thread(target=writer, args=(1,)),
threading.Thread(target=reader, args=(2,)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(operations) == 10
start_ops = [op for op in operations if "_start" in op]
end_ops = [op for op in operations if "_end" in op]
assert len(start_ops) == 5
assert len(end_ops) == 5
def test_context_manager_releases_on_exception():
lock = RWLock()
exception_raised = False
try:
with lock.r_locked():
raise ValueError("Test exception")
except ValueError:
exception_raised = True
assert exception_raised
acquired = False
with lock.w_locked():
acquired = True
assert acquired
def test_write_lock_releases_on_exception():
lock = RWLock()
exception_raised = False
try:
with lock.w_locked():
raise ValueError("Test exception")
except ValueError:
exception_raised = True
assert exception_raised
acquired = False
with lock.r_locked():
acquired = True
assert acquired
def test_stress_many_readers_few_writers():
lock = RWLock()
read_count = [0]
write_count = [0]
lock_for_counters = threading.Lock()
def reader() -> None:
for _ in range(10):
with lock.r_locked():
with lock_for_counters:
read_count[0] += 1
time.sleep(0.001)
def writer() -> None:
for _ in range(5):
with lock.w_locked():
with lock_for_counters:
write_count[0] += 1
time.sleep(0.01)
reader_threads = [threading.Thread(target=reader) for _ in range(10)]
writer_threads = [threading.Thread(target=writer) for _ in range(2)]
all_threads = reader_threads + writer_threads
for thread in all_threads:
thread.start()
for thread in all_threads:
thread.join()
assert read_count[0] == 100
assert write_count[0] == 10
def test_nested_read_locks_same_thread():
lock = RWLock()
nested_acquired = False
with lock.r_locked():
with lock.r_locked():
nested_acquired = True
assert nested_acquired
def test_manual_acquire_release():
lock = RWLock()
lock.r_acquire()
lock.r_release()
lock.w_acquire()
lock.w_release()
with lock.r_locked():
pass

View File

@@ -0,0 +1,247 @@
"""Tests for event bus shutdown and cleanup behavior.
This module tests graceful shutdown, task completion, and cleanup operations.
"""
import asyncio
import threading
import time
import pytest
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus
class ShutdownTestEvent(BaseEvent):
pass
def test_shutdown_prevents_new_events():
bus = CrewAIEventsBus()
received_events = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
received_events.append(event)
bus._shutting_down = True
event = ShutdownTestEvent(type="after_shutdown")
bus.emit("test_source", event)
time.sleep(0.1)
assert len(received_events) == 0
bus._shutting_down = False
@pytest.mark.asyncio
async def test_aemit_during_shutdown():
bus = CrewAIEventsBus()
received_events = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
async def handler(source: object, event: BaseEvent) -> None:
received_events.append(event)
bus._shutting_down = True
event = ShutdownTestEvent(type="aemit_during_shutdown")
await bus.aemit("test_source", event)
assert len(received_events) == 0
bus._shutting_down = False
def test_shutdown_flag_prevents_emit():
bus = CrewAIEventsBus()
emitted_count = [0]
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
emitted_count[0] += 1
event1 = ShutdownTestEvent(type="before_shutdown")
bus.emit("test_source", event1)
time.sleep(0.1)
assert emitted_count[0] == 1
bus._shutting_down = True
event2 = ShutdownTestEvent(type="during_shutdown")
bus.emit("test_source", event2)
time.sleep(0.1)
assert emitted_count[0] == 1
bus._shutting_down = False
def test_concurrent_access_during_shutdown_flag():
bus = CrewAIEventsBus()
received_events = []
lock = threading.Lock()
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
def emit_events() -> None:
for i in range(10):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
time.sleep(0.01)
def set_shutdown_flag() -> None:
time.sleep(0.05)
bus._shutting_down = True
emit_thread = threading.Thread(target=emit_events)
shutdown_thread = threading.Thread(target=set_shutdown_flag)
emit_thread.start()
shutdown_thread.start()
emit_thread.join()
shutdown_thread.join()
time.sleep(0.2)
assert len(received_events) < 10
assert len(received_events) > 0
bus._shutting_down = False
@pytest.mark.asyncio
async def test_async_handlers_complete_before_shutdown_flag():
bus = CrewAIEventsBus()
completed_handlers = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.05)
if not bus._shutting_down:
completed_handlers.append(event)
for i in range(5):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
await asyncio.sleep(0.3)
assert len(completed_handlers) == 5
def test_scoped_handlers_cleanup():
bus = CrewAIEventsBus()
received_before = []
received_during = []
received_after = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def before_handler(source: object, event: BaseEvent) -> None:
received_before.append(event)
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def during_handler(source: object, event: BaseEvent) -> None:
received_during.append(event)
event1 = ShutdownTestEvent(type="during")
bus.emit("source", event1)
time.sleep(0.1)
assert len(received_before) == 0
assert len(received_during) == 1
event2 = ShutdownTestEvent(type="after_inner_scope")
bus.emit("source", event2)
time.sleep(0.1)
assert len(received_before) == 1
assert len(received_during) == 1
event3 = ShutdownTestEvent(type="after_outer_scope")
bus.emit("source", event3)
time.sleep(0.1)
assert len(received_before) == 1
assert len(received_during) == 1
assert len(received_after) == 0
def test_handler_registration_thread_safety():
bus = CrewAIEventsBus()
handlers_registered = [0]
lock = threading.Lock()
with bus.scoped_handlers():
def register_handlers() -> None:
for _ in range(20):
@bus.on(ShutdownTestEvent)
def handler(source: object, event: BaseEvent) -> None:
pass
with lock:
handlers_registered[0] += 1
time.sleep(0.001)
threads = [threading.Thread(target=register_handlers) for _ in range(3)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert handlers_registered[0] == 60
@pytest.mark.asyncio
async def test_mixed_sync_async_handler_execution():
bus = CrewAIEventsBus()
sync_executed = []
async_executed = []
with bus.scoped_handlers():
@bus.on(ShutdownTestEvent)
def sync_handler(source: object, event: BaseEvent) -> None:
time.sleep(0.01)
sync_executed.append(event)
@bus.on(ShutdownTestEvent)
async def async_handler(source: object, event: BaseEvent) -> None:
await asyncio.sleep(0.01)
async_executed.append(event)
for i in range(5):
event = ShutdownTestEvent(type=f"event_{i}")
bus.emit("source", event)
await asyncio.sleep(0.2)
assert len(sync_executed) == 5
assert len(async_executed) == 5

View File

@@ -0,0 +1,189 @@
"""Tests for thread safety in CrewAI event bus.
This module tests concurrent event emission and handler registration.
"""
import threading
import time
from collections.abc import Callable
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import crewai_event_bus
class ThreadSafetyTestEvent(BaseEvent):
pass
def test_concurrent_emit_from_multiple_threads():
received_events: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(ThreadSafetyTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
threads: list[threading.Thread] = []
num_threads = 10
events_per_thread = 10
def emit_events(thread_id: int) -> None:
for i in range(events_per_thread):
event = ThreadSafetyTestEvent(type=f"thread_{thread_id}_event_{i}")
crewai_event_bus.emit(f"source_{thread_id}", event)
for i in range(num_threads):
thread = threading.Thread(target=emit_events, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
time.sleep(0.5)
assert len(received_events) == num_threads * events_per_thread
def test_concurrent_handler_registration():
handlers_executed: list[int] = []
lock = threading.Lock()
def create_handler(handler_id: int) -> Callable[[object, BaseEvent], None]:
def handler(source: object, event: BaseEvent) -> None:
with lock:
handlers_executed.append(handler_id)
return handler
with crewai_event_bus.scoped_handlers():
threads: list[threading.Thread] = []
num_handlers = 20
def register_handler(handler_id: int) -> None:
crewai_event_bus.register_handler(
ThreadSafetyTestEvent, create_handler(handler_id)
)
for i in range(num_handlers):
thread = threading.Thread(target=register_handler, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
event = ThreadSafetyTestEvent(type="registration_test")
crewai_event_bus.emit("test_source", event)
time.sleep(0.5)
assert len(handlers_executed) == num_handlers
assert set(handlers_executed) == set(range(num_handlers))
def test_concurrent_emit_and_registration():
received_events: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
def emit_continuously() -> None:
for i in range(50):
event = ThreadSafetyTestEvent(type=f"emit_event_{i}")
crewai_event_bus.emit("emitter", event)
time.sleep(0.001)
def register_continuously() -> None:
for _ in range(10):
@crewai_event_bus.on(ThreadSafetyTestEvent)
def handler(source: object, event: BaseEvent) -> None:
with lock:
received_events.append(event)
time.sleep(0.005)
emit_thread = threading.Thread(target=emit_continuously)
register_thread = threading.Thread(target=register_continuously)
emit_thread.start()
register_thread.start()
emit_thread.join()
register_thread.join()
time.sleep(0.5)
assert len(received_events) > 0
def test_stress_test_rapid_emit():
received_count = [0]
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(ThreadSafetyTestEvent)
def counter_handler(source: object, event: BaseEvent) -> None:
with lock:
received_count[0] += 1
num_events = 1000
for i in range(num_events):
event = ThreadSafetyTestEvent(type=f"rapid_event_{i}")
crewai_event_bus.emit("rapid_source", event)
time.sleep(1.0)
assert received_count[0] == num_events
def test_multiple_event_types_concurrent():
class EventTypeA(BaseEvent):
pass
class EventTypeB(BaseEvent):
pass
received_a: list[BaseEvent] = []
received_b: list[BaseEvent] = []
lock = threading.Lock()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(EventTypeA)
def handler_a(source: object, event: BaseEvent) -> None:
with lock:
received_a.append(event)
@crewai_event_bus.on(EventTypeB)
def handler_b(source: object, event: BaseEvent) -> None:
with lock:
received_b.append(event)
def emit_type_a() -> None:
for i in range(50):
crewai_event_bus.emit("source_a", EventTypeA(type=f"type_a_{i}"))
def emit_type_b() -> None:
for i in range(50):
crewai_event_bus.emit("source_b", EventTypeB(type=f"type_b_{i}"))
thread_a = threading.Thread(target=emit_type_a)
thread_b = threading.Thread(target=emit_type_b)
thread_a.start()
thread_b.start()
thread_a.join()
thread_b.join()
time.sleep(0.5)
assert len(received_a) == 50
assert len(received_b) == 50

View File

@@ -1,3 +1,4 @@
import threading
from datetime import datetime
import os
from unittest.mock import Mock, patch
@@ -49,6 +50,8 @@ from crewai.tools.base_tool import BaseTool
from pydantic import Field
import pytest
from ..utils import wait_for_event_handlers
@pytest.fixture(scope="module")
def vcr_config(request) -> dict:
@@ -118,6 +121,7 @@ def test_crew_emits_start_kickoff_event(
# Now when Crew creates EventListener, it will use our mocked telemetry
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
wait_for_event_handlers()
mock_telemetry.crew_execution_span.assert_called_once_with(crew, None)
mock_telemetry.end_crew.assert_called_once_with(crew, "hi")
@@ -131,15 +135,20 @@ def test_crew_emits_start_kickoff_event(
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_end_kickoff_event(base_agent, base_task):
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def handle_crew_end(source, event):
received_events.append(event)
event_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for crew kickoff completed event"
)
assert len(received_events) == 1
assert received_events[0].crew_name == "TestCrew"
assert isinstance(received_events[0].timestamp, datetime)
@@ -165,6 +174,7 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
eval_llm = LLM(model="gpt-4o-mini")
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.test(n_iterations=1, eval_llm=eval_llm)
wait_for_event_handlers()
assert len(received_events) == 3
assert received_events[0].crew_name == "TestCrew"
@@ -181,40 +191,44 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_kickoff_failed_event(base_agent, base_task):
received_events = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(CrewKickoffFailedEvent)
def handle_crew_failed(source, event):
received_events.append(event)
event_received.set()
@crewai_event_bus.on(CrewKickoffFailedEvent)
def handle_crew_failed(source, event):
received_events.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
with patch.object(Crew, "_execute_tasks") as mock_execute:
error_message = "Simulated crew kickoff failure"
mock_execute.side_effect = Exception(error_message)
with patch.object(Crew, "_execute_tasks") as mock_execute:
error_message = "Simulated crew kickoff failure"
mock_execute.side_effect = Exception(error_message)
with pytest.raises(Exception): # noqa: B017
crew.kickoff()
with pytest.raises(Exception): # noqa: B017
crew.kickoff()
assert len(received_events) == 1
assert received_events[0].error == error_message
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_failed"
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
assert len(received_events) == 1
assert received_events[0].error == error_message
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_failed"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_emits_start_task_event(base_agent, base_task):
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(TaskStartedEvent)
def handle_task_start(source, event):
received_events.append(event)
event_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for task started event"
assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "task_started"
@@ -225,10 +239,12 @@ def test_crew_emits_end_task_event(
base_agent, base_task, reset_event_listener_singleton
):
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(TaskCompletedEvent)
def handle_task_end(source, event):
received_events.append(event)
event_received.set()
mock_span = Mock()
@@ -246,6 +262,7 @@ def test_crew_emits_end_task_event(
mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task)
mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew)
assert event_received.wait(timeout=5), "Timeout waiting for task completed event"
assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "task_completed"
@@ -255,11 +272,13 @@ def test_crew_emits_end_task_event(
def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
received_events = []
received_sources = []
event_received = threading.Event()
@crewai_event_bus.on(TaskFailedEvent)
def handle_task_failed(source, event):
received_events.append(event)
received_sources.append(source)
event_received.set()
with patch.object(
Task,
@@ -281,6 +300,9 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
with pytest.raises(Exception): # noqa: B017
agent.execute_task(task=task)
assert event_received.wait(timeout=5), (
"Timeout waiting for task failed event"
)
assert len(received_events) == 1
assert received_sources[0] == task
assert received_events[0].error == error_message
@@ -291,17 +313,27 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_started_and_completed_events(base_agent, base_task):
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(AgentExecutionStartedEvent)
def handle_agent_start(source, event):
received_events.append(event)
with lock:
received_events.append(event)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def handle_agent_completed(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) >= 2:
all_events_received.set()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
assert all_events_received.wait(timeout=5), (
"Timeout waiting for agent execution events"
)
assert len(received_events) == 2
assert received_events[0].agent == base_agent
assert received_events[0].task == base_task
@@ -320,10 +352,12 @@ def test_agent_emits_execution_started_and_completed_events(base_agent, base_tas
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_emits_execution_error_event(base_agent, base_task):
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(AgentExecutionErrorEvent)
def handle_agent_start(source, event):
received_events.append(event)
event_received.set()
error_message = "Error happening while sending prompt to model."
base_agent.max_retry_limit = 0
@@ -337,6 +371,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
task=base_task,
)
assert event_received.wait(timeout=5), (
"Timeout waiting for agent execution error event"
)
assert len(received_events) == 1
assert received_events[0].agent == base_agent
assert received_events[0].task == base_task
@@ -358,10 +395,12 @@ class SayHiTool(BaseTool):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_finished_events():
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(ToolUsageFinishedEvent)
def handle_tool_end(source, event):
received_events.append(event)
event_received.set()
agent = Agent(
role="base_agent",
@@ -377,6 +416,10 @@ def test_tools_emits_finished_events():
)
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
crew.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for tool usage finished event"
)
assert len(received_events) == 1
assert received_events[0].agent_key == agent.key
assert received_events[0].agent_role == agent.role
@@ -389,10 +432,15 @@ def test_tools_emits_finished_events():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_emits_error_events():
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(ToolUsageErrorEvent)
def handle_tool_end(source, event):
received_events.append(event)
with lock:
received_events.append(event)
if len(received_events) >= 48:
all_events_received.set()
class ErrorTool(BaseTool):
name: str = Field(
@@ -423,6 +471,9 @@ def test_tools_emits_error_events():
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
crew.kickoff()
assert all_events_received.wait(timeout=5), (
"Timeout waiting for tool usage error events"
)
assert len(received_events) == 48
assert received_events[0].agent_key == agent.key
assert received_events[0].agent_role == agent.role
@@ -435,11 +486,13 @@ def test_tools_emits_error_events():
def test_flow_emits_start_event(reset_event_listener_singleton):
received_events = []
event_received = threading.Event()
mock_span = Mock()
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]):
@start()
@@ -458,6 +511,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
flow = TestFlow()
flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow"
@@ -466,6 +520,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
def test_flow_name_emitted_to_event_bus():
received_events = []
event_received = threading.Event()
class MyFlowClass(Flow):
name = "PRODUCTION_FLOW"
@@ -477,118 +532,133 @@ def test_flow_name_emitted_to_event_bus():
@crewai_event_bus.on(FlowStartedEvent)
def handle_flow_start(source, event):
received_events.append(event)
event_received.set()
flow = MyFlowClass()
flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
assert len(received_events) == 1
assert received_events[0].flow_name == "PRODUCTION_FLOW"
def test_flow_emits_finish_event():
received_events = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_finish(source, event):
received_events.append(event)
event_received.set()
@crewai_event_bus.on(FlowFinishedEvent)
def handle_flow_finish(source, event):
received_events.append(event)
class TestFlow(Flow[dict]):
@start()
def begin(self):
return "completed"
class TestFlow(Flow[dict]):
@start()
def begin(self):
return "completed"
flow = TestFlow()
result = flow.kickoff()
flow = TestFlow()
result = flow.kickoff()
assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow"
assert received_events[0].type == "flow_finished"
assert received_events[0].result == "completed"
assert result == "completed"
assert event_received.wait(timeout=5), "Timeout waiting for finish event"
assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow"
assert received_events[0].type == "flow_finished"
assert received_events[0].result == "completed"
assert result == "completed"
def test_flow_emits_method_execution_started_event():
received_events = []
lock = threading.Lock()
second_event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_start(source, event):
@crewai_event_bus.on(MethodExecutionStartedEvent)
async def handle_method_start(source, event):
with lock:
received_events.append(event)
if event.method_name == "second_method":
second_event_received.set()
class TestFlow(Flow[dict]):
@start()
def begin(self):
return "started"
class TestFlow(Flow[dict]):
@start()
def begin(self):
return "started"
@listen("begin")
def second_method(self):
return "executed"
@listen("begin")
def second_method(self):
return "executed"
flow = TestFlow()
flow.kickoff()
flow = TestFlow()
flow.kickoff()
assert len(received_events) == 2
assert second_event_received.wait(timeout=5), (
"Timeout waiting for second_method event"
)
assert len(received_events) == 2
assert received_events[0].method_name == "begin"
assert received_events[0].flow_name == "TestFlow"
assert received_events[0].type == "method_execution_started"
# Events may arrive in any order due to async handlers, so check both are present
method_names = {event.method_name for event in received_events}
assert method_names == {"begin", "second_method"}
assert received_events[1].method_name == "second_method"
assert received_events[1].flow_name == "TestFlow"
assert received_events[1].type == "method_execution_started"
for event in received_events:
assert event.flow_name == "TestFlow"
assert event.type == "method_execution_started"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_register_handler_adds_new_handler(base_agent, base_task):
received_events = []
event_received = threading.Event()
def custom_handler(source, event):
received_events.append(event)
event_received.set()
with crewai_event_bus.scoped_handlers():
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_started"
assert event_received.wait(timeout=5), "Timeout waiting for handler event"
assert len(received_events) == 1
assert isinstance(received_events[0].timestamp, datetime)
assert received_events[0].type == "crew_kickoff_started"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_multiple_handlers_for_same_event(base_agent, base_task):
received_events_1 = []
received_events_2 = []
event_received = threading.Event()
def handler_1(source, event):
received_events_1.append(event)
def handler_2(source, event):
received_events_2.append(event)
event_received.set()
with crewai_event_bus.scoped_handlers():
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
crew.kickoff()
assert len(received_events_1) == 1
assert len(received_events_2) == 1
assert received_events_1[0].type == "crew_kickoff_started"
assert received_events_2[0].type == "crew_kickoff_started"
assert event_received.wait(timeout=5), "Timeout waiting for handler events"
assert len(received_events_1) == 1
assert len(received_events_2) == 1
assert received_events_1[0].type == "crew_kickoff_started"
assert received_events_2[0].type == "crew_kickoff_started"
def test_flow_emits_created_event():
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(FlowCreatedEvent)
def handle_flow_created(source, event):
received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]):
@start()
@@ -598,6 +668,7 @@ def test_flow_emits_created_event():
flow = TestFlow()
flow.kickoff()
assert event_received.wait(timeout=5), "Timeout waiting for flow created event"
assert len(received_events) == 1
assert received_events[0].flow_name == "TestFlow"
assert received_events[0].type == "flow_created"
@@ -605,11 +676,13 @@ def test_flow_emits_created_event():
def test_flow_emits_method_execution_failed_event():
received_events = []
event_received = threading.Event()
error = Exception("Simulated method failure")
@crewai_event_bus.on(MethodExecutionFailedEvent)
def handle_method_failed(source, event):
received_events.append(event)
event_received.set()
class TestFlow(Flow[dict]):
@start()
@@ -620,6 +693,9 @@ def test_flow_emits_method_execution_failed_event():
with pytest.raises(Exception): # noqa: B017
flow.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for method execution failed event"
)
assert len(received_events) == 1
assert received_events[0].method_name == "begin"
assert received_events[0].flow_name == "TestFlow"
@@ -641,6 +717,7 @@ def test_llm_emits_call_started_event():
llm = LLM(model="gpt-4o-mini")
llm.call("Hello, how are you?")
wait_for_event_handlers()
assert len(received_events) == 2
assert received_events[0].type == "llm_call_started"
@@ -656,10 +733,12 @@ def test_llm_emits_call_started_event():
@pytest.mark.isolated
def test_llm_emits_call_failed_event():
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_call_failed(source, event):
received_events.append(event)
event_received.set()
error_message = "OpenAI API call failed: Simulated API failure"
@@ -673,6 +752,7 @@ def test_llm_emits_call_failed_event():
llm.call("Hello, how are you?")
assert str(exc_info.value) == "Simulated API failure"
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
assert len(received_events) == 1
assert received_events[0].type == "llm_call_failed"
assert received_events[0].error == error_message
@@ -686,24 +766,28 @@ def test_llm_emits_call_failed_event():
def test_llm_emits_stream_chunk_events():
"""Test that LLM emits stream chunk events when streaming is enabled."""
received_chunks = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 1:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Wait for at least one chunk
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
# Verify that we received chunks
assert len(received_chunks) > 0
# Verify that we received chunks
assert len(received_chunks) > 0
# Verify that concatenating all chunks equals the final response
assert "".join(received_chunks) == response
# Verify that concatenating all chunks equals the final response
assert "".join(received_chunks) == response
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -711,23 +795,21 @@ def test_llm_no_stream_chunks_when_streaming_disabled():
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
received_chunks = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming disabled
llm = LLM(model="gpt-4o", stream=False)
# Create an LLM with streaming disabled
llm = LLM(model="gpt-4o", stream=False)
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Call the LLM with a simple message
response = llm.call("Tell me a short joke")
# Verify that we didn't receive any chunks
assert len(received_chunks) == 0
# Verify that we didn't receive any chunks
assert len(received_chunks) == 0
# Verify we got a response
assert response and isinstance(response, str)
# Verify we got a response
assert response and isinstance(response, str)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -735,98 +817,105 @@ def test_streaming_fallback_to_non_streaming():
"""Test that streaming falls back to non-streaming when there's an error."""
received_chunks = []
fallback_called = False
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 2:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-4o", stream=True)
# Store original methods
original_call = llm.call
# Store original methods
original_call = llm.call
# Create a mock call method that handles the streaming error
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
# Create a mock call method that handles the streaming error
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
nonlocal fallback_called
# Emit a couple of chunks to simulate partial streaming
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
# Mark that fallback would be called
fallback_called = True
# Mark that fallback would be called
fallback_called = True
# Return a response as if fallback succeeded
return "Fallback response after streaming error"
# Return a response as if fallback succeeded
return "Fallback response after streaming error"
# Replace the call method with our mock
llm.call = mock_call
# Replace the call method with our mock
llm.call = mock_call
try:
# Call the LLM
response = llm.call("Tell me a short joke")
wait_for_event_handlers()
try:
# Call the LLM
response = llm.call("Tell me a short joke")
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
# Verify that we received some chunks
assert len(received_chunks) == 2
assert received_chunks[0] == "Test chunk 1"
assert received_chunks[1] == "Test chunk 2"
# Verify that we received some chunks
assert len(received_chunks) == 2
assert received_chunks[0] == "Test chunk 1"
assert received_chunks[1] == "Test chunk 2"
# Verify fallback was triggered
assert fallback_called
# Verify fallback was triggered
assert fallback_called
# Verify we got the fallback response
assert response == "Fallback response after streaming error"
# Verify we got the fallback response
assert response == "Fallback response after streaming error"
finally:
# Restore the original method
llm.call = original_call
finally:
# Restore the original method
llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"])
def test_streaming_empty_response_handling():
"""Test that streaming handles empty responses correctly."""
received_chunks = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
if len(received_chunks) >= 3:
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
received_chunks.append(event.chunk)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-3.5-turbo", stream=True)
# Create an LLM with streaming enabled
llm = LLM(model="gpt-3.5-turbo", stream=True)
# Store original methods
original_call = llm.call
# Store original methods
original_call = llm.call
# Create a mock call method that simulates empty chunks
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
# Create a mock call method that simulates empty chunks
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
# Emit a few empty chunks
for _ in range(3):
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
# Return the default message for empty responses
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# Return the default message for empty responses
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
# Replace the call method with our mock
llm.call = mock_call
# Replace the call method with our mock
llm.call = mock_call
try:
# Call the LLM - this should handle empty response
response = llm.call("Tell me a short joke")
try:
# Call the LLM - this should handle empty response
response = llm.call("Tell me a short joke")
assert event_received.wait(timeout=5), "Timeout waiting for empty chunks"
# Verify that we received empty chunks
assert len(received_chunks) == 3
assert all(chunk == "" for chunk in received_chunks)
# Verify that we received empty chunks
assert len(received_chunks) == 3
assert all(chunk == "" for chunk in received_chunks)
# Verify the response is the default message for empty responses
assert "I apologize" in response and "couldn't generate" in response
# Verify the response is the default message for empty responses
assert "I apologize" in response and "couldn't generate" in response
finally:
# Restore the original method
llm.call = original_call
finally:
# Restore the original method
llm.call = original_call
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -835,41 +924,49 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
failed_event = []
started_event = []
stream_event = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
if len(started_event) >= 1 and len(stream_event) >= 12:
event_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
if (
len(completed_event) >= 1
and len(started_event) >= 1
and len(stream_event) >= 12
):
event_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
agent = Agent(
role="TestAgent",
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
task = Task(
description="Just say hi",
expected_output="hi",
llm=LLM(model="gpt-4o-mini", stream=True),
agent=agent,
)
agent = Agent(
role="TestAgent",
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
task = Task(
description="Just say hi",
expected_output="hi",
llm=LLM(model="gpt-4o-mini", stream=True),
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
assert len(completed_event) == 1
assert len(failed_event) == 0
assert len(started_event) == 1
@@ -899,28 +996,30 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
failed_event = []
started_event = []
stream_event = []
event_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
if len(started_event) >= 1:
event_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
crew = Crew(agents=[base_agent], tasks=[base_task])
crew.kickoff()
crew = Crew(agents=[base_agent], tasks=[base_task])
crew.kickoff()
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
assert len(completed_event) == 1
assert len(failed_event) == 0
assert len(started_event) == 1
@@ -950,32 +1049,41 @@ def test_llm_emits_event_with_lite_agent():
failed_event = []
started_event = []
stream_event = []
all_events_received = threading.Event()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallFailedEvent)
def handle_llm_failed(source, event):
failed_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
started_event.append(event)
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
if len(started_event) >= 1 and len(stream_event) >= 15:
all_events_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_completed(source, event):
completed_event.append(event)
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
if (
len(completed_event) >= 1
and len(started_event) >= 1
and len(stream_event) >= 15
):
all_events_received.set()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_llm_stream_chunk(source, event):
stream_event.append(event)
agent = Agent(
role="Speaker",
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
agent = Agent(
role="Speaker",
llm=LLM(model="gpt-4o-mini", stream=True),
goal="Just say hi",
backstory="You are a helpful assistant that just says hi",
)
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
assert all_events_received.wait(timeout=10), "Timeout waiting for all events"
assert len(completed_event) == 1
assert len(failed_event) == 0

39
lib/crewai/tests/utils.py Normal file
View File

@@ -0,0 +1,39 @@
"""Test utilities for CrewAI tests."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
def wait_for_event_handlers(timeout: float = 5.0) -> None:
"""Wait for all pending event handlers to complete.
This helper ensures all sync and async handlers finish processing before
proceeding. Useful in tests to make assertions deterministic.
Args:
timeout: Maximum time to wait in seconds.
"""
from crewai.events.event_bus import crewai_event_bus
loop = getattr(crewai_event_bus, "_loop", None)
if loop and not loop.is_closed():
async def _wait_for_async_tasks() -> None:
tasks = {
t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()
}
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
future = asyncio.run_coroutine_threadsafe(_wait_for_async_tasks(), loop)
try:
future.result(timeout=timeout)
except Exception: # noqa: S110
pass
crewai_event_bus._sync_executor.shutdown(wait=True)
crewai_event_bus._sync_executor = ThreadPoolExecutor(
max_workers=10,
thread_name_prefix="CrewAISyncHandler",
)

11
uv.lock generated
View File

@@ -465,15 +465,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/01/f3/a9d961cfba236dc85f27f2f2c6eab88e12698754aaa02459ba7dfafc5062/bedrock_agentcore-0.1.7-py3-none-any.whl", hash = "sha256:441dde64fea596e9571e47ae37ee3b033e58d8d255018f13bdcde8ae8bef2075", size = 77216, upload-time = "2025-10-01T16:18:38.153Z" },
]
[[package]]
name = "blinker"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460, upload-time = "2024-11-08T17:25:47.436Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
]
[[package]]
name = "boto3"
version = "1.40.45"
@@ -987,7 +978,6 @@ name = "crewai"
source = { editable = "lib/crewai" }
dependencies = [
{ name = "appdirs" },
{ name = "blinker" },
{ name = "chromadb" },
{ name = "click" },
{ name = "instructor" },
@@ -1061,7 +1051,6 @@ watson = [
requires-dist = [
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
{ name = "appdirs", specifier = ">=1.4.4" },
{ name = "blinker", specifier = ">=1.9.0" },
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
{ name = "chromadb", specifier = "~=1.1.0" },