mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 08:08:14 +00:00
Compare commits
7 Commits
devin/1768
...
devin/1749
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1258c433d | ||
|
|
8607719841 | ||
|
|
dbd6890816 | ||
|
|
db6940b450 | ||
|
|
918971994a | ||
|
|
83f4493ff0 | ||
|
|
4c9abe3128 |
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
||||||
@@ -12,12 +13,34 @@ EventT = TypeVar("EventT", bound=BaseEvent)
|
|||||||
|
|
||||||
class CrewAIEventsBus:
|
class CrewAIEventsBus:
|
||||||
"""
|
"""
|
||||||
A singleton event bus that uses blinker signals for event handling.
|
Thread-safe singleton event bus for CrewAI events.
|
||||||
Allows both internal (Flow/Crew) and external event handling.
|
|
||||||
|
This class provides a centralized event handling system that allows components
|
||||||
|
to emit and listen for events throughout the CrewAI framework.
|
||||||
|
|
||||||
|
Thread Safety:
|
||||||
|
- All public methods are thread-safe
|
||||||
|
- Uses a class-level lock to ensure synchronized access to shared resources
|
||||||
|
- Safe for concurrent event emission and handler registration/deregistration
|
||||||
|
- Prevents race conditions that could cause event mixing between sessions
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@crewai_event_bus.on(SomeEvent)
|
||||||
|
def handle_event(source, event):
|
||||||
|
# Handle the event
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Emit an event
|
||||||
|
event = SomeEvent(type="example")
|
||||||
|
crewai_event_bus.emit(source_object, event)
|
||||||
|
|
||||||
|
# Deregister a handler
|
||||||
|
crewai_event_bus.deregister_handler(SomeEvent, handle_event)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@@ -67,27 +90,61 @@ class CrewAIEventsBus:
|
|||||||
source: The object emitting the event
|
source: The object emitting the event
|
||||||
event: The event instance to emit
|
event: The event instance to emit
|
||||||
"""
|
"""
|
||||||
for event_type, handlers in self._handlers.items():
|
with CrewAIEventsBus._lock:
|
||||||
if isinstance(event, event_type):
|
for event_type, handlers in self._handlers.items():
|
||||||
for handler in handlers:
|
if isinstance(event, event_type):
|
||||||
try:
|
for handler in handlers:
|
||||||
handler(source, event)
|
try:
|
||||||
except Exception as e:
|
handler(source, event)
|
||||||
print(
|
except Exception as e:
|
||||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
CrewAIEventsBus._logger.error(
|
||||||
)
|
"Handler execution failed",
|
||||||
|
extra={
|
||||||
|
"handler": handler.__name__,
|
||||||
|
"event_type": event_type.__name__,
|
||||||
|
"error": str(e),
|
||||||
|
"source": str(source)
|
||||||
|
},
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
self._signal.send(source, event=event)
|
self._signal.send(source, event=event)
|
||||||
|
|
||||||
def register_handler(
|
def register_handler(
|
||||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register an event handler for a specific event type"""
|
"""Register an event handler for a specific event type"""
|
||||||
if event_type not in self._handlers:
|
with CrewAIEventsBus._lock:
|
||||||
self._handlers[event_type] = []
|
if event_type not in self._handlers:
|
||||||
self._handlers[event_type].append(
|
self._handlers[event_type] = []
|
||||||
cast(Callable[[Any, EventTypes], None], handler)
|
self._handlers[event_type].append(
|
||||||
)
|
cast(Callable[[Any, EventTypes], None], handler)
|
||||||
|
)
|
||||||
|
|
||||||
|
def deregister_handler(
|
||||||
|
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Deregister an event handler for a specific event type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: The event type to deregister the handler from
|
||||||
|
handler: The handler function to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the handler was found and removed, False otherwise
|
||||||
|
"""
|
||||||
|
with CrewAIEventsBus._lock:
|
||||||
|
if event_type not in self._handlers:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._handlers[event_type].remove(handler)
|
||||||
|
if not self._handlers[event_type]:
|
||||||
|
del self._handlers[event_type]
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def scoped_handlers(self):
|
def scoped_handlers(self):
|
||||||
|
|||||||
112
test_fixes.py
Normal file
112
test_fixes.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script to verify the CI fixes work locally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
|
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
|
||||||
|
import logging
|
||||||
|
|
||||||
|
class TestEvent(BaseEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_basic_functionality():
|
||||||
|
"""Test basic event emission works"""
|
||||||
|
print("Testing basic functionality...")
|
||||||
|
|
||||||
|
received_events = []
|
||||||
|
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler(source, event):
|
||||||
|
received_events.append(f"{source}: {event.chunk}")
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(type='llm_stream_chunk', chunk='test')
|
||||||
|
crewai_event_bus.emit('test_source', event)
|
||||||
|
|
||||||
|
if len(received_events) == 1 and 'test_source: test' in received_events[0]:
|
||||||
|
print("✅ Basic event emission works")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Basic event emission failed")
|
||||||
|
print(f"Received: {received_events}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_error_handling():
|
||||||
|
"""Test error handling with structured logging"""
|
||||||
|
print("Testing error handling...")
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.ERROR)
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
@crewai_event_bus.on(BaseEvent)
|
||||||
|
def broken_handler(source, event):
|
||||||
|
raise ValueError("Simulated handler failure")
|
||||||
|
|
||||||
|
event = TestEvent(type="test_event")
|
||||||
|
crewai_event_bus.emit("source_object", event)
|
||||||
|
|
||||||
|
print("✅ Error handling test completed (check logs above)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_deregistration():
|
||||||
|
"""Test handler deregistration"""
|
||||||
|
print("Testing handler deregistration...")
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
def test_handler(source, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
crewai_event_bus.register_handler(TestEvent, test_handler)
|
||||||
|
initial_count = len(crewai_event_bus._handlers.get(TestEvent, []))
|
||||||
|
print(f"Handlers after registration: {initial_count}")
|
||||||
|
|
||||||
|
result = crewai_event_bus.deregister_handler(TestEvent, test_handler)
|
||||||
|
final_count = len(crewai_event_bus._handlers.get(TestEvent, []))
|
||||||
|
print(f"Handlers after deregistration: {final_count}")
|
||||||
|
print(f"Deregistration result: {result}")
|
||||||
|
|
||||||
|
if result and final_count == 0:
|
||||||
|
print("✅ Handler deregistration works")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Handler deregistration failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Testing CI fixes locally")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_basic_functionality,
|
||||||
|
test_error_handling,
|
||||||
|
test_deregistration
|
||||||
|
]
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
total = len(tests)
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
if test():
|
||||||
|
passed += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"Results: {passed}/{total} tests passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("🎉 All local tests passed!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("💥 Some tests failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from crewai.utilities.events.base_events import BaseEvent
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
|
||||||
|
|
||||||
|
|
||||||
class TestEvent(BaseEvent):
|
class TestEvent(BaseEvent):
|
||||||
@@ -34,14 +38,167 @@ def test_wildcard_event_handler():
|
|||||||
mock_handler.assert_called_once_with("source_object", event)
|
mock_handler.assert_called_once_with("source_object", event)
|
||||||
|
|
||||||
|
|
||||||
def test_event_bus_error_handling(capfd):
|
def test_event_bus_error_handling(caplog):
|
||||||
@crewai_event_bus.on(BaseEvent)
|
with crewai_event_bus.scoped_handlers():
|
||||||
def broken_handler(source, event):
|
@crewai_event_bus.on(BaseEvent)
|
||||||
raise ValueError("Simulated handler failure")
|
def broken_handler(source, event):
|
||||||
|
raise ValueError("Simulated handler failure")
|
||||||
|
|
||||||
event = TestEvent(type="test_event")
|
event = TestEvent(type="test_event")
|
||||||
crewai_event_bus.emit("source_object", event)
|
crewai_event_bus.emit("source_object", event)
|
||||||
|
|
||||||
out, err = capfd.readouterr()
|
assert any("Handler execution failed" in record.message for record in caplog.records)
|
||||||
assert "Simulated handler failure" in out
|
assert any("Simulated handler failure" in str(record.exc_info) if record.exc_info else False for record in caplog.records)
|
||||||
assert "Handler 'broken_handler' failed" in out
|
|
||||||
|
|
||||||
|
def test_concurrent_event_emission_thread_safety():
|
||||||
|
"""Test that concurrent event emission is thread-safe"""
|
||||||
|
|
||||||
|
handler1_events = []
|
||||||
|
handler2_events = []
|
||||||
|
handler_lock = threading.Lock()
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler1(source, event: LLMStreamChunkEvent):
|
||||||
|
with handler_lock:
|
||||||
|
handler1_events.append(f"Handler1: {event.chunk}")
|
||||||
|
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler2(source, event: LLMStreamChunkEvent):
|
||||||
|
with handler_lock:
|
||||||
|
handler2_events.append(f"Handler2: {event.chunk}")
|
||||||
|
|
||||||
|
def emit_events(thread_id, num_events=20):
|
||||||
|
"""Emit events from a specific thread"""
|
||||||
|
for i in range(num_events):
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
type="llm_stream_chunk",
|
||||||
|
chunk=f"Thread-{thread_id}-Chunk-{i}"
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(f"source-{thread_id}", event)
|
||||||
|
|
||||||
|
num_threads = 5
|
||||||
|
events_per_thread = 20
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
futures = []
|
||||||
|
for thread_id in range(num_threads):
|
||||||
|
future = executor.submit(emit_events, thread_id, events_per_thread)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
expected_total = num_threads * events_per_thread
|
||||||
|
assert len(handler1_events) == expected_total, f"Handler1 received {len(handler1_events)} events, expected {expected_total}"
|
||||||
|
assert len(handler2_events) == expected_total, f"Handler2 received {len(handler2_events)} events, expected {expected_total}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_handler_registration_thread_safety():
|
||||||
|
"""Test that concurrent handler registration is thread-safe"""
|
||||||
|
|
||||||
|
registered_handlers = []
|
||||||
|
|
||||||
|
def register_handler(thread_id):
|
||||||
|
"""Register a handler from a specific thread"""
|
||||||
|
def handler(source, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
handler.__name__ = f"handler_{thread_id}"
|
||||||
|
crewai_event_bus.register_handler(TestEvent, handler)
|
||||||
|
registered_handlers.append(handler)
|
||||||
|
|
||||||
|
num_threads = 10
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
futures = []
|
||||||
|
for thread_id in range(num_threads):
|
||||||
|
future = executor.submit(register_handler, thread_id)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
assert len(registered_handlers) == num_threads
|
||||||
|
assert len(crewai_event_bus._handlers[TestEvent]) >= num_threads
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safety_with_mixed_operations():
|
||||||
|
"""Test thread safety when mixing event emission and handler registration"""
|
||||||
|
|
||||||
|
received_events = []
|
||||||
|
event_lock = threading.Lock()
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
def emit_events(thread_id):
|
||||||
|
for i in range(10):
|
||||||
|
event = TestEvent(type="test_event")
|
||||||
|
crewai_event_bus.emit(f"source-{thread_id}", event)
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def register_handlers(thread_id):
|
||||||
|
for i in range(5):
|
||||||
|
def handler(source, event):
|
||||||
|
with event_lock:
|
||||||
|
received_events.append(f"Handler-{thread_id}-{i}: {event.type}")
|
||||||
|
|
||||||
|
handler.__name__ = f"handler_{thread_id}_{i}"
|
||||||
|
crewai_event_bus.register_handler(TestEvent, handler)
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=6) as executor:
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
for thread_id in range(3):
|
||||||
|
futures.append(executor.submit(emit_events, thread_id))
|
||||||
|
|
||||||
|
for thread_id in range(3):
|
||||||
|
futures.append(executor.submit(register_handlers, thread_id))
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
assert len(received_events) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_handler_deregistration_thread_safety():
|
||||||
|
"""Test that concurrent handler deregistration is thread-safe"""
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
handlers_to_remove = []
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
def handler(source, event):
|
||||||
|
pass
|
||||||
|
handler.__name__ = f"handler_{i}"
|
||||||
|
crewai_event_bus.register_handler(TestEvent, handler)
|
||||||
|
handlers_to_remove.append(handler)
|
||||||
|
|
||||||
|
def deregister_handler(handler):
|
||||||
|
"""Deregister a handler from a specific thread"""
|
||||||
|
return crewai_event_bus.deregister_handler(TestEvent, handler)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = []
|
||||||
|
for handler in handlers_to_remove:
|
||||||
|
future = executor.submit(deregister_handler, handler)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
|
||||||
|
assert all(results), "All handlers should be successfully deregistered"
|
||||||
|
|
||||||
|
remaining_count = len(crewai_event_bus._handlers.get(TestEvent, []))
|
||||||
|
assert remaining_count == 0, f"Expected 0 handlers remaining, got {remaining_count}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_deregister_nonexistent_handler():
|
||||||
|
"""Test deregistering a handler that doesn't exist"""
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
def dummy_handler(source, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = crewai_event_bus.deregister_handler(TestEvent, dummy_handler)
|
||||||
|
assert result is False, "Deregistering non-existent handler should return False"
|
||||||
|
|||||||
156
verify_thread_safety.py
Normal file
156
verify_thread_safety.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple verification script for thread safety fix without pytest dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
|
def test_basic_functionality():
|
||||||
|
"""Test basic event emission works"""
|
||||||
|
print("Testing basic functionality...")
|
||||||
|
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
|
||||||
|
|
||||||
|
received_events = []
|
||||||
|
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler(source, event):
|
||||||
|
received_events.append(f"{source}: {event.chunk}")
|
||||||
|
|
||||||
|
event = LLMStreamChunkEvent(type='llm_stream_chunk', chunk='test')
|
||||||
|
crewai_event_bus.emit('test_source', event)
|
||||||
|
|
||||||
|
if len(received_events) == 1 and 'test_source: test' in received_events[0]:
|
||||||
|
print("✅ Basic event emission works")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Basic event emission failed")
|
||||||
|
print(f"Received: {received_events}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_thread_safety():
|
||||||
|
"""Test thread safety of concurrent event emission"""
|
||||||
|
print("Testing thread safety...")
|
||||||
|
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.llm_events import LLMStreamChunkEvent
|
||||||
|
|
||||||
|
handler1_events = []
|
||||||
|
handler2_events = []
|
||||||
|
handler_lock = threading.Lock()
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler1(source, event: LLMStreamChunkEvent):
|
||||||
|
with handler_lock:
|
||||||
|
handler1_events.append(f"Handler1: {event.chunk}")
|
||||||
|
|
||||||
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
|
def handler2(source, event: LLMStreamChunkEvent):
|
||||||
|
with handler_lock:
|
||||||
|
handler2_events.append(f"Handler2: {event.chunk}")
|
||||||
|
|
||||||
|
def emit_events(thread_id, num_events=10):
|
||||||
|
"""Emit events from a specific thread"""
|
||||||
|
for i in range(num_events):
|
||||||
|
event = LLMStreamChunkEvent(
|
||||||
|
type="llm_stream_chunk",
|
||||||
|
chunk=f"Thread-{thread_id}-Chunk-{i}"
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(f"source-{thread_id}", event)
|
||||||
|
|
||||||
|
num_threads = 3
|
||||||
|
events_per_thread = 10
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
futures = []
|
||||||
|
for thread_id in range(num_threads):
|
||||||
|
future = executor.submit(emit_events, thread_id, events_per_thread)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
expected_total = num_threads * events_per_thread
|
||||||
|
success = (len(handler1_events) == expected_total and
|
||||||
|
len(handler2_events) == expected_total)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"✅ Thread safety test passed - each handler received {expected_total} events")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Thread safety test failed")
|
||||||
|
print(f"Handler1 received {len(handler1_events)} events, expected {expected_total}")
|
||||||
|
print(f"Handler2 received {len(handler2_events)} events, expected {expected_total}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_deregistration():
|
||||||
|
"""Test handler deregistration"""
|
||||||
|
print("Testing handler deregistration...")
|
||||||
|
|
||||||
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
class TestEvent(BaseEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
def test_handler(source, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
crewai_event_bus.register_handler(TestEvent, test_handler)
|
||||||
|
initial_count = len(crewai_event_bus._handlers.get(TestEvent, []))
|
||||||
|
|
||||||
|
result = crewai_event_bus.deregister_handler(TestEvent, test_handler)
|
||||||
|
final_count = len(crewai_event_bus._handlers.get(TestEvent, []))
|
||||||
|
|
||||||
|
if result and final_count == 0 and initial_count == 1:
|
||||||
|
print("✅ Handler deregistration works")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Handler deregistration failed")
|
||||||
|
print(f"Initial count: {initial_count}, Final count: {final_count}, Result: {result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Verifying thread safety fix for Issue #2991")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_basic_functionality,
|
||||||
|
test_thread_safety,
|
||||||
|
test_deregistration
|
||||||
|
]
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
total = len(tests)
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
if test():
|
||||||
|
passed += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"Results: {passed}/{total} tests passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("🎉 All thread safety tests passed!")
|
||||||
|
print("The fix for Issue #2991 is working correctly.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("💥 Some tests failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
Reference in New Issue
Block a user