Compare commits

...

2 Commits

Author SHA1 Message Date
Lucas Gomide
eaf550521c global handlers(?) WIP 2025-07-15 13:47:53 -03:00
Lucas Gomide
7b1ee07b18 feat: implement thread-safe event bus
This PR also added tests to ensure singleton pattern
2025-07-14 11:42:31 -03:00
2 changed files with 604 additions and 19 deletions

View File

@@ -1,6 +1,6 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast from typing import Any, Callable, Type, TypeVar, cast
from blinker import Signal from blinker import Signal
@@ -14,10 +14,13 @@ class CrewAIEventsBus:
""" """
A singleton event bus that uses blinker signals for event handling. A singleton event bus that uses blinker signals for event handling.
Allows both internal (Flow/Crew) and external event handling. Allows both internal (Flow/Crew) and external event handling.
Handlers are global by default for cross-thread communication,
with optional thread-local isolation for testing scenarios.
""" """
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
_thread_local: threading.local = threading.local()
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
@@ -30,7 +33,46 @@ class CrewAIEventsBus:
def _initialize(self) -> None: def _initialize(self) -> None:
"""Initialize the event bus internal state""" """Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus") self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {} self._global_handlers: dict[type[BaseEvent], list[Callable]] = {}
@property
def _handlers(self) -> dict[type[BaseEvent], list[Callable]]:
if not hasattr(CrewAIEventsBus._thread_local, "handlers"):
CrewAIEventsBus._thread_local.handlers = {}
return CrewAIEventsBus._thread_local.handlers
@_handlers.setter
def _handlers(self, value: dict[type[BaseEvent], list[Callable]]) -> None:
if not hasattr(CrewAIEventsBus._thread_local, "handlers"):
CrewAIEventsBus._thread_local.handlers = {}
CrewAIEventsBus._thread_local.handlers = value
def _add_handler_with_deduplication(
self, handlers_dict: dict, event_type: Type[BaseEvent], handler: Callable
) -> bool:
"""
Add a handler to the specified handlers dictionary with deduplication.
Args:
handlers_dict: The dictionary to add the handler to
event_type: The event type
handler: The handler function to add
Returns:
bool: True if handler was added, False if it was already present
"""
if event_type not in handlers_dict:
handlers_dict[event_type] = []
# Check if handler is already registered
for existing_handler in handlers_dict[event_type]:
if existing_handler is handler:
# Handler already exists, don't add duplicate
return False
# Add the handler
handlers_dict[event_type].append(handler)
return True
def on( def on(
self, event_type: Type[EventT] self, event_type: Type[EventT]
@@ -38,6 +80,13 @@ class CrewAIEventsBus:
""" """
Decorator to register an event handler for a specific event type. Decorator to register an event handler for a specific event type.
Handlers registered with this decorator are global by default,
allowing cross-thread event communication. Use scoped_handlers()
for thread-local isolation in testing scenarios.
Duplicate handlers are automatically prevented - the same handler
function will only be registered once per event type.
Usage: Usage:
@crewai_event_bus.on(AgentExecutionCompletedEvent) @crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed( def on_agent_execution_completed(
@@ -50,23 +99,38 @@ class CrewAIEventsBus:
def decorator( def decorator(
handler: Callable[[Any, EventT], None], handler: Callable[[Any, EventT], None],
) -> Callable[[Any, EventT], None]: ) -> Callable[[Any, EventT], None]:
if event_type not in self._handlers: was_added = self._add_handler_with_deduplication(
self._handlers[event_type] = [] self._global_handlers, event_type, handler
self._handlers[event_type].append(
cast(Callable[[Any, EventT], None], handler)
) )
if not was_added:
# Log that duplicate was prevented (optional)
print(
f"[EventBus Info] Handler '{handler.__name__}' already registered for {event_type.__name__}"
)
return handler return handler
return decorator return decorator
def emit(self, source: Any, event: BaseEvent) -> None: def emit(self, source: Any, event: BaseEvent) -> None:
""" """
Emit an event to all registered handlers Emit an event to all registered handlers (both global and thread-local)
Args: Args:
source: The object emitting the event source: The object emitting the event
event: The event instance to emit event: The event instance to emit
""" """
# Call global handlers (default behavior, cross-thread)
for event_type, handlers in self._global_handlers.items():
if isinstance(event, event_type):
for handler in handlers:
try:
handler(source, event)
except Exception as e:
print(
f"[EventBus Error] Global handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
)
# Call thread-local handlers (for testing isolation)
for event_type, handlers in self._handlers.items(): for event_type, handlers in self._handlers.items():
if isinstance(event, event_type): if isinstance(event, event_type):
for handler in handlers: for handler in handlers:
@@ -74,32 +138,76 @@ class CrewAIEventsBus:
handler(source, event) handler(source, event)
except Exception as e: except Exception as e:
print( print(
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}" f"[EventBus Error] Thread-local handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
) )
# Send to blinker signal (existing mechanism)
self._signal.send(source, event=event) self._signal.send(source, event=event)
def register_handler( def register_handler(
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None] self, event_type: Type[BaseEvent], handler: Callable[[Any, BaseEvent], None]
) -> None: ) -> bool:
"""Register an event handler for a specific event type""" """
if event_type not in self._handlers: Register an event handler for a specific event type (global)
self._handlers[event_type] = []
self._handlers[event_type].append( Args:
cast(Callable[[Any, EventTypes], None], handler) event_type: The event type to handle
handler: The handler function to register
Returns:
bool: True if handler was added, False if it was already present
"""
return self._add_handler_with_deduplication(
self._global_handlers, event_type, handler
) )
def unregister_handler(
self, event_type: Type[BaseEvent], handler: Callable[[Any, BaseEvent], None]
) -> bool:
"""
Unregister an event handler for a specific event type (global)
Args:
event_type: The event type
handler: The handler function to unregister
Returns:
bool: True if handler was removed, False if it wasn't found
"""
if event_type in self._global_handlers:
try:
self._global_handlers[event_type].remove(handler)
return True
except ValueError:
return False
return False
def get_handler_count(self, event_type: Type[BaseEvent]) -> int:
"""
Get the number of handlers registered for a specific event type
Args:
event_type: The event type to check
Returns:
int: Number of handlers registered for this event type
"""
return len(self._global_handlers.get(event_type, []))
@contextmanager @contextmanager
def scoped_handlers(self): def scoped_handlers(self):
""" """
Context manager for temporary event handling scope. Context manager for temporary thread-local event handling scope.
Useful for testing or temporary event handling. Useful for testing or temporary event handling with thread isolation.
This creates thread-local handlers that are isolated from global handlers,
making it useful for testing scenarios where you want to avoid interference.
Usage: Usage:
with crewai_event_bus.scoped_handlers(): with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(CrewKickoffStarted) @crewai_event_bus.on(CrewKickoffStarted)
def temp_handler(source, event): def temp_handler(source, event):
print("Temporary handler") print("Temporary thread-local handler")
# Do stuff... # Do stuff...
# Handlers are cleared after the context # Handlers are cleared after the context
""" """
@@ -110,6 +218,25 @@ class CrewAIEventsBus:
finally: finally:
self._handlers = previous_handlers self._handlers = previous_handlers
@contextmanager
def scoped_global_handlers(self):
"""
Context manager for temporary global event handling scope.
Useful for testing or temporary global event handling.
Usage:
with crewai_event_bus.scoped_global_handlers():
crewai_event_bus.register_handler(CrewKickoffStarted, temp_handler)
# Do stuff...
# Global handlers are cleared after the context
"""
previous_global_handlers = self._global_handlers.copy()
self._global_handlers.clear()
try:
yield
finally:
self._global_handlers = previous_global_handlers
# Global instance # Global instance
crewai_event_bus = CrewAIEventsBus() crewai_event_bus = CrewAIEventsBus()

View File

@@ -1,13 +1,31 @@
import threading
from typing import Any, Callable, cast
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from crewai.utilities.events.base_events import BaseEvent from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.crewai_event_bus import crewai_event_bus from crewai.utilities.events.crewai_event_bus import crewai_event_bus
@pytest.fixture(autouse=True)
def scoped_event_handlers():
with crewai_event_bus.scoped_handlers():
yield
class TestEvent(BaseEvent): class TestEvent(BaseEvent):
pass pass
class AnotherThreadTestEvent(BaseEvent):
pass
class CrossThreadTestEvent(BaseEvent):
pass
def test_specific_event_handler(): def test_specific_event_handler():
mock_handler = Mock() mock_handler = Mock()
@@ -44,4 +62,444 @@ def test_event_bus_error_handling(capfd):
out, err = capfd.readouterr() out, err = capfd.readouterr()
assert "Simulated handler failure" in out assert "Simulated handler failure" in out
assert "Handler 'broken_handler' failed" in out assert "Global handler 'broken_handler' failed" in out
def test_singleton_pattern_across_threads():
instances = []
def get_instance():
instances.append(crewai_event_bus)
threads = []
for _ in range(10):
thread = threading.Thread(target=get_instance)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(instances) == 10
for instance in instances:
assert instance is crewai_event_bus
assert instance is instances[0]
def test_default_handlers_are_global():
"""Test that handlers registered with @crewai_event_bus.on() are global by default."""
received_events = []
mock_handler = Mock()
@crewai_event_bus.on(CrossThreadTestEvent)
def global_handler(source, event):
received_events.append((source, event))
mock_handler(source, event)
def thread_worker(thread_id):
# Emit event from a different thread
event = CrossThreadTestEvent(type=f"cross_thread_event_{thread_id}")
crewai_event_bus.emit(f"thread_source_{thread_id}", event)
# Start multiple threads that emit events
threads = []
for i in range(3):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Verify that the global handler received all events from different threads
assert len(received_events) == 3
assert mock_handler.call_count == 3
# Check that events from different threads were received
for i in range(3):
source, event = received_events[i]
assert source == f"thread_source_{i}"
assert event.type == f"cross_thread_event_{i}"
def test_scoped_handlers_thread_isolation():
"""Test that scoped_handlers() provides thread-local isolation for testing."""
global_events = []
scoped_events = []
# Register a global handler
@crewai_event_bus.on(CrossThreadTestEvent)
def global_handler(source, event):
global_events.append((source, event))
# Emit an event - should be received by global handler
event1 = CrossThreadTestEvent(type="event_1")
crewai_event_bus.emit("source_1", event1)
assert len(global_events) == 1
# Use scoped handlers for testing isolation
with crewai_event_bus.scoped_handlers():
# Register a handler in the scoped context (thread-local)
@crewai_event_bus.on(CrossThreadTestEvent)
def scoped_handler(source, event):
scoped_events.append((source, event))
# Emit event - should be received by scoped handler only
event2 = CrossThreadTestEvent(type="event_2")
crewai_event_bus.emit("source_2", event2)
# After scope, emit another event - should be received by global handler only
event3 = CrossThreadTestEvent(type="event_3")
crewai_event_bus.emit("source_3", event3)
# Verify events
assert len(global_events) == 2 # event_1 and event_3
assert len(scoped_events) == 1 # only event_2
assert global_events[0] == ("source_1", event1)
assert scoped_events[0] == ("source_2", event2)
assert global_events[1] == ("source_3", event3)
def test_scoped_handlers_thread_safety():
"""Test that scoped handlers work correctly across multiple threads."""
thread_results = {}
def thread_worker(thread_id):
with crewai_event_bus.scoped_handlers():
mock_handler = Mock()
@crewai_event_bus.on(AnotherThreadTestEvent)
def scoped_handler(source, event):
mock_handler(f"scoped_thread_{thread_id}", event)
scoped_event = AnotherThreadTestEvent(type=f"scoped_event_{thread_id}")
crewai_event_bus.emit(f"scoped_source_{thread_id}", scoped_event)
thread_results[thread_id] = {
"mock_handler": mock_handler,
"scoped_event": scoped_event,
}
# After scope, emit event - should not be received by scoped handler
post_scoped_event = AnotherThreadTestEvent(type=f"post_scoped_{thread_id}")
crewai_event_bus.emit(f"post_source_{thread_id}", post_scoped_event)
threads = []
for i in range(5):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
for thread_id, result in thread_results.items():
result["mock_handler"].assert_called_once_with(
f"scoped_thread_{thread_id}", result["scoped_event"]
)
def test_register_handler_method():
"""Test the register_handler method works with global handlers."""
received_events = []
def handler(source, event):
received_events.append((source, event))
# Register handler using the method
crewai_event_bus.register_handler(CrossThreadTestEvent, handler)
# Emit event from different thread
def thread_worker():
event = CrossThreadTestEvent(type="test_event")
crewai_event_bus.emit("thread_source", event)
thread = threading.Thread(target=thread_worker)
thread.start()
thread.join()
# Verify handler received the event
assert len(received_events) == 1
assert received_events[0] == (
"thread_source",
CrossThreadTestEvent(type="test_event"),
)
def test_scoped_global_handlers():
"""Test the scoped_global_handlers context manager."""
global_events = []
def global_handler(source, event):
global_events.append((source, event))
# Register a global handler
crewai_event_bus.register_handler(CrossThreadTestEvent, global_handler)
# Emit an event - should be received
event1 = CrossThreadTestEvent(type="event_1")
crewai_event_bus.emit("source_1", event1)
assert len(global_events) == 1
# Use scoped global handlers
with crewai_event_bus.scoped_global_handlers():
# Register a different handler in scope
def scoped_handler(source, event):
global_events.append(("scoped", source, event))
crewai_event_bus.register_handler(CrossThreadTestEvent, scoped_handler)
# Emit event - should be received by scoped handler
event2 = CrossThreadTestEvent(type="event_2")
crewai_event_bus.emit("source_2", event2)
# After scope, original handler should be restored
event3 = CrossThreadTestEvent(type="event_3")
crewai_event_bus.emit("source_3", event3)
# Verify events
assert len(global_events) == 3
assert global_events[0] == ("source_1", event1)
assert global_events[1] == ("scoped", "source_2", event2)
assert global_events[2] == ("source_3", event3)
def test_handler_duplication_scenarios():
"""Test various scenarios where handler duplication can occur."""
call_counts = []
def handler(source, event):
call_counts.append(1)
# Scenario 1: Register the same handler multiple times
crewai_event_bus.register_handler(TestEvent, handler)
crewai_event_bus.register_handler(TestEvent, handler) # Duplicate registration
# Scenario 2: Use decorator multiple times on the same function
@crewai_event_bus.on(TestEvent)
def decorated_handler1(source, event):
call_counts.append(1)
@crewai_event_bus.on(TestEvent)
def decorated_handler2(source, event): # Same function name, different instance
call_counts.append(1)
# Emit an event
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Currently, all handlers are called (including duplicates)
# This shows the current behavior - handlers can be duplicated
assert len(call_counts) >= 4 # At least 4 calls (2 direct + 2 decorated)
def test_module_reload_duplication():
"""Test duplication that could occur from module reloading."""
call_counts = []
def create_handler():
def handler(source, event):
call_counts.append(1)
return handler
# Simulate module reload scenario
handler1 = create_handler()
handler2 = create_handler() # Same function, different instance
crewai_event_bus.register_handler(TestEvent, handler1)
crewai_event_bus.register_handler(TestEvent, handler2)
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Both handlers are called (duplication)
assert len(call_counts) == 2
def test_listener_class_duplication():
"""Test duplication from multiple listener class instances."""
call_counts = []
class TestListener:
def __init__(self):
@crewai_event_bus.on(TestEvent)
def handler(source, event):
call_counts.append(1)
# Create multiple instances (simulating multiple imports)
listener1 = TestListener()
listener2 = TestListener()
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Both instances register handlers (duplication)
assert len(call_counts) == 2
def test_handler_deduplication():
"""Test that duplicate handlers are automatically prevented."""
call_counts = []
def handler(source, event):
call_counts.append(1)
# Register the same handler multiple times
result1 = crewai_event_bus.register_handler(TestEvent, handler)
result2 = crewai_event_bus.register_handler(
TestEvent, handler
) # Duplicate registration
# First registration should succeed, second should fail
assert result1 is True
assert result2 is False
# Emit an event
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Handler should only be called once (no duplication)
assert len(call_counts) == 1
def test_decorator_deduplication():
"""Test that decorator prevents duplicate registrations."""
call_counts = []
# Define the same handler function
def handler(source, event):
call_counts.append(1)
# Register using decorator
@crewai_event_bus.on(TestEvent)
def decorated_handler(source, event):
call_counts.append(1)
# Try to register the same function again using register_handler
result = crewai_event_bus.register_handler(
TestEvent, cast(Callable[[Any, BaseEvent], None], decorated_handler)
)
# Should fail because it's already registered
assert result is False
# Emit an event
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Should only be called once
assert len(call_counts) == 1
def test_handler_unregistration():
"""Test that handlers can be unregistered."""
call_counts = []
def handler(source, event):
call_counts.append(1)
# Register handler
crewai_event_bus.register_handler(TestEvent, handler)
# Verify it's registered
assert crewai_event_bus.get_handler_count(TestEvent) == 1
# Emit event - should be called
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
assert len(call_counts) == 1
# Unregister handler
result = crewai_event_bus.unregister_handler(TestEvent, handler)
assert result is True
assert crewai_event_bus.get_handler_count(TestEvent) == 0
# Emit event again - should not be called
crewai_event_bus.emit("source", event)
assert len(call_counts) == 1 # Still only 1 call
def test_handler_count_tracking():
"""Test that handler counts are tracked correctly."""
def handler1(source, event):
pass
def handler2(source, event):
pass
# Initially no handlers
assert crewai_event_bus.get_handler_count(TestEvent) == 0
# Register first handler
crewai_event_bus.register_handler(TestEvent, handler1)
assert crewai_event_bus.get_handler_count(TestEvent) == 1
# Register second handler
crewai_event_bus.register_handler(TestEvent, handler2)
assert crewai_event_bus.get_handler_count(TestEvent) == 2
# Try to register first handler again (should fail)
crewai_event_bus.register_handler(TestEvent, handler1)
assert crewai_event_bus.get_handler_count(TestEvent) == 2 # Count unchanged
# Unregister first handler
crewai_event_bus.unregister_handler(TestEvent, handler1)
assert crewai_event_bus.get_handler_count(TestEvent) == 1
# Unregister second handler
crewai_event_bus.unregister_handler(TestEvent, handler2)
assert crewai_event_bus.get_handler_count(TestEvent) == 0
def test_different_event_types_dont_conflict():
"""Test that handlers for different event types don't interfere."""
test_event_calls = []
cross_thread_calls = []
def test_event_handler(source, event):
test_event_calls.append(1)
def cross_thread_handler(source, event):
cross_thread_calls.append(1)
# Register handlers for different event types
crewai_event_bus.register_handler(TestEvent, test_event_handler)
crewai_event_bus.register_handler(CrossThreadTestEvent, cross_thread_handler)
# Emit TestEvent
test_event = TestEvent(type="test")
crewai_event_bus.emit("source", test_event)
assert len(test_event_calls) == 1
assert len(cross_thread_calls) == 0
# Emit CrossThreadTestEvent
cross_thread_event = CrossThreadTestEvent(type="cross_thread")
crewai_event_bus.emit("source", cross_thread_event)
assert len(test_event_calls) == 1 # Unchanged
assert len(cross_thread_calls) == 1
def test_scoped_handlers_with_deduplication():
"""Test that deduplication works within scoped handlers."""
call_counts = []
def handler(source, event):
call_counts.append(1)
# Register global handler
crewai_event_bus.register_handler(TestEvent, handler)
# Use scoped handlers
with crewai_event_bus.scoped_handlers():
# Try to register the same handler in scoped context
@crewai_event_bus.on(TestEvent)
def scoped_handler(source, event):
call_counts.append(1)
# Emit event - should be called by both global and scoped handlers
event = TestEvent(type="test_event")
crewai_event_bus.emit("source", event)
# Should have 2 calls (1 global + 1 scoped)
assert len(call_counts) == 2