feat: implement thread-safe event bus

This PR also added tests to ensure singleton pattern
This commit is contained in:
Lucas Gomide
2025-07-14 11:20:44 -03:00
parent b6d699f764
commit 7b1ee07b18
2 changed files with 120 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast from typing import Any, Callable, Type, TypeVar, cast
from blinker import Signal from blinker import Signal
@@ -18,6 +18,7 @@ class CrewAIEventsBus:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
_thread_local: threading.local = threading.local()
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
@@ -30,7 +31,18 @@ class CrewAIEventsBus:
def _initialize(self) -> None: def _initialize(self) -> None:
"""Initialize the event bus internal state""" """Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus") self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
@property
def _handlers(self) -> dict[type[BaseEvent], list[Callable]]:
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
CrewAIEventsBus._thread_local.handlers = {}
return CrewAIEventsBus._thread_local.handlers
@_handlers.setter
def _handlers(self, value: dict[type[BaseEvent], list[Callable]]) -> None:
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
CrewAIEventsBus._thread_local.handlers = {}
CrewAIEventsBus._thread_local.handlers = value
def on( def on(
self, event_type: Type[EventT] self, event_type: Type[EventT]

View File

@@ -1,13 +1,24 @@
import threading
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from crewai.utilities.events.base_events import BaseEvent from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.crewai_event_bus import crewai_event_bus from crewai.utilities.events.crewai_event_bus import crewai_event_bus
@pytest.fixture(autouse=True)
def scoped_event_handlers():
with crewai_event_bus.scoped_handlers():
yield
class TestEvent(BaseEvent): class TestEvent(BaseEvent):
pass pass
class AnotherThreadTestEvent(BaseEvent):
pass
def test_specific_event_handler(): def test_specific_event_handler():
mock_handler = Mock() mock_handler = Mock()
@@ -45,3 +56,98 @@ def test_event_bus_error_handling(capfd):
out, err = capfd.readouterr() out, err = capfd.readouterr()
assert "Simulated handler failure" in out assert "Simulated handler failure" in out
assert "Handler 'broken_handler' failed" in out assert "Handler 'broken_handler' failed" in out
def test_singleton_pattern_across_threads():
instances = []
def get_instance():
instances.append(crewai_event_bus)
threads = []
for _ in range(10):
thread = threading.Thread(target=get_instance)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(instances) == 10
for instance in instances:
assert instance is crewai_event_bus
assert instance is instances[0]
def test_thread_local_handler_isolation():
thread_results = {}
def thread_worker(thread_id):
mock_handler = Mock()
@crewai_event_bus.on(TestEvent)
def thread_handler(source, event):
mock_handler(f"thread_{thread_id}", event)
event = TestEvent(type=f"test_event_thread_{thread_id}")
crewai_event_bus.emit(f"source_{thread_id}", event)
thread_results[thread_id] = {
'mock_handler': mock_handler,
'handler_function': thread_handler,
'event': event
}
threads = []
for i in range(5):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(thread_results) == 5
for thread_id, result in thread_results.items():
result['mock_handler'].assert_called_once_with(
f"thread_{thread_id}",
result['event']
)
def test_scoped_handlers_thread_safety():
thread_results = {}
def thread_worker(thread_id):
with crewai_event_bus.scoped_handlers():
mock_handler = Mock()
@crewai_event_bus.on(AnotherThreadTestEvent)
def scoped_handler(source, event):
mock_handler(f"scoped_thread_{thread_id}", event)
scoped_event = AnotherThreadTestEvent(type=f"scoped_event_{thread_id}")
crewai_event_bus.emit(f"scoped_source_{thread_id}", scoped_event)
thread_results[thread_id] = {
'mock_handler': mock_handler,
'scoped_event': scoped_event
}
post_scoped_event = AnotherThreadTestEvent(type=f"post_scoped_{thread_id}")
crewai_event_bus.emit(f"post_source_{thread_id}", post_scoped_event)
threads = []
for i in range(5):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
for thread_id, result in thread_results.items():
result['mock_handler'].assert_called_once_with(
f"scoped_thread_{thread_id}",
result['scoped_event']
)