mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: improve event bus thread safety and async support
Add thread-safe, async-compatible event bus with read–write locking and handler dependency ordering. Remove blinker dependency and implement direct dispatch. Improve type safety, error handling, and deterministic event synchronization. Refactor tests to auto-wait for async handlers, ensure clean teardown, and add comprehensive concurrency coverage. Replace thread-local state in AgentEvaluator with instance-based locking for correct cross-thread access. Enhance tracing reliability and event finalization.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -185,14 +186,17 @@ def test_agent_execution_with_tools():
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
output = agent.execute_task(task)
|
||||
assert output == "The result of the multiplication is 12."
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].tool_name == "multiplier"
|
||||
@@ -284,10 +288,12 @@ def test_cache_hitting():
|
||||
'multiplier-{"first_number": 12, "second_number": 3}': 36,
|
||||
}
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
@@ -303,6 +309,7 @@ def test_cache_hitting():
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input='{"first_number": 2, "second_number": 6}'
|
||||
)
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -156,14 +157,17 @@ def test_lite_agent_with_tools():
|
||||
)
|
||||
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
agent.kickoff("What are the effects of climate change on coral reefs?")
|
||||
|
||||
# Verify tool usage events were emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage events"
|
||||
assert len(received_events) > 0, "Tool usage events should be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageStartedEvent)
|
||||
@@ -316,15 +320,18 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
event_received.set()
|
||||
|
||||
flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for agent execution event"
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -342,30 +349,43 @@ def test_guardrail_is_called_using_string():
|
||||
guardrail="""Only include Brazilian players, both women and men""",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -376,29 +396,42 @@ def test_guardrail_is_called_using_callable():
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
|
||||
)
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
|
||||
)
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -409,37 +442,50 @@ def test_guardrail_reached_attempt_limit():
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
all_events_received.set()
|
||||
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
@@ -33,7 +33,7 @@ def setup_test_environment():
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
@@ -159,6 +159,29 @@ def mock_opentelemetry_components():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_event_bus_handlers():
|
||||
"""Clear event bus handlers after each test for isolation.
|
||||
|
||||
Handlers registered during the test are allowed to run, then cleaned up
|
||||
after the test completes.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
EvaluationTraceCallback,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
crewai_event_bus.shutdown(wait=True)
|
||||
crewai_event_bus._initialize()
|
||||
|
||||
callback = EvaluationTraceCallback()
|
||||
callback.traces.clear()
|
||||
callback.current_agent_id = None
|
||||
callback.current_task_id = None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(request) -> dict:
|
||||
import os
|
||||
|
||||
286
lib/crewai/tests/events/test_depends.py
Normal file
286
lib/crewai/tests/events/test_depends.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Tests for FastAPI-style dependency injection in event handlers."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events import Depends, crewai_event_bus
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class DependsTestEvent(BaseEvent):
|
||||
"""Test event for dependency tests."""
|
||||
|
||||
value: int = 0
|
||||
type: str = "test_event"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_dependency():
|
||||
"""Test that handler with dependency runs after its dependency."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup(source, event: DependsTestEvent):
|
||||
execution_order.append("setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, Depends(setup))
|
||||
def process(source, event: DependsTestEvent):
|
||||
execution_order.append("process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["setup", "process"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_dependencies():
|
||||
"""Test handler with multiple dependencies."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup_a(source, event: DependsTestEvent):
|
||||
execution_order.append("setup_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup_b(source, event: DependsTestEvent):
|
||||
execution_order.append("setup_b")
|
||||
|
||||
@crewai_event_bus.on(
|
||||
DependsTestEvent, depends_on=[Depends(setup_a), Depends(setup_b)]
|
||||
)
|
||||
def process(source, event: DependsTestEvent):
|
||||
execution_order.append("process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# setup_a and setup_b can run in any order (same level)
|
||||
assert "process" in execution_order
|
||||
assert execution_order.index("process") > execution_order.index("setup_a")
|
||||
assert execution_order.index("process") > execution_order.index("setup_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_of_dependencies():
|
||||
"""Test chain of dependencies (A -> B -> C)."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def handler_a(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_a))
|
||||
def handler_b(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_b")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_b))
|
||||
def handler_c(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_c")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["handler_a", "handler_b", "handler_c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_with_dependency():
|
||||
"""Test async handler with dependency on sync handler."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def sync_setup(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(sync_setup))
|
||||
async def async_process(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("async_process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["sync_setup", "async_process"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_with_dependencies():
|
||||
"""Test mix of sync and async handlers with dependencies."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup(source, event: DependsTestEvent):
|
||||
execution_order.append("setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
|
||||
def sync_process(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_process")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
|
||||
async def async_process(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("async_process")
|
||||
|
||||
@crewai_event_bus.on(
|
||||
DependsTestEvent, depends_on=[Depends(sync_process), Depends(async_process)]
|
||||
)
|
||||
def finalize(source, event: DependsTestEvent):
|
||||
execution_order.append("finalize")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Verify execution order
|
||||
assert execution_order[0] == "setup"
|
||||
assert "finalize" in execution_order
|
||||
assert execution_order.index("finalize") > execution_order.index("sync_process")
|
||||
assert execution_order.index("finalize") > execution_order.index("async_process")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_handlers_run_concurrently():
|
||||
"""Test that handlers without dependencies can run concurrently."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def handler_a(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("handler_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def handler_b(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("handler_b")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Both handlers should have executed
|
||||
assert len(execution_order) == 2
|
||||
assert "handler_a" in execution_order
|
||||
assert "handler_b" in execution_order
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circular_dependency_detection():
|
||||
"""Test that circular dependencies are detected and raise an error."""
|
||||
from crewai.events.handler_graph import CircularDependencyError, build_execution_plan
|
||||
|
||||
# Create circular dependency: handler_a -> handler_b -> handler_c -> handler_a
|
||||
def handler_a(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
def handler_b(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
def handler_c(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
# Build a dependency graph with a cycle
|
||||
handlers = [handler_a, handler_b, handler_c]
|
||||
dependencies = {
|
||||
handler_a: [Depends(handler_b)],
|
||||
handler_b: [Depends(handler_c)],
|
||||
handler_c: [Depends(handler_a)], # Creates the cycle
|
||||
}
|
||||
|
||||
# Should raise CircularDependencyError about circular dependency
|
||||
with pytest.raises(CircularDependencyError, match="Circular dependency"):
|
||||
build_execution_plan(handlers, dependencies)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_dependency_runs_normally():
|
||||
"""Test that handlers without dependencies still work as before."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def simple_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("simple_handler")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["simple_handler"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depends_equality():
|
||||
"""Test Depends equality and hashing."""
|
||||
|
||||
def handler_a(source, event):
|
||||
pass
|
||||
|
||||
def handler_b(source, event):
|
||||
pass
|
||||
|
||||
dep_a1 = Depends(handler_a)
|
||||
dep_a2 = Depends(handler_a)
|
||||
dep_b = Depends(handler_b)
|
||||
|
||||
# Same handler should be equal
|
||||
assert dep_a1 == dep_a2
|
||||
assert hash(dep_a1) == hash(dep_a2)
|
||||
|
||||
# Different handlers should not be equal
|
||||
assert dep_a1 != dep_b
|
||||
assert hash(dep_a1) != hash(dep_b)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_ignores_dependencies():
|
||||
"""Test that aemit only processes async handlers (no dependency support yet)."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def sync_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_handler")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def async_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("async_handler")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
# Only async handler should execute
|
||||
assert execution_order == ["async_handler"]
|
||||
@@ -1,3 +1,5 @@
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
@@ -19,7 +21,10 @@ from crewai.experimental.evaluation import (
|
||||
create_default_evaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.agent_evaluator import AgentEvaluator
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
BaseEvaluator,
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@@ -51,12 +56,25 @@ class TestAgentEvaluator:
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_evaluate_current_iteration(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
task_completed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
task_completed_event.set()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
assert isinstance(results, dict)
|
||||
@@ -98,73 +116,15 @@ class TestAgentEvaluator:
|
||||
]
|
||||
|
||||
assert len(agent_evaluator.evaluators) == len(expected_types)
|
||||
for evaluator, expected_type in zip(agent_evaluator.evaluators, expected_types):
|
||||
for evaluator, expected_type in zip(
|
||||
agent_evaluator.evaluators, expected_types, strict=False
|
||||
):
|
||||
assert isinstance(evaluator, expected_type)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_lite_agent(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete test tasks successfully",
|
||||
backstory="An agent created for testing purposes",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
agent.kickoff(messages="Complete this task successfully")
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id is None
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id is None
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 2.0
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == "lite_task"
|
||||
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 2.0
|
||||
|
||||
expected_feedback = "The agent did not demonstrate a clear understanding of the task goal, which is to complete test tasks successfully"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 2' in goal_alignment.raw_response
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_specific_agents_from_crew(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent Eval",
|
||||
goal="Complete test tasks successfully",
|
||||
@@ -178,111 +138,132 @@ class TestAgentEvaluator:
|
||||
mock_crew.agents.append(agent)
|
||||
mock_crew.tasks.append(task)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
events = {}
|
||||
started_event = threading.Event()
|
||||
completed_event = threading.Event()
|
||||
task_completed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
async def capture_started(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
completed_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
if event.task and event.task.id == task.id:
|
||||
task_completed_event.set()
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
mock_crew.kickoff()
|
||||
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id == str(task.id)
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 5.0
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert isinstance(results, dict)
|
||||
assert len(results.keys()) == 1
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id == str(task.id)
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 5.0
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 5.0
|
||||
assert isinstance(results, dict)
|
||||
assert len(results.keys()) == 1
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 5' in goal_alignment.raw_response
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 5.0
|
||||
|
||||
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 5' in goal_alignment.raw_response
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_failed_evaluation(self, mock_crew):
|
||||
(agent,) = mock_crew.agents
|
||||
(task,) = mock_crew.tasks
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
events = {}
|
||||
started_event = threading.Event()
|
||||
failed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
failed_event.set()
|
||||
|
||||
# Create a mock evaluator that will raise an exception
|
||||
from crewai.experimental.evaluation import MetricCategory
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator
|
||||
class FailingEvaluator(BaseEvaluator):
|
||||
metric_category = MetricCategory.GOAL_ALIGNMENT
|
||||
|
||||
class FailingEvaluator(BaseEvaluator):
|
||||
metric_category = MetricCategory.GOAL_ALIGNMENT
|
||||
def evaluate(self, agent, task, execution_trace, final_output):
|
||||
raise ValueError("Forced evaluation failure")
|
||||
|
||||
def evaluate(self, agent, task, execution_trace, final_output):
|
||||
raise ValueError("Forced evaluation failure")
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[FailingEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[FailingEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert failed_event.wait(timeout=5), "Timeout waiting for failed event"
|
||||
|
||||
assert events.keys() == {"started", "failed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
assert events.keys() == {"started", "failed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert events["failed"].agent_id == str(agent.id)
|
||||
assert events["failed"].agent_role == agent.role
|
||||
assert events["failed"].task_id == str(task.id)
|
||||
assert events["failed"].iteration == 1
|
||||
assert events["failed"].error == "Forced evaluation failure"
|
||||
assert events["failed"].agent_id == str(agent.id)
|
||||
assert events["failed"].agent_role == agent.role
|
||||
assert events["failed"].task_id == str(task.id)
|
||||
assert events["failed"].iteration == 1
|
||||
assert events["failed"].error == "Forced evaluation failure"
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
|
||||
assert result.metrics == {}
|
||||
assert result.metrics == {}
|
||||
|
||||
@@ -1,23 +1,36 @@
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew, Process
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_event_handlers():
|
||||
"""Cleanup event handlers after each test"""
|
||||
yield
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
@@ -238,24 +251,26 @@ def test_external_memory_search_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
external_memory_with_mocked_config.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
external_memory_with_mocked_config.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
|
||||
@@ -300,24 +315,25 @@ def test_external_memory_save_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
external_memory_with_mocked_config.save(
|
||||
value="saving value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
external_memory_with_mocked_config.save(
|
||||
value="saving value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
@@ -21,27 +23,37 @@ def long_term_memory():
|
||||
|
||||
def test_long_term_memory_save_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemorySaveStartedEvent"]) == 1
|
||||
and len(events["MemorySaveCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemorySaveStartedEvent"]) == 1
|
||||
and len(events["MemorySaveCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
@@ -86,21 +98,31 @@ def test_long_term_memory_save_events(long_term_memory):
|
||||
|
||||
def test_long_term_memory_search_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemoryQueryStartedEvent"]) == 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemoryQueryStartedEvent"]) == 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
test_query = "test query"
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
@@ -37,24 +38,33 @@ def short_term_memory():
|
||||
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "search", return_value=[]):
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), (
|
||||
"Timeout waiting for search started event"
|
||||
)
|
||||
assert search_completed.wait(timeout=2), (
|
||||
"Timeout waiting for search completed event"
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -98,20 +108,26 @@ def test_short_term_memory_search_events(short_term_memory):
|
||||
|
||||
def test_short_term_memory_save_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
save_started = threading.Event()
|
||||
save_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
save_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
save_completed.set()
|
||||
|
||||
short_term_memory.save(
|
||||
value="test value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
short_term_memory.save(
|
||||
value="test value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
assert save_started.wait(timeout=2), "Timeout waiting for save started event"
|
||||
assert save_completed.wait(timeout=2), "Timeout waiting for save completed event"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
import json
|
||||
import re
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
@@ -2476,62 +2477,63 @@ def test_using_contextual_memory():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_memory_events_are_emitted():
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def handle_memory_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def handle_memory_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def handle_memory_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def handle_memory_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def handle_memory_save_failed(source, event):
|
||||
events["MemorySaveFailedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def handle_memory_save_failed(source, event):
|
||||
events["MemorySaveFailedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def handle_memory_query_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def handle_memory_query_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def handle_memory_query_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def handle_memory_query_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def handle_memory_query_failed(source, event):
|
||||
events["MemoryQueryFailedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def handle_memory_query_failed(source, event):
|
||||
events["MemoryQueryFailedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def handle_memory_retrieval_started(source, event):
|
||||
events["MemoryRetrievalStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def handle_memory_retrieval_started(source, event):
|
||||
events["MemoryRetrievalStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def handle_memory_retrieval_completed(source, event):
|
||||
events["MemoryRetrievalCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def handle_memory_retrieval_completed(source, event):
|
||||
events["MemoryRetrievalCompletedEvent"].append(event)
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for memory events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
@@ -2907,19 +2909,29 @@ def test_crew_train_success(
|
||||
copy_mock.return_value = crew
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew.train(
|
||||
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||
)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all train events"
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
@@ -3726,17 +3738,27 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator, research
|
||||
llm_instance = LLM("gpt-4o-mini")
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def on_crew_test_started(source, event: CrewTestStartedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewTestCompletedEvent)
|
||||
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew.test(n_iterations, llm_instance, inputs={"topic": "AI"})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all test events"
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Test Flow creation and execution basic functionality."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
@@ -13,7 +16,6 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -439,20 +441,42 @@ def test_unstructured_flow_event_emission():
|
||||
|
||||
flow = PoemFlow()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = (
|
||||
7 # 1 FlowStarted + 5 MethodExecutionStarted + 1 FlowFinished
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff(inputs={"separator": ", "})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
# (async handlers may append out of order)
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "PoemFlow"
|
||||
assert received_events[0].inputs == {"separator": ", "}
|
||||
@@ -642,28 +666,48 @@ def test_structured_flow_event_emission():
|
||||
return f"Welcome, {self.state.name}!"
|
||||
|
||||
flow = OnboardingFlow()
|
||||
flow.kickoff(inputs={"name": "Anakin"})
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff(inputs={"name": "Anakin"})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "OnboardingFlow"
|
||||
assert received_events[0].inputs == {"name": "Anakin"}
|
||||
@@ -711,25 +755,46 @@ def test_stateless_flow_event_emission():
|
||||
|
||||
flow = StatelessFlow()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "StatelessFlow"
|
||||
assert received_events[0].inputs is None
|
||||
@@ -769,13 +834,16 @@ def test_flow_plotting():
|
||||
flow = StatelessFlow()
|
||||
flow.kickoff()
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(FlowPlotEvent)
|
||||
def handle_flow_plot(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
flow.plot("test_flow")
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for plot event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], FlowPlotEvent)
|
||||
assert received_events[0].flow_name == "StatelessFlow"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -175,78 +176,92 @@ def test_task_guardrail_process_output(task_output):
|
||||
def test_guardrail_emits_events(sample_agent):
|
||||
started_guardrail = []
|
||||
completed_guardrail = []
|
||||
all_events_received = threading.Event()
|
||||
expected_started = 3 # 2 from first task, 1 from second
|
||||
expected_completed = 3 # 2 from first task, 1 from second
|
||||
|
||||
task = Task(
|
||||
task1 = Task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
guardrail="Ensure the authors are from Italy",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def handle_guardrail_started(source, event):
|
||||
assert source == task
|
||||
started_guardrail.append(
|
||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def handle_guardrail_completed(source, event):
|
||||
assert source == task
|
||||
completed_guardrail.append(
|
||||
{
|
||||
"success": event.success,
|
||||
"result": event.result,
|
||||
"error": event.error,
|
||||
"retry_count": event.retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=sample_agent)
|
||||
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def handle_guardrail_started(source, event):
|
||||
started_guardrail.append(
|
||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||
)
|
||||
if (
|
||||
len(started_guardrail) >= expected_started
|
||||
and len(completed_guardrail) >= expected_completed
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
task.execute_sync(agent=sample_agent)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def handle_guardrail_completed(source, event):
|
||||
completed_guardrail.append(
|
||||
{
|
||||
"success": event.success,
|
||||
"result": event.result,
|
||||
"error": event.error,
|
||||
"retry_count": event.retry_count,
|
||||
}
|
||||
)
|
||||
if (
|
||||
len(started_guardrail) >= expected_started
|
||||
and len(completed_guardrail) >= expected_completed
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
expected_started_events = [
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
|
||||
{
|
||||
"guardrail": """def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")""",
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
result = task1.execute_sync(agent=sample_agent)
|
||||
|
||||
expected_completed_events = [
|
||||
{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": "The task result does not comply with the guardrail because none of "
|
||||
"the listed authors are from Italy. All authors mentioned are from "
|
||||
"different countries, including Germany, the UK, the USA, and others, "
|
||||
"which violates the requirement that authors must be Italian.",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
|
||||
{
|
||||
"success": True,
|
||||
"result": "good result from callable function",
|
||||
"error": None,
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
assert started_guardrail == expected_started_events
|
||||
assert completed_guardrail == expected_completed_events
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task2 = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
)
|
||||
|
||||
task2.execute_sync(agent=sample_agent)
|
||||
|
||||
# Wait for all events to be received
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
|
||||
expected_started_events = [
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
|
||||
{
|
||||
"guardrail": """def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")""",
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
|
||||
expected_completed_events = [
|
||||
{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": "The task result does not comply with the guardrail because none of "
|
||||
"the listed authors are from Italy. All authors mentioned are from "
|
||||
"different countries, including Germany, the UK, the USA, and others, "
|
||||
"which violates the requirement that authors must be Italian.",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
|
||||
{
|
||||
"success": True,
|
||||
"result": "good result from callable function",
|
||||
"error": None,
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
assert started_guardrail == expected_started_events
|
||||
assert completed_guardrail == expected_completed_events
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -32,7 +33,7 @@ class RandomNumberTool(BaseTool):
|
||||
args_schema: type[BaseModel] = RandomNumberToolInput
|
||||
|
||||
def _run(self, min_value: int, max_value: int) -> int:
|
||||
return random.randint(min_value, max_value)
|
||||
return random.randint(min_value, max_value) # noqa: S311
|
||||
|
||||
|
||||
# Example agent and task
|
||||
@@ -470,13 +471,21 @@ def test_tool_selection_error_event_direct():
|
||||
)
|
||||
|
||||
received_events = []
|
||||
first_event_received = threading.Event()
|
||||
second_event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolSelectionErrorEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
if event.tool_name == "Non Existent Tool":
|
||||
first_event_received.set()
|
||||
elif event.tool_name == "":
|
||||
second_event_received.set()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._select_tool("Non Existent Tool")
|
||||
|
||||
assert first_event_received.wait(timeout=5), "Timeout waiting for first event"
|
||||
assert len(received_events) == 1
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolSelectionErrorEvent)
|
||||
@@ -488,12 +497,12 @@ def test_tool_selection_error_event_direct():
|
||||
assert "A test tool" in event.tool_class
|
||||
assert "don't exist" in event.error
|
||||
|
||||
received_events.clear()
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._select_tool("")
|
||||
|
||||
assert len(received_events) == 1
|
||||
event = received_events[0]
|
||||
assert second_event_received.wait(timeout=5), "Timeout waiting for second event"
|
||||
assert len(received_events) == 2
|
||||
event = received_events[1]
|
||||
assert isinstance(event, ToolSelectionErrorEvent)
|
||||
assert event.agent_key == "test_key"
|
||||
assert event.agent_role == "test_role"
|
||||
@@ -562,7 +571,7 @@ def test_tool_validate_input_error_event():
|
||||
|
||||
# Test invalid input
|
||||
invalid_input = "invalid json {[}"
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._validate_tool_input(invalid_input)
|
||||
|
||||
# Verify event was emitted
|
||||
@@ -616,12 +625,13 @@ def test_tool_usage_finished_event_with_result():
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Track received events
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
# Call on_tool_use_finished with test data
|
||||
started_at = time.time()
|
||||
@@ -634,7 +644,7 @@ def test_tool_usage_finished_event_with_result():
|
||||
result=result,
|
||||
)
|
||||
|
||||
# Verify event was emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for event"
|
||||
assert len(received_events) == 1, "Expected one event to be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageFinishedEvent)
|
||||
@@ -695,12 +705,13 @@ def test_tool_usage_finished_event_with_cached_result():
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Track received events
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
# Call on_tool_use_finished with test data and from_cache=True
|
||||
started_at = time.time()
|
||||
@@ -713,7 +724,7 @@ def test_tool_usage_finished_event_with_cached_result():
|
||||
result=result,
|
||||
)
|
||||
|
||||
# Verify event was emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for event"
|
||||
assert len(received_events) == 1, "Expected one event to be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageFinishedEvent)
|
||||
|
||||
@@ -14,6 +14,7 @@ from crewai.events.listeners.tracing.trace_listener import (
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.flow.flow import Flow, start
|
||||
from tests.utils import wait_for_event_handlers
|
||||
|
||||
|
||||
class TestTraceListenerSetup:
|
||||
@@ -39,38 +40,44 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_event_bus(self):
|
||||
"""Clear event bus listeners before and after each test"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
# Store original handlers
|
||||
original_handlers = crewai_event_bus._handlers.copy()
|
||||
|
||||
# Clear for test
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original state
|
||||
crewai_event_bus._handlers.clear()
|
||||
crewai_event_bus._handlers.update(original_handlers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_tracing_singletons(self):
|
||||
"""Reset tracing singleton instances between tests"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
# Clear event bus handlers BEFORE creating any new singletons
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset TraceCollectionListener singleton
|
||||
if hasattr(TraceCollectionListener, "_instance"):
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
yield
|
||||
|
||||
# Clean up after test
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
if hasattr(TraceCollectionListener, "_instance"):
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_plus_api_calls(self):
|
||||
"""Mock all PlusAPI HTTP calls to avoid network requests"""
|
||||
@@ -167,15 +174,26 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = None
|
||||
for handler_list in crewai_event_bus._handlers.values():
|
||||
for handler in handler_list:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
with crewai_event_bus._rwlock.r_locked():
|
||||
for handler_set in crewai_event_bus._sync_handlers.values():
|
||||
for handler in handler_set:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
if not trace_listener:
|
||||
for handler_set in crewai_event_bus._async_handlers.values():
|
||||
for handler in handler_set:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
|
||||
if not trace_listener:
|
||||
pytest.skip(
|
||||
@@ -221,6 +239,7 @@ class TestTraceListenerSetup:
|
||||
wraps=trace_listener.batch_manager.add_event,
|
||||
) as add_event_mock:
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert add_event_mock.call_count >= 2
|
||||
|
||||
@@ -267,24 +286,22 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_handlers = []
|
||||
for handlers in crewai_event_bus._handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
elif hasattr(handler, "__name__") and any(
|
||||
trace_name in handler.__name__
|
||||
for trace_name in [
|
||||
"on_crew_started",
|
||||
"on_crew_completed",
|
||||
"on_flow_started",
|
||||
]
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
with crewai_event_bus._rwlock.r_locked():
|
||||
for handlers in crewai_event_bus._sync_handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
for handlers in crewai_event_bus._async_handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
|
||||
assert len(trace_handlers) == 0, (
|
||||
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
f"Found {len(trace_handlers)} TraceCollectionListener handlers when tracing should be disabled"
|
||||
)
|
||||
|
||||
def test_trace_listener_setup_correctly_for_crew(self):
|
||||
@@ -385,6 +402,7 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
crew = Crew(agents=[agent], tasks=[task], tracing=True)
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_plus_api_class.assert_called_with(api_key="mock_token_12345")
|
||||
|
||||
@@ -396,15 +414,33 @@ class TestTraceListenerSetup:
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test method"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
"""Final cleanup after all tests in this class"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
|
||||
@@ -466,6 +502,7 @@ class TestTraceListenerSetup:
|
||||
) as mock_add_event,
|
||||
):
|
||||
result = crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
assert result is not None
|
||||
|
||||
assert mock_handle_completion.call_count >= 1
|
||||
@@ -543,6 +580,7 @@ class TestTraceListenerSetup:
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert mock_handle_completion.call_count >= 1, (
|
||||
"handle_execution_completion should be called"
|
||||
@@ -561,7 +599,6 @@ class TestTraceListenerSetup:
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
|
||||
"""Test the consolidation logic for first-time users vs regular tracing"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
@@ -579,7 +616,9 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
@@ -600,6 +639,9 @@ class TestTraceListenerSetup:
|
||||
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
|
||||
result = crew.kickoff()
|
||||
|
||||
assert trace_listener.batch_manager.wait_for_pending_events(timeout=5.0), (
|
||||
"Timeout waiting for trace event handlers to complete"
|
||||
)
|
||||
assert mock_initialize.call_count >= 1
|
||||
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
|
||||
assert result is not None
|
||||
@@ -700,6 +742,7 @@ class TestTraceListenerSetup:
|
||||
) as mock_mark_failed,
|
||||
):
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_mark_failed.assert_called_once()
|
||||
call_args = mock_mark_failed.call_args_list[0]
|
||||
|
||||
206
lib/crewai/tests/utilities/events/test_async_event_bus.py
Normal file
206
lib/crewai/tests/utilities/events/test_async_event_bus.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for async event handling in CrewAI event bus.
|
||||
|
||||
This module tests async handler registration, execution, and the aemit method.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
|
||||
class AsyncTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_execution():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_with_async_handlers():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_handlers():
|
||||
received_events_1 = []
|
||||
received_events_2 = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def handler_1(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events_1.append(event)
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def handler_2(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.02)
|
||||
received_events_2.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_and_async_handlers():
|
||||
sync_events = []
|
||||
async_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
def sync_handler(source: object, event: BaseEvent) -> None:
|
||||
sync_events.append(event)
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
async_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="mixed_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(sync_events) == 1
|
||||
assert len(async_events) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_error_handling():
|
||||
successful_handler_called = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def failing_handler(source: object, event: BaseEvent) -> None:
|
||||
raise ValueError("Async handler error")
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def successful_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
successful_handler_called.append(True)
|
||||
|
||||
event = AsyncTestEvent(type="error_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(successful_handler_called) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_with_no_handlers():
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
event = AsyncTestEvent(type="no_handlers")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_registration_via_register_handler():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
async def custom_async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
crewai_event_bus.register_handler(AsyncTestEvent, custom_async_handler)
|
||||
|
||||
event = AsyncTestEvent(type="register_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_async_handlers_fire_and_forget():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def slow_async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="fire_forget_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_handlers_with_async():
|
||||
received_before = []
|
||||
received_during = []
|
||||
received_after = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def before_handler(source: object, event: BaseEvent) -> None:
|
||||
received_before.append(event)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def scoped_handler(source: object, event: BaseEvent) -> None:
|
||||
received_during.append(event)
|
||||
|
||||
event1 = AsyncTestEvent(type="during_scope")
|
||||
await crewai_event_bus.aemit("test_source", event1)
|
||||
|
||||
assert len(received_before) == 0
|
||||
assert len(received_during) == 1
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def after_handler(source: object, event: BaseEvent) -> None:
|
||||
received_after.append(event)
|
||||
|
||||
event2 = AsyncTestEvent(type="after_scope")
|
||||
await crewai_event_bus.aemit("test_source", event2)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
assert len(received_after) == 1
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
@@ -21,27 +22,42 @@ def test_specific_event_handler():
|
||||
mock_handler.assert_called_once_with("source_object", event)
|
||||
|
||||
|
||||
def test_wildcard_event_handler():
|
||||
mock_handler = Mock()
|
||||
def test_multiple_handlers_same_event():
|
||||
"""Test that multiple handlers can be registered for the same event type."""
|
||||
mock_handler1 = Mock()
|
||||
mock_handler2 = Mock()
|
||||
|
||||
@crewai_event_bus.on(BaseEvent)
|
||||
def handler(source, event):
|
||||
mock_handler(source, event)
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def handler1(source, event):
|
||||
mock_handler1(source, event)
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def handler2(source, event):
|
||||
mock_handler2(source, event)
|
||||
|
||||
event = TestEvent(type="test_event")
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
mock_handler.assert_called_once_with("source_object", event)
|
||||
mock_handler1.assert_called_once_with("source_object", event)
|
||||
mock_handler2.assert_called_once_with("source_object", event)
|
||||
|
||||
|
||||
def test_event_bus_error_handling(capfd):
|
||||
@crewai_event_bus.on(BaseEvent)
|
||||
def test_event_bus_error_handling():
|
||||
"""Test that handler exceptions are caught and don't break the event bus."""
|
||||
called = threading.Event()
|
||||
error_caught = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def broken_handler(source, event):
|
||||
called.set()
|
||||
raise ValueError("Simulated handler failure")
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def working_handler(source, event):
|
||||
error_caught.set()
|
||||
|
||||
event = TestEvent(type="test_event")
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
out, err = capfd.readouterr()
|
||||
assert "Simulated handler failure" in out
|
||||
assert "Handler 'broken_handler' failed" in out
|
||||
assert called.wait(timeout=2), "Broken handler was never called"
|
||||
assert error_caught.wait(timeout=2), "Working handler was never called after error"
|
||||
|
||||
264
lib/crewai/tests/utilities/events/test_rw_lock.py
Normal file
264
lib/crewai/tests/utilities/events/test_rw_lock.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Tests for read-write lock implementation.
|
||||
|
||||
This module tests the RWLock class for correct concurrent read and write behavior.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from crewai.events.utils.rw_lock import RWLock
|
||||
|
||||
|
||||
def test_multiple_readers_concurrent():
|
||||
lock = RWLock()
|
||||
active_readers = [0]
|
||||
max_concurrent_readers = [0]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader(reader_id: int) -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
active_readers[0] += 1
|
||||
max_concurrent_readers[0] = max(
|
||||
max_concurrent_readers[0], active_readers[0]
|
||||
)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
with lock_for_counters:
|
||||
active_readers[0] -= 1
|
||||
|
||||
threads = [threading.Thread(target=reader, args=(i,)) for i in range(5)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert max_concurrent_readers[0] == 5
|
||||
|
||||
|
||||
def test_writer_blocks_readers():
|
||||
lock = RWLock()
|
||||
writer_holding_lock = [False]
|
||||
reader_accessed_during_write = [False]
|
||||
|
||||
def writer() -> None:
|
||||
with lock.w_locked():
|
||||
writer_holding_lock[0] = True
|
||||
time.sleep(0.2)
|
||||
writer_holding_lock[0] = False
|
||||
|
||||
def reader() -> None:
|
||||
time.sleep(0.05)
|
||||
with lock.r_locked():
|
||||
if writer_holding_lock[0]:
|
||||
reader_accessed_during_write[0] = True
|
||||
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
|
||||
writer_thread.start()
|
||||
reader_thread.start()
|
||||
|
||||
writer_thread.join()
|
||||
reader_thread.join()
|
||||
|
||||
assert not reader_accessed_during_write[0]
|
||||
|
||||
|
||||
def test_writer_blocks_other_writers():
|
||||
lock = RWLock()
|
||||
execution_order: list[int] = []
|
||||
lock_for_order = threading.Lock()
|
||||
|
||||
def writer(writer_id: int) -> None:
|
||||
with lock.w_locked():
|
||||
with lock_for_order:
|
||||
execution_order.append(writer_id)
|
||||
time.sleep(0.1)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(i,)) for i in range(3)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len(execution_order) == 3
|
||||
assert len(set(execution_order)) == 3
|
||||
|
||||
|
||||
def test_readers_block_writers():
|
||||
lock = RWLock()
|
||||
reader_count = [0]
|
||||
writer_accessed_during_read = [False]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader() -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
reader_count[0] += 1
|
||||
time.sleep(0.2)
|
||||
with lock_for_counters:
|
||||
reader_count[0] -= 1
|
||||
|
||||
def writer() -> None:
|
||||
time.sleep(0.05)
|
||||
with lock.w_locked():
|
||||
with lock_for_counters:
|
||||
if reader_count[0] > 0:
|
||||
writer_accessed_during_read[0] = True
|
||||
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
|
||||
reader_thread.start()
|
||||
writer_thread.start()
|
||||
|
||||
reader_thread.join()
|
||||
writer_thread.join()
|
||||
|
||||
assert not writer_accessed_during_read[0]
|
||||
|
||||
|
||||
def test_alternating_readers_and_writers():
|
||||
lock = RWLock()
|
||||
operations: list[str] = []
|
||||
lock_for_operations = threading.Lock()
|
||||
|
||||
def reader(reader_id: int) -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_operations:
|
||||
operations.append(f"r{reader_id}_start")
|
||||
time.sleep(0.05)
|
||||
with lock_for_operations:
|
||||
operations.append(f"r{reader_id}_end")
|
||||
|
||||
def writer(writer_id: int) -> None:
|
||||
with lock.w_locked():
|
||||
with lock_for_operations:
|
||||
operations.append(f"w{writer_id}_start")
|
||||
time.sleep(0.05)
|
||||
with lock_for_operations:
|
||||
operations.append(f"w{writer_id}_end")
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=reader, args=(0,)),
|
||||
threading.Thread(target=writer, args=(0,)),
|
||||
threading.Thread(target=reader, args=(1,)),
|
||||
threading.Thread(target=writer, args=(1,)),
|
||||
threading.Thread(target=reader, args=(2,)),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len(operations) == 10
|
||||
|
||||
start_ops = [op for op in operations if "_start" in op]
|
||||
end_ops = [op for op in operations if "_end" in op]
|
||||
assert len(start_ops) == 5
|
||||
assert len(end_ops) == 5
|
||||
|
||||
|
||||
def test_context_manager_releases_on_exception():
|
||||
lock = RWLock()
|
||||
exception_raised = False
|
||||
|
||||
try:
|
||||
with lock.r_locked():
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
exception_raised = True
|
||||
|
||||
assert exception_raised
|
||||
|
||||
acquired = False
|
||||
with lock.w_locked():
|
||||
acquired = True
|
||||
|
||||
assert acquired
|
||||
|
||||
|
||||
def test_write_lock_releases_on_exception():
|
||||
lock = RWLock()
|
||||
exception_raised = False
|
||||
|
||||
try:
|
||||
with lock.w_locked():
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
exception_raised = True
|
||||
|
||||
assert exception_raised
|
||||
|
||||
acquired = False
|
||||
with lock.r_locked():
|
||||
acquired = True
|
||||
|
||||
assert acquired
|
||||
|
||||
|
||||
def test_stress_many_readers_few_writers():
|
||||
lock = RWLock()
|
||||
read_count = [0]
|
||||
write_count = [0]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader() -> None:
|
||||
for _ in range(10):
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
read_count[0] += 1
|
||||
time.sleep(0.001)
|
||||
|
||||
def writer() -> None:
|
||||
for _ in range(5):
|
||||
with lock.w_locked():
|
||||
with lock_for_counters:
|
||||
write_count[0] += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
reader_threads = [threading.Thread(target=reader) for _ in range(10)]
|
||||
writer_threads = [threading.Thread(target=writer) for _ in range(2)]
|
||||
|
||||
all_threads = reader_threads + writer_threads
|
||||
|
||||
for thread in all_threads:
|
||||
thread.start()
|
||||
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
|
||||
assert read_count[0] == 100
|
||||
assert write_count[0] == 10
|
||||
|
||||
|
||||
def test_nested_read_locks_same_thread():
|
||||
lock = RWLock()
|
||||
nested_acquired = False
|
||||
|
||||
with lock.r_locked():
|
||||
with lock.r_locked():
|
||||
nested_acquired = True
|
||||
|
||||
assert nested_acquired
|
||||
|
||||
|
||||
def test_manual_acquire_release():
|
||||
lock = RWLock()
|
||||
|
||||
lock.r_acquire()
|
||||
lock.r_release()
|
||||
|
||||
lock.w_acquire()
|
||||
lock.w_release()
|
||||
|
||||
with lock.r_locked():
|
||||
pass
|
||||
247
lib/crewai/tests/utilities/events/test_shutdown.py
Normal file
247
lib/crewai/tests/utilities/events/test_shutdown.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Tests for event bus shutdown and cleanup behavior.
|
||||
|
||||
This module tests graceful shutdown, task completion, and cleanup operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
|
||||
|
||||
class ShutdownTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
def test_shutdown_prevents_new_events():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
received_events.append(event)
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event = ShutdownTestEvent(type="after_shutdown")
|
||||
bus.emit("test_source", event)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_during_shutdown():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def handler(source: object, event: BaseEvent) -> None:
|
||||
received_events.append(event)
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event = ShutdownTestEvent(type="aemit_during_shutdown")
|
||||
await bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
def test_shutdown_flag_prevents_emit():
|
||||
bus = CrewAIEventsBus()
|
||||
emitted_count = [0]
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
emitted_count[0] += 1
|
||||
|
||||
event1 = ShutdownTestEvent(type="before_shutdown")
|
||||
bus.emit("test_source", event1)
|
||||
|
||||
time.sleep(0.1)
|
||||
assert emitted_count[0] == 1
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event2 = ShutdownTestEvent(type="during_shutdown")
|
||||
bus.emit("test_source", event2)
|
||||
|
||||
time.sleep(0.1)
|
||||
assert emitted_count[0] == 1
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
def test_concurrent_access_during_shutdown_flag():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
def emit_events() -> None:
|
||||
for i in range(10):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
time.sleep(0.01)
|
||||
|
||||
def set_shutdown_flag() -> None:
|
||||
time.sleep(0.05)
|
||||
bus._shutting_down = True
|
||||
|
||||
emit_thread = threading.Thread(target=emit_events)
|
||||
shutdown_thread = threading.Thread(target=set_shutdown_flag)
|
||||
|
||||
emit_thread.start()
|
||||
shutdown_thread.start()
|
||||
|
||||
emit_thread.join()
|
||||
shutdown_thread.join()
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
assert len(received_events) < 10
|
||||
assert len(received_events) > 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handlers_complete_before_shutdown_flag():
|
||||
bus = CrewAIEventsBus()
|
||||
completed_handlers = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
if not bus._shutting_down:
|
||||
completed_handlers.append(event)
|
||||
|
||||
for i in range(5):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert len(completed_handlers) == 5
|
||||
|
||||
|
||||
def test_scoped_handlers_cleanup():
|
||||
bus = CrewAIEventsBus()
|
||||
received_before = []
|
||||
received_during = []
|
||||
received_after = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def before_handler(source: object, event: BaseEvent) -> None:
|
||||
received_before.append(event)
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def during_handler(source: object, event: BaseEvent) -> None:
|
||||
received_during.append(event)
|
||||
|
||||
event1 = ShutdownTestEvent(type="during")
|
||||
bus.emit("source", event1)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 0
|
||||
assert len(received_during) == 1
|
||||
|
||||
event2 = ShutdownTestEvent(type="after_inner_scope")
|
||||
bus.emit("source", event2)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
|
||||
event3 = ShutdownTestEvent(type="after_outer_scope")
|
||||
bus.emit("source", event3)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
assert len(received_after) == 0
|
||||
|
||||
|
||||
def test_handler_registration_thread_safety():
|
||||
bus = CrewAIEventsBus()
|
||||
handlers_registered = [0]
|
||||
lock = threading.Lock()
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
def register_handlers() -> None:
|
||||
for _ in range(20):
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
pass
|
||||
|
||||
with lock:
|
||||
handlers_registered[0] += 1
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
threads = [threading.Thread(target=register_handlers) for _ in range(3)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert handlers_registered[0] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_handler_execution():
|
||||
bus = CrewAIEventsBus()
|
||||
sync_executed = []
|
||||
async_executed = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def sync_handler(source: object, event: BaseEvent) -> None:
|
||||
time.sleep(0.01)
|
||||
sync_executed.append(event)
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
async_executed.append(event)
|
||||
|
||||
for i in range(5):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
assert len(sync_executed) == 5
|
||||
assert len(async_executed) == 5
|
||||
189
lib/crewai/tests/utilities/events/test_thread_safety.py
Normal file
189
lib/crewai/tests/utilities/events/test_thread_safety.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for thread safety in CrewAI event bus.
|
||||
|
||||
This module tests concurrent event emission and handler registration.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
|
||||
class ThreadSafetyTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
def test_concurrent_emit_from_multiple_threads():
|
||||
received_events: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
threads: list[threading.Thread] = []
|
||||
num_threads = 10
|
||||
events_per_thread = 10
|
||||
|
||||
def emit_events(thread_id: int) -> None:
|
||||
for i in range(events_per_thread):
|
||||
event = ThreadSafetyTestEvent(type=f"thread_{thread_id}_event_{i}")
|
||||
crewai_event_bus.emit(f"source_{thread_id}", event)
|
||||
|
||||
for i in range(num_threads):
|
||||
thread = threading.Thread(target=emit_events, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_events) == num_threads * events_per_thread
|
||||
|
||||
|
||||
def test_concurrent_handler_registration():
|
||||
handlers_executed: list[int] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def create_handler(handler_id: int) -> Callable[[object, BaseEvent], None]:
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
handlers_executed.append(handler_id)
|
||||
|
||||
return handler
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
threads: list[threading.Thread] = []
|
||||
num_handlers = 20
|
||||
|
||||
def register_handler(handler_id: int) -> None:
|
||||
crewai_event_bus.register_handler(
|
||||
ThreadSafetyTestEvent, create_handler(handler_id)
|
||||
)
|
||||
|
||||
for i in range(num_handlers):
|
||||
thread = threading.Thread(target=register_handler, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
event = ThreadSafetyTestEvent(type="registration_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(handlers_executed) == num_handlers
|
||||
assert set(handlers_executed) == set(range(num_handlers))
|
||||
|
||||
|
||||
def test_concurrent_emit_and_registration():
|
||||
received_events: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
def emit_continuously() -> None:
|
||||
for i in range(50):
|
||||
event = ThreadSafetyTestEvent(type=f"emit_event_{i}")
|
||||
crewai_event_bus.emit("emitter", event)
|
||||
time.sleep(0.001)
|
||||
|
||||
def register_continuously() -> None:
|
||||
for _ in range(10):
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
emit_thread = threading.Thread(target=emit_continuously)
|
||||
register_thread = threading.Thread(target=register_continuously)
|
||||
|
||||
emit_thread.start()
|
||||
register_thread.start()
|
||||
|
||||
emit_thread.join()
|
||||
register_thread.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_events) > 0
|
||||
|
||||
|
||||
def test_stress_test_rapid_emit():
|
||||
received_count = [0]
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def counter_handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_count[0] += 1
|
||||
|
||||
num_events = 1000
|
||||
|
||||
for i in range(num_events):
|
||||
event = ThreadSafetyTestEvent(type=f"rapid_event_{i}")
|
||||
crewai_event_bus.emit("rapid_source", event)
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
assert received_count[0] == num_events
|
||||
|
||||
|
||||
def test_multiple_event_types_concurrent():
|
||||
class EventTypeA(BaseEvent):
|
||||
pass
|
||||
|
||||
class EventTypeB(BaseEvent):
|
||||
pass
|
||||
|
||||
received_a: list[BaseEvent] = []
|
||||
received_b: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(EventTypeA)
|
||||
def handler_a(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_a.append(event)
|
||||
|
||||
@crewai_event_bus.on(EventTypeB)
|
||||
def handler_b(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_b.append(event)
|
||||
|
||||
def emit_type_a() -> None:
|
||||
for i in range(50):
|
||||
crewai_event_bus.emit("source_a", EventTypeA(type=f"type_a_{i}"))
|
||||
|
||||
def emit_type_b() -> None:
|
||||
for i in range(50):
|
||||
crewai_event_bus.emit("source_b", EventTypeB(type=f"type_b_{i}"))
|
||||
|
||||
thread_a = threading.Thread(target=emit_type_a)
|
||||
thread_b = threading.Thread(target=emit_type_b)
|
||||
|
||||
thread_a.start()
|
||||
thread_b.start()
|
||||
|
||||
thread_a.join()
|
||||
thread_b.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_a) == 50
|
||||
assert len(received_b) == 50
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -49,6 +50,8 @@ from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
from ..utils import wait_for_event_handlers
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(request) -> dict:
|
||||
@@ -118,6 +121,7 @@ def test_crew_emits_start_kickoff_event(
|
||||
# Now when Crew creates EventListener, it will use our mocked telemetry
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_telemetry.crew_execution_span.assert_called_once_with(crew, None)
|
||||
mock_telemetry.end_crew.assert_called_once_with(crew, "hi")
|
||||
@@ -131,15 +135,20 @@ def test_crew_emits_start_kickoff_event(
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_end_kickoff_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def handle_crew_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for crew kickoff completed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
@@ -165,6 +174,7 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
eval_llm = LLM(model="gpt-4o-mini")
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.test(n_iterations=1, eval_llm=eval_llm)
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert len(received_events) == 3
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
@@ -181,40 +191,44 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_kickoff_failed_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def handle_crew_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def handle_crew_failed(source, event):
|
||||
received_events.append(event)
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
with patch.object(Crew, "_execute_tasks") as mock_execute:
|
||||
error_message = "Simulated crew kickoff failure"
|
||||
mock_execute.side_effect = Exception(error_message)
|
||||
|
||||
with patch.object(Crew, "_execute_tasks") as mock_execute:
|
||||
error_message = "Simulated crew kickoff failure"
|
||||
mock_execute.side_effect = Exception(error_message)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
crew.kickoff()
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].error == error_message
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_failed"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].error == error_message
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_failed"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_start_task_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def handle_task_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for task started event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "task_started"
|
||||
@@ -225,10 +239,12 @@ def test_crew_emits_end_task_event(
|
||||
base_agent, base_task, reset_event_listener_singleton
|
||||
):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def handle_task_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
mock_span = Mock()
|
||||
|
||||
@@ -246,6 +262,7 @@ def test_crew_emits_end_task_event(
|
||||
mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task)
|
||||
mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for task completed event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "task_completed"
|
||||
@@ -255,11 +272,13 @@ def test_crew_emits_end_task_event(
|
||||
def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
received_events = []
|
||||
received_sources = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def handle_task_failed(source, event):
|
||||
received_events.append(event)
|
||||
received_sources.append(source)
|
||||
event_received.set()
|
||||
|
||||
with patch.object(
|
||||
Task,
|
||||
@@ -281,6 +300,9 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
agent.execute_task(task=task)
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for task failed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_sources[0] == task
|
||||
assert received_events[0].error == error_message
|
||||
@@ -291,17 +313,27 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_emits_execution_started_and_completed_events(base_agent, base_task):
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
||||
def handle_agent_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def handle_agent_completed(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), (
|
||||
"Timeout waiting for agent execution events"
|
||||
)
|
||||
assert len(received_events) == 2
|
||||
assert received_events[0].agent == base_agent
|
||||
assert received_events[0].task == base_task
|
||||
@@ -320,10 +352,12 @@ def test_agent_emits_execution_started_and_completed_events(base_agent, base_tas
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_emits_execution_error_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionErrorEvent)
|
||||
def handle_agent_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
error_message = "Error happening while sending prompt to model."
|
||||
base_agent.max_retry_limit = 0
|
||||
@@ -337,6 +371,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
|
||||
task=base_task,
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for agent execution error event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].agent == base_agent
|
||||
assert received_events[0].task == base_task
|
||||
@@ -358,10 +395,12 @@ class SayHiTool(BaseTool):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_emits_finished_events():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="base_agent",
|
||||
@@ -377,6 +416,10 @@ def test_tools_emits_finished_events():
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for tool usage finished event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].agent_key == agent.key
|
||||
assert received_events[0].agent_role == agent.role
|
||||
@@ -389,10 +432,15 @@ def test_tools_emits_finished_events():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_emits_error_events():
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 48:
|
||||
all_events_received.set()
|
||||
|
||||
class ErrorTool(BaseTool):
|
||||
name: str = Field(
|
||||
@@ -423,6 +471,9 @@ def test_tools_emits_error_events():
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), (
|
||||
"Timeout waiting for tool usage error events"
|
||||
)
|
||||
assert len(received_events) == 48
|
||||
assert received_events[0].agent_key == agent.key
|
||||
assert received_events[0].agent_role == agent.role
|
||||
@@ -435,11 +486,13 @@ def test_tools_emits_error_events():
|
||||
|
||||
def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
mock_span = Mock()
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -458,6 +511,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
|
||||
mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
@@ -466,6 +520,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
|
||||
def test_flow_name_emitted_to_event_bus():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class MyFlowClass(Flow):
|
||||
name = "PRODUCTION_FLOW"
|
||||
@@ -477,118 +532,133 @@ def test_flow_name_emitted_to_event_bus():
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
flow = MyFlowClass()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "PRODUCTION_FLOW"
|
||||
|
||||
|
||||
def test_flow_emits_finish_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_finish(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_finish(source, event):
|
||||
received_events.append(event)
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "completed"
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "completed"
|
||||
flow = TestFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
flow = TestFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_finished"
|
||||
assert received_events[0].result == "completed"
|
||||
assert result == "completed"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for finish event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_finished"
|
||||
assert received_events[0].result == "completed"
|
||||
assert result == "completed"
|
||||
|
||||
|
||||
def test_flow_emits_method_execution_started_event():
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
second_event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
async def handle_method_start(source, event):
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if event.method_name == "second_method":
|
||||
second_event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def second_method(self):
|
||||
return "executed"
|
||||
@listen("begin")
|
||||
def second_method(self):
|
||||
return "executed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert len(received_events) == 2
|
||||
assert second_event_received.wait(timeout=5), (
|
||||
"Timeout waiting for second_method event"
|
||||
)
|
||||
assert len(received_events) == 2
|
||||
|
||||
assert received_events[0].method_name == "begin"
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "method_execution_started"
|
||||
# Events may arrive in any order due to async handlers, so check both are present
|
||||
method_names = {event.method_name for event in received_events}
|
||||
assert method_names == {"begin", "second_method"}
|
||||
|
||||
assert received_events[1].method_name == "second_method"
|
||||
assert received_events[1].flow_name == "TestFlow"
|
||||
assert received_events[1].type == "method_execution_started"
|
||||
for event in received_events:
|
||||
assert event.flow_name == "TestFlow"
|
||||
assert event.type == "method_execution_started"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_register_handler_adds_new_handler(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
def custom_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_started"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for handler event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_started"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_multiple_handlers_for_same_event(base_agent, base_task):
|
||||
received_events_1 = []
|
||||
received_events_2 = []
|
||||
event_received = threading.Event()
|
||||
|
||||
def handler_1(source, event):
|
||||
received_events_1.append(event)
|
||||
|
||||
def handler_2(source, event):
|
||||
received_events_2.append(event)
|
||||
event_received.set()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
assert received_events_1[0].type == "crew_kickoff_started"
|
||||
assert received_events_2[0].type == "crew_kickoff_started"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for handler events"
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
assert received_events_1[0].type == "crew_kickoff_started"
|
||||
assert received_events_2[0].type == "crew_kickoff_started"
|
||||
|
||||
|
||||
def test_flow_emits_created_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def handle_flow_created(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -598,6 +668,7 @@ def test_flow_emits_created_event():
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow created event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_created"
|
||||
@@ -605,11 +676,13 @@ def test_flow_emits_created_event():
|
||||
|
||||
def test_flow_emits_method_execution_failed_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
error = Exception("Simulated method failure")
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def handle_method_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -620,6 +693,9 @@ def test_flow_emits_method_execution_failed_event():
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution failed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].method_name == "begin"
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
@@ -641,6 +717,7 @@ def test_llm_emits_call_started_event():
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.call("Hello, how are you?")
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert len(received_events) == 2
|
||||
assert received_events[0].type == "llm_call_started"
|
||||
@@ -656,10 +733,12 @@ def test_llm_emits_call_started_event():
|
||||
@pytest.mark.isolated
|
||||
def test_llm_emits_call_failed_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_call_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
error_message = "OpenAI API call failed: Simulated API failure"
|
||||
|
||||
@@ -673,6 +752,7 @@ def test_llm_emits_call_failed_event():
|
||||
llm.call("Hello, how are you?")
|
||||
|
||||
assert str(exc_info.value) == "Simulated API failure"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].type == "llm_call_failed"
|
||||
assert received_events[0].error == error_message
|
||||
@@ -686,24 +766,28 @@ def test_llm_emits_call_failed_event():
|
||||
def test_llm_emits_stream_chunk_events():
|
||||
"""Test that LLM emits stream chunk events when streaming is enabled."""
|
||||
received_chunks = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 1:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
# Wait for at least one chunk
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
|
||||
|
||||
# Verify that we received chunks
|
||||
assert len(received_chunks) > 0
|
||||
# Verify that we received chunks
|
||||
assert len(received_chunks) > 0
|
||||
|
||||
# Verify that concatenating all chunks equals the final response
|
||||
assert "".join(received_chunks) == response
|
||||
# Verify that concatenating all chunks equals the final response
|
||||
assert "".join(received_chunks) == response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -711,23 +795,21 @@ def test_llm_no_stream_chunks_when_streaming_disabled():
|
||||
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
|
||||
received_chunks = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming disabled
|
||||
llm = LLM(model="gpt-4o", stream=False)
|
||||
|
||||
# Create an LLM with streaming disabled
|
||||
llm = LLM(model="gpt-4o", stream=False)
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
# Verify that we didn't receive any chunks
|
||||
assert len(received_chunks) == 0
|
||||
|
||||
# Verify that we didn't receive any chunks
|
||||
assert len(received_chunks) == 0
|
||||
|
||||
# Verify we got a response
|
||||
assert response and isinstance(response, str)
|
||||
# Verify we got a response
|
||||
assert response and isinstance(response, str)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -735,98 +817,105 @@ def test_streaming_fallback_to_non_streaming():
|
||||
"""Test that streaming falls back to non-streaming when there's an error."""
|
||||
received_chunks = []
|
||||
fallback_called = False
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 2:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
# Create a mock call method that handles the streaming error
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
|
||||
# Create a mock call method that handles the streaming error
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
# Return a response as if fallback succeeded
|
||||
return "Fallback response after streaming error"
|
||||
|
||||
# Return a response as if fallback succeeded
|
||||
return "Fallback response after streaming error"
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
try:
|
||||
# Call the LLM
|
||||
response = llm.call("Tell me a short joke")
|
||||
wait_for_event_handlers()
|
||||
|
||||
try:
|
||||
# Call the LLM
|
||||
response = llm.call("Tell me a short joke")
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
|
||||
|
||||
# Verify that we received some chunks
|
||||
assert len(received_chunks) == 2
|
||||
assert received_chunks[0] == "Test chunk 1"
|
||||
assert received_chunks[1] == "Test chunk 2"
|
||||
# Verify that we received some chunks
|
||||
assert len(received_chunks) == 2
|
||||
assert received_chunks[0] == "Test chunk 1"
|
||||
assert received_chunks[1] == "Test chunk 2"
|
||||
|
||||
# Verify fallback was triggered
|
||||
assert fallback_called
|
||||
# Verify fallback was triggered
|
||||
assert fallback_called
|
||||
|
||||
# Verify we got the fallback response
|
||||
assert response == "Fallback response after streaming error"
|
||||
# Verify we got the fallback response
|
||||
assert response == "Fallback response after streaming error"
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_empty_response_handling():
|
||||
"""Test that streaming handles empty responses correctly."""
|
||||
received_chunks = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 3:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-3.5-turbo", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-3.5-turbo", stream=True)
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
# Create a mock call method that simulates empty chunks
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
|
||||
# Create a mock call method that simulates empty chunks
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
try:
|
||||
# Call the LLM - this should handle empty response
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
try:
|
||||
# Call the LLM - this should handle empty response
|
||||
response = llm.call("Tell me a short joke")
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for empty chunks"
|
||||
|
||||
# Verify that we received empty chunks
|
||||
assert len(received_chunks) == 3
|
||||
assert all(chunk == "" for chunk in received_chunks)
|
||||
# Verify that we received empty chunks
|
||||
assert len(received_chunks) == 3
|
||||
assert all(chunk == "" for chunk in received_chunks)
|
||||
|
||||
# Verify the response is the default message for empty responses
|
||||
assert "I apologize" in response and "couldn't generate" in response
|
||||
# Verify the response is the default message for empty responses
|
||||
assert "I apologize" in response and "couldn't generate" in response
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -835,41 +924,49 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1 and len(stream_event) >= 12:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
if (
|
||||
len(completed_event) >= 1
|
||||
and len(started_event) >= 1
|
||||
and len(stream_event) >= 12
|
||||
):
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
agent = Agent(
|
||||
role="TestAgent",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="TestAgent",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
assert len(started_event) == 1
|
||||
@@ -899,28 +996,30 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task])
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task])
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
assert len(started_event) == 1
|
||||
@@ -950,32 +1049,41 @@ def test_llm_emits_event_with_lite_agent():
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1 and len(stream_event) >= 15:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
if (
|
||||
len(completed_event) >= 1
|
||||
and len(started_event) >= 1
|
||||
and len(stream_event) >= 15
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
agent = Agent(
|
||||
role="Speaker",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
||||
|
||||
agent = Agent(
|
||||
role="Speaker",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
||||
assert all_events_received.wait(timeout=10), "Timeout waiting for all events"
|
||||
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
|
||||
39
lib/crewai/tests/utils.py
Normal file
39
lib/crewai/tests/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Test utilities for CrewAI tests."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def wait_for_event_handlers(timeout: float = 5.0) -> None:
|
||||
"""Wait for all pending event handlers to complete.
|
||||
|
||||
This helper ensures all sync and async handlers finish processing before
|
||||
proceeding. Useful in tests to make assertions deterministic.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
loop = getattr(crewai_event_bus, "_loop", None)
|
||||
|
||||
if loop and not loop.is_closed():
|
||||
|
||||
async def _wait_for_async_tasks() -> None:
|
||||
tasks = {
|
||||
t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()
|
||||
}
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(_wait_for_async_tasks(), loop)
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
crewai_event_bus._sync_executor.shutdown(wait=True)
|
||||
crewai_event_bus._sync_executor = ThreadPoolExecutor(
|
||||
max_workers=10,
|
||||
thread_name_prefix="CrewAISyncHandler",
|
||||
)
|
||||
Reference in New Issue
Block a user