"""Tests for thread safety in CrewAI event bus. This module tests concurrent event emission and handler registration. """ import threading import time from collections.abc import Callable from crewai.events.base_events import BaseEvent from crewai.events.event_bus import crewai_event_bus class ThreadSafetyTestEvent(BaseEvent): pass def test_concurrent_emit_from_multiple_threads(): received_events: list[BaseEvent] = [] lock = threading.Lock() with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(ThreadSafetyTestEvent) def handler(source: object, event: BaseEvent) -> None: with lock: received_events.append(event) threads: list[threading.Thread] = [] num_threads = 10 events_per_thread = 10 def emit_events(thread_id: int) -> None: for i in range(events_per_thread): event = ThreadSafetyTestEvent(type=f"thread_{thread_id}_event_{i}") crewai_event_bus.emit(f"source_{thread_id}", event) for i in range(num_threads): thread = threading.Thread(target=emit_events, args=(i,)) threads.append(thread) thread.start() for thread in threads: thread.join() time.sleep(0.5) assert len(received_events) == num_threads * events_per_thread def test_concurrent_handler_registration(): handlers_executed: list[int] = [] lock = threading.Lock() def create_handler(handler_id: int) -> Callable[[object, BaseEvent], None]: def handler(source: object, event: BaseEvent) -> None: with lock: handlers_executed.append(handler_id) return handler with crewai_event_bus.scoped_handlers(): threads: list[threading.Thread] = [] num_handlers = 20 def register_handler(handler_id: int) -> None: crewai_event_bus.register_handler( ThreadSafetyTestEvent, create_handler(handler_id) ) for i in range(num_handlers): thread = threading.Thread(target=register_handler, args=(i,)) threads.append(thread) thread.start() for thread in threads: thread.join() event = ThreadSafetyTestEvent(type="registration_test") crewai_event_bus.emit("test_source", event) time.sleep(0.5) assert len(handlers_executed) == num_handlers assert set(handlers_executed) == set(range(num_handlers)) def test_concurrent_emit_and_registration(): received_events: list[BaseEvent] = [] lock = threading.Lock() with crewai_event_bus.scoped_handlers(): def emit_continuously() -> None: for i in range(50): event = ThreadSafetyTestEvent(type=f"emit_event_{i}") crewai_event_bus.emit("emitter", event) time.sleep(0.001) def register_continuously() -> None: for _ in range(10): @crewai_event_bus.on(ThreadSafetyTestEvent) def handler(source: object, event: BaseEvent) -> None: with lock: received_events.append(event) time.sleep(0.005) emit_thread = threading.Thread(target=emit_continuously) register_thread = threading.Thread(target=register_continuously) emit_thread.start() register_thread.start() emit_thread.join() register_thread.join() time.sleep(0.5) assert len(received_events) > 0 def test_stress_test_rapid_emit(): received_count = [0] lock = threading.Lock() with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(ThreadSafetyTestEvent) def counter_handler(source: object, event: BaseEvent) -> None: with lock: received_count[0] += 1 num_events = 1000 for i in range(num_events): event = ThreadSafetyTestEvent(type=f"rapid_event_{i}") crewai_event_bus.emit("rapid_source", event) time.sleep(1.0) assert received_count[0] == num_events def test_multiple_event_types_concurrent(): class EventTypeA(BaseEvent): pass class EventTypeB(BaseEvent): pass received_a: list[BaseEvent] = [] received_b: list[BaseEvent] = [] lock = threading.Lock() with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(EventTypeA) def handler_a(source: object, event: BaseEvent) -> None: with lock: received_a.append(event) @crewai_event_bus.on(EventTypeB) def handler_b(source: object, event: BaseEvent) -> None: with lock: received_b.append(event) def emit_type_a() -> None: for i in range(50): crewai_event_bus.emit("source_a", EventTypeA(type=f"type_a_{i}")) def emit_type_b() -> None: for i in range(50): crewai_event_bus.emit("source_b", EventTypeB(type=f"type_b_{i}")) thread_a = threading.Thread(target=emit_type_a) thread_b = threading.Thread(target=emit_type_b) thread_a.start() thread_b.start() thread_a.join() thread_b.join() time.sleep(0.5) assert len(received_a) == 50 assert len(received_b) == 50