From c653d41b89c3b0292c0b0e7ae19ec2a19ebbcd93 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 3 Apr 2026 22:31:36 +0800 Subject: [PATCH] feat: add EventRecord to RuntimeState checkpoints --- lib/crewai/src/crewai/events/event_bus.py | 7 + lib/crewai/src/crewai/state/event_record.py | 150 +++++++ lib/crewai/src/crewai/state/provider/core.py | 12 +- lib/crewai/src/crewai/state/runtime.py | 43 +- lib/crewai/tests/test_event_record.py | 414 +++++++++++++++++++ 5 files changed, 618 insertions(+), 8 deletions(-) create mode 100644 lib/crewai/src/crewai/state/event_record.py create mode 100644 lib/crewai/tests/test_event_record.py diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index 01aa1d9a6..10a612dd3 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -480,6 +480,10 @@ class CrewAIEventsBus: event.parent_event_id = get_current_parent_id() set_last_event_id(event.event_id) + + if self._runtime_state is not None: + self._runtime_state.event_record.add(event) + event_type = type(event) with self._rwlock.r_locked(): @@ -578,6 +582,9 @@ class CrewAIEventsBus: source: The object emitting the event event: The event instance to emit """ + if self._runtime_state is not None: + self._runtime_state.event_record.add(event) + event_type = type(event) with self._rwlock.r_locked(): diff --git a/lib/crewai/src/crewai/state/event_record.py b/lib/crewai/src/crewai/state/event_record.py new file mode 100644 index 000000000..bfa6b1fbb --- /dev/null +++ b/lib/crewai/src/crewai/state/event_record.py @@ -0,0 +1,150 @@ +"""Directed record of execution events. + +Stores events as nodes with typed edges for parent/child, causal, and +sequential relationships. Provides O(1) lookups and traversal. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +from crewai.events.base_events import BaseEvent + + +EdgeType = Literal[ + "parent", + "child", + "trigger", + "triggered_by", + "next", + "previous", + "started", + "completed_by", +] + + +class EventNode(BaseModel): + """A node wrapping a single event with its adjacency lists.""" + + event: BaseEvent + edges: dict[EdgeType, list[str]] = Field(default_factory=dict) + + def add_edge(self, edge_type: EdgeType, target_id: str) -> None: + """Add an edge from this node to another. + + Args: + edge_type: The relationship type. + target_id: The event_id of the target node. + """ + self.edges.setdefault(edge_type, []).append(target_id) + + def neighbors(self, edge_type: EdgeType) -> list[str]: + """Return neighbor IDs for a given edge type. + + Args: + edge_type: The relationship type to query. + + Returns: + List of event IDs connected by this edge type. + """ + return self.edges.get(edge_type, []) + + +class EventRecord(BaseModel): + """Directed record of execution events with O(1) node lookup. + + Events are added via :meth:`add` which automatically wires edges + based on the event's relationship fields — ``parent_event_id``, + ``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``. + """ + + nodes: dict[str, EventNode] = Field(default_factory=dict) + + def add(self, event: BaseEvent) -> EventNode: + """Add an event to the record and wire its edges. + + Args: + event: The event to insert. + + Returns: + The created node. + """ + node = EventNode(event=event) + self.nodes[event.event_id] = node + + if event.parent_event_id and event.parent_event_id in self.nodes: + node.add_edge("parent", event.parent_event_id) + self.nodes[event.parent_event_id].add_edge("child", event.event_id) + + if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes: + node.add_edge("triggered_by", event.triggered_by_event_id) + self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id) + + if event.previous_event_id and event.previous_event_id in self.nodes: + node.add_edge("previous", event.previous_event_id) + self.nodes[event.previous_event_id].add_edge("next", event.event_id) + + if event.started_event_id and event.started_event_id in self.nodes: + node.add_edge("started", event.started_event_id) + self.nodes[event.started_event_id].add_edge("completed_by", event.event_id) + + return node + + def get(self, event_id: str) -> EventNode | None: + """Look up a node by event ID. + + Args: + event_id: The event's unique identifier. + + Returns: + The node, or None if not found. + """ + return self.nodes.get(event_id) + + def descendants(self, event_id: str) -> list[EventNode]: + """Return all descendant nodes, children recursively. + + Args: + event_id: The root event ID to start from. + + Returns: + All descendant nodes in breadth-first order. + """ + result: list[EventNode] = [] + queue = [event_id] + visited: set[str] = set() + + while queue: + current_id = queue.pop(0) + if current_id in visited: + continue + visited.add(current_id) + + node = self.nodes.get(current_id) + if node is None: + continue + + for child_id in node.neighbors("child"): + if child_id not in visited: + child_node = self.nodes.get(child_id) + if child_node: + result.append(child_node) + queue.append(child_id) + + return result + + def roots(self) -> list[EventNode]: + """Return all root nodes — events with no parent. + + Returns: + List of root event nodes. + """ + return [node for node in self.nodes.values() if not node.neighbors("parent")] + + def __len__(self) -> int: + return len(self.nodes) + + def __contains__(self, event_id: str) -> bool: + return event_id in self.nodes diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py index a3f7c9c5a..71698c712 100644 --- a/lib/crewai/src/crewai/state/provider/core.py +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -12,8 +12,8 @@ from pydantic_core import CoreSchema, core_schema class BaseProvider(Protocol): """Interface for persisting and restoring runtime state checkpoints. - Implementations handle the storage backend (filesystem, cloud, database, - etc.) while ``RuntimeState`` handles serialization. + Implementations handle the storage backend — filesystem, cloud, database, + etc. — while ``RuntimeState`` handles serialization. """ @classmethod @@ -39,10 +39,10 @@ class BaseProvider(Protocol): Args: data: The serialized string to persist. - directory: Logical destination (path, bucket prefix, etc.). + directory: Logical destination: path, bucket prefix, etc. Returns: - A location identifier for the saved checkpoint (e.g. file path, URI). + A location identifier for the saved checkpoint, such as a file path or URI. """ ... @@ -51,9 +51,9 @@ class BaseProvider(Protocol): Args: data: The serialized string to persist. - directory: Logical destination (path, bucket prefix, etc.). + directory: Logical destination: path, bucket prefix, etc. Returns: - A location identifier for the saved checkpoint (e.g. file path, URI). + A location identifier for the saved checkpoint, such as a file path or URI. """ ... diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 784154c82..8fa684264 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -9,11 +9,19 @@ via ``RuntimeState.model_rebuild()``. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict -from pydantic import PrivateAttr, RootModel +from pydantic import ( + ModelWrapValidatorHandler, + PrivateAttr, + RootModel, + SerializerFunctionWrapHandler, + model_serializer, + model_validator, +) from crewai.context import capture_execution_context +from crewai.state.event_record import EventRecord from crewai.state.provider.core import BaseProvider from crewai.state.provider.json_provider import JsonProvider @@ -22,6 +30,11 @@ if TYPE_CHECKING: from crewai import Entity +class CheckpointPayload(TypedDict): + entities: list[Entity] + event_record: dict[str, Any] + + def _entity_discriminator(v: dict[str, Any] | object) -> str: if isinstance(v, dict): raw = v.get("entity_type", "agent") @@ -64,6 +77,32 @@ def _sync_checkpoint_fields(entity: object) -> None: class RuntimeState(RootModel): # type: ignore[type-arg] root: list[Entity] _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) + _event_record: EventRecord = PrivateAttr(default_factory=EventRecord) + + @property + def event_record(self) -> EventRecord: + """The execution event record.""" + return self._event_record + + @model_serializer(mode="wrap") + def _serialize(self, handler: SerializerFunctionWrapHandler) -> CheckpointPayload: + return { + "entities": handler(self), + "event_record": self._event_record.model_dump(), + } + + @model_validator(mode="wrap") + @classmethod + def _deserialize( + cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState] + ) -> RuntimeState: + if isinstance(data, dict) and "entities" in data: + record_data = data.get("event_record") + state = handler(data["entities"]) + if record_data: + state._event_record = EventRecord.model_validate(record_data) + return state + return handler(data) def checkpoint(self, directory: str) -> str: """Write a checkpoint file to the directory. diff --git a/lib/crewai/tests/test_event_record.py b/lib/crewai/tests/test_event_record.py new file mode 100644 index 000000000..9cd384b02 --- /dev/null +++ b/lib/crewai/tests/test_event_record.py @@ -0,0 +1,414 @@ +"""Tests for EventRecord data structure and RuntimeState integration.""" + +from __future__ import annotations + +import json + +import pytest + +from crewai.events.base_events import BaseEvent +from crewai.state.event_record import EventRecord, EventNode + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _event(type: str, **kwargs) -> BaseEvent: + return BaseEvent(type=type, **kwargs) + + +def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]: + """Build a simple chain: e0 → e1 → e2 → ... with previous_event_id.""" + g = EventRecord() + events: list[BaseEvent] = [] + for i in range(n): + e = _event( + f"step_{i}", + previous_event_id=events[-1].event_id if events else None, + emission_sequence=i + 1, + ) + events.append(e) + g.add(e) + return g, events + + +def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]: + """Build a parent/child tree: + + crew_start + ├── task_start + │ ├── agent_start + │ └── agent_complete (started=agent_start) + └── task_complete (started=task_start) + """ + g = EventRecord() + crew_start = _event("crew_kickoff_started", emission_sequence=1) + task_start = _event( + "task_started", + parent_event_id=crew_start.event_id, + previous_event_id=crew_start.event_id, + emission_sequence=2, + ) + agent_start = _event( + "agent_execution_started", + parent_event_id=task_start.event_id, + previous_event_id=task_start.event_id, + emission_sequence=3, + ) + agent_complete = _event( + "agent_execution_completed", + parent_event_id=task_start.event_id, + previous_event_id=agent_start.event_id, + started_event_id=agent_start.event_id, + emission_sequence=4, + ) + task_complete = _event( + "task_completed", + parent_event_id=crew_start.event_id, + previous_event_id=agent_complete.event_id, + started_event_id=task_start.event_id, + emission_sequence=5, + ) + + for e in [crew_start, task_start, agent_start, agent_complete, task_complete]: + g.add(e) + + return g, { + "crew_start": crew_start, + "task_start": task_start, + "agent_start": agent_start, + "agent_complete": agent_complete, + "task_complete": task_complete, + } + + +# ── EventNode tests ───────────────────────────────────────────────── + + +class TestEventNode: + def test_add_edge(self): + node = EventNode(event=_event("test")) + node.add_edge("child", "abc") + assert node.neighbors("child") == ["abc"] + + def test_neighbors_empty(self): + node = EventNode(event=_event("test")) + assert node.neighbors("parent") == [] + + def test_multiple_edges_same_type(self): + node = EventNode(event=_event("test")) + node.add_edge("child", "a") + node.add_edge("child", "b") + assert node.neighbors("child") == ["a", "b"] + + +# ── EventRecord core tests ─────────────────────────────────────────── + + +class TestEventRecordCore: + def test_add_single_event(self): + g = EventRecord() + e = _event("test") + node = g.add(e) + assert len(g) == 1 + assert e.event_id in g + assert node.event.type == "test" + + def test_get_existing(self): + g = EventRecord() + e = _event("test") + g.add(e) + assert g.get(e.event_id) is not None + + def test_get_missing(self): + g = EventRecord() + assert g.get("nonexistent") is None + + def test_contains(self): + g = EventRecord() + e = _event("test") + g.add(e) + assert e.event_id in g + assert "missing" not in g + + +# ── Edge wiring tests ─────────────────────────────────────────────── + + +class TestEdgeWiring: + def test_parent_child_bidirectional(self): + g = EventRecord() + parent = _event("parent") + child = _event("child", parent_event_id=parent.event_id) + g.add(parent) + g.add(child) + + parent_node = g.get(parent.event_id) + child_node = g.get(child.event_id) + assert child.event_id in parent_node.neighbors("child") + assert parent.event_id in child_node.neighbors("parent") + + def test_previous_next_bidirectional(self): + g, events = _linear_record(3) + node0 = g.get(events[0].event_id) + node1 = g.get(events[1].event_id) + node2 = g.get(events[2].event_id) + + assert events[1].event_id in node0.neighbors("next") + assert events[0].event_id in node1.neighbors("previous") + assert events[2].event_id in node1.neighbors("next") + assert events[1].event_id in node2.neighbors("previous") + + def test_trigger_bidirectional(self): + g = EventRecord() + cause = _event("cause") + effect = _event("effect", triggered_by_event_id=cause.event_id) + g.add(cause) + g.add(effect) + + assert effect.event_id in g.get(cause.event_id).neighbors("trigger") + assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by") + + def test_started_completed_by_bidirectional(self): + g = EventRecord() + start = _event("start") + end = _event("end", started_event_id=start.event_id) + g.add(start) + g.add(end) + + assert end.event_id in g.get(start.event_id).neighbors("completed_by") + assert start.event_id in g.get(end.event_id).neighbors("started") + + def test_dangling_reference_ignored(self): + """Edge to a non-existent node should not be wired.""" + g = EventRecord() + e = _event("orphan", parent_event_id="nonexistent") + g.add(e) + node = g.get(e.event_id) + assert node.neighbors("parent") == [] + + +# ── Edge symmetry validation ───────────────────────────────────────── + + +SYMMETRIC_PAIRS = [ + ("parent", "child"), + ("previous", "next"), + ("triggered_by", "trigger"), + ("started", "completed_by"), +] + + +class TestEdgeSymmetry: + @pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS) + def test_symmetry_on_tree(self, forward, reverse): + g, _ = _tree_record() + for node_id, node in g.nodes.items(): + for target_id in node.neighbors(forward): + target_node = g.get(target_id) + assert target_node is not None, f"{target_id} missing from record" + assert node_id in target_node.neighbors(reverse), ( + f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} " + f"but {target_id} has no {reverse.value} back to {node_id}" + ) + + @pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS) + def test_symmetry_on_linear(self, forward, reverse): + g, _ = _linear_record(10) + for node_id, node in g.nodes.items(): + for target_id in node.neighbors(forward): + target_node = g.get(target_id) + assert target_node is not None + assert node_id in target_node.neighbors(reverse) + + +# ── Ordering tests ─────────────────────────────────────────────────── + + +class TestOrdering: + def test_emission_sequence_monotonic(self): + g, events = _linear_record(10) + sequences = [e.emission_sequence for e in events] + assert sequences == sorted(sequences) + assert len(set(sequences)) == len(sequences), "Duplicate sequences" + + def test_next_chain_follows_sequence_order(self): + g, events = _linear_record(5) + current = g.get(events[0].event_id) + visited = [] + while current: + visited.append(current.event.event_id) + nexts = current.neighbors("next") + current = g.get(nexts[0]) if nexts else None + assert visited == [e.event_id for e in events] + + +# ── Traversal tests ───────────────────────────────────────────────── + + +class TestTraversal: + def test_roots_single_root(self): + g, events = _tree_record() + roots = g.roots() + assert len(roots) == 1 + assert roots[0].event.type == "crew_kickoff_started" + + def test_roots_multiple(self): + g = EventRecord() + g.add(_event("root1")) + g.add(_event("root2")) + assert len(g.roots()) == 2 + + def test_descendants_of_crew_start(self): + g, events = _tree_record() + desc = g.descendants(events["crew_start"].event_id) + desc_types = {n.event.type for n in desc} + assert desc_types == { + "task_started", + "task_completed", + "agent_execution_started", + "agent_execution_completed", + } + + def test_descendants_of_leaf(self): + g, events = _tree_record() + desc = g.descendants(events["task_complete"].event_id) + assert desc == [] + + def test_descendants_does_not_include_self(self): + g, events = _tree_record() + desc = g.descendants(events["crew_start"].event_id) + desc_ids = {n.event.event_id for n in desc} + assert events["crew_start"].event_id not in desc_ids + + +# ── Serialization round-trip tests ────────────────────────────────── + + +class TestSerialization: + def test_empty_record_roundtrip(self): + g = EventRecord() + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 0 + + def test_linear_record_roundtrip(self): + g, events = _linear_record(5) + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 5 + for e in events: + assert e.event_id in restored + + def test_tree_record_roundtrip(self): + g, events = _tree_record() + restored = EventRecord.model_validate_json(g.model_dump_json()) + assert len(restored) == 5 + + # Verify edges survived + crew_node = restored.get(events["crew_start"].event_id) + assert len(crew_node.neighbors("child")) == 2 + + def test_roundtrip_preserves_edge_symmetry(self): + g, _ = _tree_record() + restored = EventRecord.model_validate_json(g.model_dump_json()) + for node_id, node in restored.nodes.items(): + for forward, reverse in SYMMETRIC_PAIRS: + for target_id in node.neighbors(forward): + target_node = restored.get(target_id) + assert node_id in target_node.neighbors(reverse) + + def test_roundtrip_preserves_event_data(self): + g = EventRecord() + e = _event( + "test", + source_type="crew", + task_id="t1", + agent_role="researcher", + emission_sequence=42, + ) + g.add(e) + restored = EventRecord.model_validate_json(g.model_dump_json()) + re = restored.get(e.event_id).event + assert re.type == "test" + assert re.source_type == "crew" + assert re.task_id == "t1" + assert re.agent_role == "researcher" + assert re.emission_sequence == 42 + + +# ── RuntimeState integration tests ────────────────────────────────── + + +class TestRuntimeStateIntegration: + def test_runtime_state_serializes_event_record(self): + from crewai import Agent, Crew, RuntimeState + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + e1 = _event("crew_started", emission_sequence=1) + e2 = _event( + "task_started", + parent_event_id=e1.event_id, + emission_sequence=2, + ) + state.event_record.add(e1) + state.event_record.add(e2) + + dumped = json.loads(state.model_dump_json()) + assert "entities" in dumped + assert "event_record" in dumped + assert len(dumped["event_record"]["nodes"]) == 2 + + def test_runtime_state_roundtrip_with_record(self): + from crewai import Agent, Crew, RuntimeState + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + e1 = _event("crew_started", emission_sequence=1) + e2 = _event( + "task_started", + parent_event_id=e1.event_id, + emission_sequence=2, + ) + state.event_record.add(e1) + state.event_record.add(e2) + + raw = state.model_dump_json() + restored = RuntimeState.model_validate_json( + raw, context={"from_checkpoint": True} + ) + + assert len(restored.event_record) == 2 + assert e1.event_id in restored.event_record + assert e2.event_id in restored.event_record + + # Verify edges survived + e2_node = restored.event_record.get(e2.event_id) + assert e1.event_id in e2_node.neighbors("parent") + + def test_runtime_state_without_record_still_loads(self): + """Backwards compat: a bare entity list should still validate.""" + from crewai import Agent, Crew, RuntimeState + + agent = Agent( + role="test", goal="test", backstory="test", llm="gpt-4o-mini" + ) + crew = Crew(agents=[agent], tasks=[], verbose=False) + state = RuntimeState(root=[crew]) + + # Simulate old-format JSON (just the entity list) + old_json = json.dumps( + [json.loads(crew.model_dump_json())] + ) + restored = RuntimeState.model_validate_json( + old_json, context={"from_checkpoint": True} + ) + assert len(restored.root) == 1 + assert len(restored.event_record) == 0 \ No newline at end of file