fix: add additional stack checks

This commit is contained in:
Greyson LaLonde
2026-01-20 01:39:55 -05:00
parent 5914f8be7d
commit d78ee915dc
4 changed files with 232 additions and 43 deletions

View File

@@ -35,10 +35,15 @@ def cleanup_event_handlers() -> Generator[None, Any, None]:
def reset_event_state() -> None:
"""Reset event system state before each test for isolation."""
from crewai.events.base_events import reset_emission_counter
from crewai.events.event_context import _event_id_stack
from crewai.events.event_context import (
EventContextConfig,
_event_context_config,
_event_id_stack,
)
reset_emission_counter()
_event_id_stack.set(())
_event_context_config.set(EventContextConfig())
@pytest.fixture(autouse=True, scope="function")

View File

@@ -24,6 +24,8 @@ from crewai.events.event_context import (
VALID_EVENT_PAIRS,
get_current_parent_id,
get_enclosing_parent_id,
handle_empty_pop,
handle_mismatch,
pop_event_scope,
push_event_scope,
)
@@ -342,19 +344,12 @@ class CrewAIEventsBus:
event.parent_event_id = get_enclosing_parent_id()
popped = pop_event_scope()
if popped is None:
self._console.print(
f"[CrewAIEventsBus] Warning: Ending event '{event_type_name}' "
"emitted with empty scope stack. Missing starting event?"
)
handle_empty_pop(event_type_name)
else:
_, popped_type = popped
expected_start = VALID_EVENT_PAIRS.get(event_type_name)
if expected_start and popped_type and popped_type != expected_start:
self._console.print(
f"[CrewAIEventsBus] Warning: Event pairing mismatch. "
f"'{event_type_name}' closed '{popped_type}' "
f"(expected '{expected_start}')"
)
handle_mismatch(event_type_name, popped_type, expected_start)
elif event_type_name in SCOPE_STARTING_EVENTS:
event.parent_event_id = get_current_parent_id()
push_event_scope(event.event_id, event_type_name)

View File

@@ -3,52 +3,82 @@
from collections.abc import Generator
from contextlib import contextmanager
import contextvars
from dataclasses import dataclass
from enum import Enum
from crewai.events.utils.console_formatter import ConsoleFormatter
class MismatchBehavior(Enum):
"""Behavior when event pairs don't match."""
WARN = "warn"
RAISE = "raise"
SILENT = "silent"
@dataclass
class EventContextConfig:
"""Configuration for event context behavior."""
max_stack_depth: int = 100
mismatch_behavior: MismatchBehavior = MismatchBehavior.WARN
empty_pop_behavior: MismatchBehavior = MismatchBehavior.WARN
class StackDepthExceededError(Exception):
"""Raised when stack depth limit is exceeded."""
class EventPairingError(Exception):
"""Raised when event pairs don't match."""
class EmptyStackError(Exception):
"""Raised when popping from empty stack."""
_event_id_stack: contextvars.ContextVar[tuple[tuple[str, str], ...]] = (
contextvars.ContextVar("_event_id_stack", default=())
)
_event_context_config: contextvars.ContextVar[EventContextConfig | None] = (
contextvars.ContextVar("_event_context_config", default=None)
)
_default_config = EventContextConfig()
_console = ConsoleFormatter()
def get_current_parent_id() -> str | None:
"""Get the current parent event ID from the stack.
Returns:
The top event ID if stack is non-empty, otherwise None.
"""
"""Get the current parent event ID from the stack."""
stack = _event_id_stack.get()
return stack[-1][0] if stack else None
def get_enclosing_parent_id() -> str | None:
"""Get the parent of the current scope (stack[-2]).
Used by ending events to become siblings of their matching started events.
Returns:
The second-to-top event ID, or None if stack has fewer than 2 items.
"""
"""Get the parent of the current scope (stack[-2])."""
stack = _event_id_stack.get()
return stack[-2][0] if len(stack) >= 2 else None
def push_event_scope(event_id: str, event_type: str = "") -> None:
"""Push an event ID and type onto the scope stack.
Args:
event_id: The event ID to push.
event_type: The event type name (for pairing validation).
"""
"""Push an event ID and type onto the scope stack."""
config = _event_context_config.get() or _default_config
stack = _event_id_stack.get()
if config.max_stack_depth > 0 and len(stack) >= config.max_stack_depth:
raise StackDepthExceededError(
f"Event stack depth limit ({config.max_stack_depth}) exceeded. "
f"This usually indicates missing ending events."
)
_event_id_stack.set((*stack, (event_id, event_type)))
def pop_event_scope() -> tuple[str, str] | None:
"""Pop an event entry from the scope stack.
Returns:
Tuple of (event_id, event_type), or None if stack was empty.
"""
"""Pop an event entry from the scope stack."""
stack = _event_id_stack.get()
if not stack:
return None
@@ -56,18 +86,41 @@ def pop_event_scope() -> tuple[str, str] | None:
return stack[-1]
def handle_empty_pop(event_type_name: str) -> None:
"""Handle a pop attempt on an empty stack."""
config = _event_context_config.get() or _default_config
msg = (
f"Ending event '{event_type_name}' emitted with empty scope stack. "
"Missing starting event?"
)
if config.empty_pop_behavior == MismatchBehavior.RAISE:
raise EmptyStackError(msg)
if config.empty_pop_behavior == MismatchBehavior.WARN:
_console.print(f"[CrewAIEventsBus] Warning: {msg}")
def handle_mismatch(
event_type_name: str,
popped_type: str,
expected_start: str,
) -> None:
"""Handle a mismatched event pair."""
config = _event_context_config.get() or _default_config
msg = (
f"Event pairing mismatch. '{event_type_name}' closed '{popped_type}' "
f"(expected '{expected_start}')"
)
if config.mismatch_behavior == MismatchBehavior.RAISE:
raise EventPairingError(msg)
if config.mismatch_behavior == MismatchBehavior.WARN:
_console.print(f"[CrewAIEventsBus] Warning: {msg}")
@contextmanager
def event_scope(event_id: str, event_type: str = "") -> Generator[None, None, None]:
"""Context manager to establish a parent event scope.
Safe to use alongside emit() auto-management. If the event_id is already
on the stack (e.g., from a starting event's auto-push), this will not
double-push or double-pop.
Args:
event_id: The event ID to set as the current parent.
event_type: The event type name (for pairing validation).
"""
"""Context manager to establish a parent event scope."""
stack = _event_id_stack.get()
already_on_stack = any(entry[0] == event_id for entry in stack)
if not already_on_stack:
@@ -82,16 +135,25 @@ def event_scope(event_id: str, event_type: str = "") -> Generator[None, None, No
SCOPE_STARTING_EVENTS: frozenset[str] = frozenset(
{
"flow_started",
"method_execution_started",
"crew_kickoff_started",
"agent_execution_started",
"lite_agent_execution_started",
"task_started",
"llm_call_started",
"llm_guardrail_started",
"tool_usage_started",
"memory_retrieval_started",
"memory_save_started",
"memory_query_started",
"knowledge_retrieval_started",
"knowledge_query_started",
"a2a_delegation_started",
"a2a_conversation_started",
"a2a_polling_started",
"a2a_streaming_started",
"a2a_server_task_started",
"a2a_parallel_delegation_started",
"agent_reasoning_started",
}
)
@@ -100,14 +162,19 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
{
"flow_finished",
"flow_paused",
"method_execution_finished",
"method_execution_failed",
"crew_kickoff_completed",
"crew_kickoff_failed",
"agent_execution_completed",
"agent_execution_error",
"lite_agent_execution_completed",
"lite_agent_execution_error",
"task_completed",
"task_failed",
"llm_call_completed",
"llm_call_failed",
"llm_guardrail_completed",
"tool_usage_finished",
"tool_usage_error",
"memory_retrieval_completed",
@@ -115,8 +182,17 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
"memory_save_failed",
"memory_query_completed",
"memory_query_failed",
"knowledge_retrieval_completed",
"knowledge_query_completed",
"knowledge_query_failed",
"a2a_delegation_completed",
"a2a_conversation_completed",
"a2a_polling_completed",
"a2a_streaming_completed",
"a2a_server_task_completed",
"a2a_server_task_canceled",
"a2a_server_task_failed",
"a2a_parallel_delegation_completed",
"agent_reasoning_completed",
"agent_reasoning_failed",
}
@@ -125,14 +201,19 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
VALID_EVENT_PAIRS: dict[str, str] = {
"flow_finished": "flow_started",
"flow_paused": "flow_started",
"method_execution_finished": "method_execution_started",
"method_execution_failed": "method_execution_started",
"crew_kickoff_completed": "crew_kickoff_started",
"crew_kickoff_failed": "crew_kickoff_started",
"agent_execution_completed": "agent_execution_started",
"agent_execution_error": "agent_execution_started",
"lite_agent_execution_completed": "lite_agent_execution_started",
"lite_agent_execution_error": "lite_agent_execution_started",
"task_completed": "task_started",
"task_failed": "task_started",
"llm_call_completed": "llm_call_started",
"llm_call_failed": "llm_call_started",
"llm_guardrail_completed": "llm_guardrail_started",
"tool_usage_finished": "tool_usage_started",
"tool_usage_error": "tool_usage_started",
"memory_retrieval_completed": "memory_retrieval_started",
@@ -140,8 +221,17 @@ VALID_EVENT_PAIRS: dict[str, str] = {
"memory_save_failed": "memory_save_started",
"memory_query_completed": "memory_query_started",
"memory_query_failed": "memory_query_started",
"knowledge_retrieval_completed": "knowledge_retrieval_started",
"knowledge_query_completed": "knowledge_query_started",
"knowledge_query_failed": "knowledge_query_started",
"a2a_delegation_completed": "a2a_delegation_started",
"a2a_conversation_completed": "a2a_conversation_started",
"a2a_polling_completed": "a2a_polling_started",
"a2a_streaming_completed": "a2a_streaming_started",
"a2a_server_task_completed": "a2a_server_task_started",
"a2a_server_task_canceled": "a2a_server_task_started",
"a2a_server_task_failed": "a2a_server_task_started",
"a2a_parallel_delegation_completed": "a2a_parallel_delegation_started",
"agent_reasoning_completed": "agent_reasoning_started",
"agent_reasoning_failed": "agent_reasoning_started",
}

View File

@@ -0,0 +1,99 @@
"""Tests for event context management."""
import pytest
from crewai.events.event_context import (
SCOPE_ENDING_EVENTS,
SCOPE_STARTING_EVENTS,
VALID_EVENT_PAIRS,
EmptyStackError,
EventPairingError,
MismatchBehavior,
StackDepthExceededError,
_event_context_config,
EventContextConfig,
get_current_parent_id,
get_enclosing_parent_id,
handle_empty_pop,
handle_mismatch,
pop_event_scope,
push_event_scope,
)
class TestStackOperations:
"""Tests for stack push/pop operations."""
def test_empty_stack_returns_none(self) -> None:
assert get_current_parent_id() is None
assert get_enclosing_parent_id() is None
def test_push_and_get_parent(self) -> None:
push_event_scope("event-1", "task_started")
assert get_current_parent_id() == "event-1"
def test_nested_push(self) -> None:
push_event_scope("event-1", "crew_kickoff_started")
push_event_scope("event-2", "task_started")
assert get_current_parent_id() == "event-2"
assert get_enclosing_parent_id() == "event-1"
def test_pop_restores_parent(self) -> None:
push_event_scope("event-1", "crew_kickoff_started")
push_event_scope("event-2", "task_started")
popped = pop_event_scope()
assert popped == ("event-2", "task_started")
assert get_current_parent_id() == "event-1"
def test_pop_empty_stack_returns_none(self) -> None:
assert pop_event_scope() is None
class TestStackDepthLimit:
"""Tests for stack depth limit."""
def test_depth_limit_exceeded_raises(self) -> None:
_event_context_config.set(EventContextConfig(max_stack_depth=3))
push_event_scope("event-1", "type-1")
push_event_scope("event-2", "type-2")
push_event_scope("event-3", "type-3")
with pytest.raises(StackDepthExceededError):
push_event_scope("event-4", "type-4")
class TestMismatchHandling:
"""Tests for mismatch behavior."""
def test_handle_mismatch_raises_when_configured(self) -> None:
_event_context_config.set(
EventContextConfig(mismatch_behavior=MismatchBehavior.RAISE)
)
with pytest.raises(EventPairingError):
handle_mismatch("task_completed", "llm_call_started", "task_started")
def test_handle_empty_pop_raises_when_configured(self) -> None:
_event_context_config.set(
EventContextConfig(empty_pop_behavior=MismatchBehavior.RAISE)
)
with pytest.raises(EmptyStackError):
handle_empty_pop("task_completed")
class TestEventTypeSets:
"""Tests for event type set completeness."""
def test_all_ending_events_have_pairs(self) -> None:
for ending_event in SCOPE_ENDING_EVENTS:
assert ending_event in VALID_EVENT_PAIRS
def test_all_pairs_reference_starting_events(self) -> None:
for ending_event, starting_event in VALID_EVENT_PAIRS.items():
assert starting_event in SCOPE_STARTING_EVENTS
def test_starting_and_ending_are_disjoint(self) -> None:
overlap = SCOPE_STARTING_EVENTS & SCOPE_ENDING_EVENTS
assert not overlap