mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: implement thread-safe event bus
This PR also added tests to ensure singleton pattern
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
from typing import Any, Callable, Type, TypeVar, cast
|
||||||
|
|
||||||
from blinker import Signal
|
from blinker import Signal
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ class CrewAIEventsBus:
|
|||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
_thread_local: threading.local = threading.local()
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@@ -30,7 +31,18 @@ class CrewAIEventsBus:
|
|||||||
def _initialize(self) -> None:
|
def _initialize(self) -> None:
|
||||||
"""Initialize the event bus internal state"""
|
"""Initialize the event bus internal state"""
|
||||||
self._signal = Signal("crewai_event_bus")
|
self._signal = Signal("crewai_event_bus")
|
||||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
|
||||||
|
@property
|
||||||
|
def _handlers(self) -> dict[type[BaseEvent], list[Callable]]:
|
||||||
|
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
|
||||||
|
CrewAIEventsBus._thread_local.handlers = {}
|
||||||
|
return CrewAIEventsBus._thread_local.handlers
|
||||||
|
|
||||||
|
@_handlers.setter
|
||||||
|
def _handlers(self, value: dict[type[BaseEvent], list[Callable]]) -> None:
|
||||||
|
if not hasattr(CrewAIEventsBus._thread_local, 'handlers'):
|
||||||
|
CrewAIEventsBus._thread_local.handlers = {}
|
||||||
|
CrewAIEventsBus._thread_local.handlers = value
|
||||||
|
|
||||||
def on(
|
def on(
|
||||||
self, event_type: Type[EventT]
|
self, event_type: Type[EventT]
|
||||||
|
|||||||
@@ -1,13 +1,24 @@
|
|||||||
|
import threading
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
from crewai.utilities.events.base_events import BaseEvent
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def scoped_event_handlers():
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class TestEvent(BaseEvent):
|
class TestEvent(BaseEvent):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AnotherThreadTestEvent(BaseEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_specific_event_handler():
|
def test_specific_event_handler():
|
||||||
mock_handler = Mock()
|
mock_handler = Mock()
|
||||||
|
|
||||||
@@ -45,3 +56,98 @@ def test_event_bus_error_handling(capfd):
|
|||||||
out, err = capfd.readouterr()
|
out, err = capfd.readouterr()
|
||||||
assert "Simulated handler failure" in out
|
assert "Simulated handler failure" in out
|
||||||
assert "Handler 'broken_handler' failed" in out
|
assert "Handler 'broken_handler' failed" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_singleton_pattern_across_threads():
|
||||||
|
instances = []
|
||||||
|
|
||||||
|
def get_instance():
|
||||||
|
instances.append(crewai_event_bus)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for _ in range(10):
|
||||||
|
thread = threading.Thread(target=get_instance)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
assert len(instances) == 10
|
||||||
|
for instance in instances:
|
||||||
|
assert instance is crewai_event_bus
|
||||||
|
assert instance is instances[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_local_handler_isolation():
|
||||||
|
thread_results = {}
|
||||||
|
|
||||||
|
def thread_worker(thread_id):
|
||||||
|
mock_handler = Mock()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(TestEvent)
|
||||||
|
def thread_handler(source, event):
|
||||||
|
mock_handler(f"thread_{thread_id}", event)
|
||||||
|
|
||||||
|
event = TestEvent(type=f"test_event_thread_{thread_id}")
|
||||||
|
crewai_event_bus.emit(f"source_{thread_id}", event)
|
||||||
|
|
||||||
|
thread_results[thread_id] = {
|
||||||
|
'mock_handler': mock_handler,
|
||||||
|
'handler_function': thread_handler,
|
||||||
|
'event': event
|
||||||
|
}
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(5):
|
||||||
|
thread = threading.Thread(target=thread_worker, args=(i,))
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert len(thread_results) == 5
|
||||||
|
|
||||||
|
for thread_id, result in thread_results.items():
|
||||||
|
result['mock_handler'].assert_called_once_with(
|
||||||
|
f"thread_{thread_id}",
|
||||||
|
result['event']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_handlers_thread_safety():
|
||||||
|
thread_results = {}
|
||||||
|
|
||||||
|
def thread_worker(thread_id):
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
mock_handler = Mock()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(AnotherThreadTestEvent)
|
||||||
|
def scoped_handler(source, event):
|
||||||
|
mock_handler(f"scoped_thread_{thread_id}", event)
|
||||||
|
|
||||||
|
scoped_event = AnotherThreadTestEvent(type=f"scoped_event_{thread_id}")
|
||||||
|
crewai_event_bus.emit(f"scoped_source_{thread_id}", scoped_event)
|
||||||
|
|
||||||
|
thread_results[thread_id] = {
|
||||||
|
'mock_handler': mock_handler,
|
||||||
|
'scoped_event': scoped_event
|
||||||
|
}
|
||||||
|
|
||||||
|
post_scoped_event = AnotherThreadTestEvent(type=f"post_scoped_{thread_id}")
|
||||||
|
crewai_event_bus.emit(f"post_source_{thread_id}", post_scoped_event)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(5):
|
||||||
|
thread = threading.Thread(target=thread_worker, args=(i,))
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
for thread_id, result in thread_results.items():
|
||||||
|
result['mock_handler'].assert_called_once_with(
|
||||||
|
f"scoped_thread_{thread_id}",
|
||||||
|
result['scoped_event']
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user