feat: implement thread-safe event bus

This PR also added tests to ensure singleton pattern
This commit is contained in:
Lucas Gomide
2025-07-14 11:20:44 -03:00
parent b6d699f764
commit 7b1ee07b18
2 changed files with 120 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
from typing import Any, Callable, Type, TypeVar, cast
from blinker import Signal
@@ -18,6 +18,7 @@ class CrewAIEventsBus:
_instance = None
_lock = threading.Lock()
_thread_local: threading.local = threading.local()
def __new__(cls):
if cls._instance is None:
@@ -30,7 +31,18 @@ class CrewAIEventsBus:
def _initialize(self) -> None:
"""Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
@property
def _handlers(self) -> dict[type[BaseEvent], list[Callable]]:
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
CrewAIEventsBus._thread_local.handlers = {}
return CrewAIEventsBus._thread_local.handlers
@_handlers.setter
def _handlers(self, value: dict[type[BaseEvent], list[Callable]]) -> None:
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
CrewAIEventsBus._thread_local.handlers = {}
CrewAIEventsBus._thread_local.handlers = value
def on(
self, event_type: Type[EventT]

View File

@@ -1,13 +1,24 @@
import threading
from unittest.mock import Mock
import pytest
from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
@pytest.fixture(autouse=True)
def scoped_event_handlers():
with crewai_event_bus.scoped_handlers():
yield
class TestEvent(BaseEvent):
pass
class AnotherThreadTestEvent(BaseEvent):
pass
def test_specific_event_handler():
mock_handler = Mock()
@@ -45,3 +56,98 @@ def test_event_bus_error_handling(capfd):
out, err = capfd.readouterr()
assert "Simulated handler failure" in out
assert "Handler 'broken_handler' failed" in out
def test_singleton_pattern_across_threads():
instances = []
def get_instance():
instances.append(crewai_event_bus)
threads = []
for _ in range(10):
thread = threading.Thread(target=get_instance)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(instances) == 10
for instance in instances:
assert instance is crewai_event_bus
assert instance is instances[0]
def test_thread_local_handler_isolation():
thread_results = {}
def thread_worker(thread_id):
mock_handler = Mock()
@crewai_event_bus.on(TestEvent)
def thread_handler(source, event):
mock_handler(f"thread_{thread_id}", event)
event = TestEvent(type=f"test_event_thread_{thread_id}")
crewai_event_bus.emit(f"source_{thread_id}", event)
thread_results[thread_id] = {
'mock_handler': mock_handler,
'handler_function': thread_handler,
'event': event
}
threads = []
for i in range(5):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(thread_results) == 5
for thread_id, result in thread_results.items():
result['mock_handler'].assert_called_once_with(
f"thread_{thread_id}",
result['event']
)
def test_scoped_handlers_thread_safety():
thread_results = {}
def thread_worker(thread_id):
with crewai_event_bus.scoped_handlers():
mock_handler = Mock()
@crewai_event_bus.on(AnotherThreadTestEvent)
def scoped_handler(source, event):
mock_handler(f"scoped_thread_{thread_id}", event)
scoped_event = AnotherThreadTestEvent(type=f"scoped_event_{thread_id}")
crewai_event_bus.emit(f"scoped_source_{thread_id}", scoped_event)
thread_results[thread_id] = {
'mock_handler': mock_handler,
'scoped_event': scoped_event
}
post_scoped_event = AnotherThreadTestEvent(type=f"post_scoped_{thread_id}")
crewai_event_bus.emit(f"post_source_{thread_id}", post_scoped_event)
threads = []
for i in range(5):
thread = threading.Thread(target=thread_worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
for thread_id, result in thread_results.items():
result['mock_handler'].assert_called_once_with(
f"scoped_thread_{thread_id}",
result['scoped_event']
)