feat: add EventRecord to RuntimeState checkpoints

This commit is contained in:
Greyson LaLonde
2026-04-03 22:31:36 +08:00
parent de9300705d
commit c653d41b89
5 changed files with 618 additions and 8 deletions

View File

@@ -480,6 +480,10 @@ class CrewAIEventsBus:
event.parent_event_id = get_current_parent_id()
set_last_event_id(event.event_id)
if self._runtime_state is not None:
self._runtime_state.event_record.add(event)
event_type = type(event)
with self._rwlock.r_locked():
@@ -578,6 +582,9 @@ class CrewAIEventsBus:
source: The object emitting the event
event: The event instance to emit
"""
if self._runtime_state is not None:
self._runtime_state.event_record.add(event)
event_type = type(event)
with self._rwlock.r_locked():

View File

@@ -0,0 +1,150 @@
"""Directed record of execution events.
Stores events as nodes with typed edges for parent/child, causal, and
sequential relationships. Provides O(1) lookups and traversal.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
from crewai.events.base_events import BaseEvent
EdgeType = Literal[
"parent",
"child",
"trigger",
"triggered_by",
"next",
"previous",
"started",
"completed_by",
]
class EventNode(BaseModel):
"""A node wrapping a single event with its adjacency lists."""
event: BaseEvent
edges: dict[EdgeType, list[str]] = Field(default_factory=dict)
def add_edge(self, edge_type: EdgeType, target_id: str) -> None:
"""Add an edge from this node to another.
Args:
edge_type: The relationship type.
target_id: The event_id of the target node.
"""
self.edges.setdefault(edge_type, []).append(target_id)
def neighbors(self, edge_type: EdgeType) -> list[str]:
"""Return neighbor IDs for a given edge type.
Args:
edge_type: The relationship type to query.
Returns:
List of event IDs connected by this edge type.
"""
return self.edges.get(edge_type, [])
class EventRecord(BaseModel):
"""Directed record of execution events with O(1) node lookup.
Events are added via :meth:`add` which automatically wires edges
based on the event's relationship fields — ``parent_event_id``,
``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``.
"""
nodes: dict[str, EventNode] = Field(default_factory=dict)
def add(self, event: BaseEvent) -> EventNode:
"""Add an event to the record and wire its edges.
Args:
event: The event to insert.
Returns:
The created node.
"""
node = EventNode(event=event)
self.nodes[event.event_id] = node
if event.parent_event_id and event.parent_event_id in self.nodes:
node.add_edge("parent", event.parent_event_id)
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes:
node.add_edge("triggered_by", event.triggered_by_event_id)
self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id)
if event.previous_event_id and event.previous_event_id in self.nodes:
node.add_edge("previous", event.previous_event_id)
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
if event.started_event_id and event.started_event_id in self.nodes:
node.add_edge("started", event.started_event_id)
self.nodes[event.started_event_id].add_edge("completed_by", event.event_id)
return node
def get(self, event_id: str) -> EventNode | None:
"""Look up a node by event ID.
Args:
event_id: The event's unique identifier.
Returns:
The node, or None if not found.
"""
return self.nodes.get(event_id)
def descendants(self, event_id: str) -> list[EventNode]:
"""Return all descendant nodes, children recursively.
Args:
event_id: The root event ID to start from.
Returns:
All descendant nodes in breadth-first order.
"""
result: list[EventNode] = []
queue = [event_id]
visited: set[str] = set()
while queue:
current_id = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
node = self.nodes.get(current_id)
if node is None:
continue
for child_id in node.neighbors("child"):
if child_id not in visited:
child_node = self.nodes.get(child_id)
if child_node:
result.append(child_node)
queue.append(child_id)
return result
def roots(self) -> list[EventNode]:
"""Return all root nodes — events with no parent.
Returns:
List of root event nodes.
"""
return [node for node in self.nodes.values() if not node.neighbors("parent")]
def __len__(self) -> int:
return len(self.nodes)
def __contains__(self, event_id: str) -> bool:
return event_id in self.nodes

View File

@@ -12,8 +12,8 @@ from pydantic_core import CoreSchema, core_schema
class BaseProvider(Protocol):
"""Interface for persisting and restoring runtime state checkpoints.
Implementations handle the storage backend (filesystem, cloud, database,
etc.) while ``RuntimeState`` handles serialization.
Implementations handle the storage backend filesystem, cloud, database,
etc. while ``RuntimeState`` handles serialization.
"""
@classmethod
@@ -39,10 +39,10 @@ class BaseProvider(Protocol):
Args:
data: The serialized string to persist.
directory: Logical destination (path, bucket prefix, etc.).
directory: Logical destination: path, bucket prefix, etc.
Returns:
A location identifier for the saved checkpoint (e.g. file path, URI).
A location identifier for the saved checkpoint, such as a file path or URI.
"""
...
@@ -51,9 +51,9 @@ class BaseProvider(Protocol):
Args:
data: The serialized string to persist.
directory: Logical destination (path, bucket prefix, etc.).
directory: Logical destination: path, bucket prefix, etc.
Returns:
A location identifier for the saved checkpoint (e.g. file path, URI).
A location identifier for the saved checkpoint, such as a file path or URI.
"""
...

View File

@@ -9,11 +9,19 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict
from pydantic import PrivateAttr, RootModel
from pydantic import (
ModelWrapValidatorHandler,
PrivateAttr,
RootModel,
SerializerFunctionWrapHandler,
model_serializer,
model_validator,
)
from crewai.context import capture_execution_context
from crewai.state.event_record import EventRecord
from crewai.state.provider.core import BaseProvider
from crewai.state.provider.json_provider import JsonProvider
@@ -22,6 +30,11 @@ if TYPE_CHECKING:
from crewai import Entity
class CheckpointPayload(TypedDict):
entities: list[Entity]
event_record: dict[str, Any]
def _entity_discriminator(v: dict[str, Any] | object) -> str:
if isinstance(v, dict):
raw = v.get("entity_type", "agent")
@@ -64,6 +77,32 @@ def _sync_checkpoint_fields(entity: object) -> None:
class RuntimeState(RootModel): # type: ignore[type-arg]
root: list[Entity]
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
@property
def event_record(self) -> EventRecord:
"""The execution event record."""
return self._event_record
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler) -> CheckpointPayload:
return {
"entities": handler(self),
"event_record": self._event_record.model_dump(),
}
@model_validator(mode="wrap")
@classmethod
def _deserialize(
cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState]
) -> RuntimeState:
if isinstance(data, dict) and "entities" in data:
record_data = data.get("event_record")
state = handler(data["entities"])
if record_data:
state._event_record = EventRecord.model_validate(record_data)
return state
return handler(data)
def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory.

View File

@@ -0,0 +1,414 @@
"""Tests for EventRecord data structure and RuntimeState integration."""
from __future__ import annotations
import json
import pytest
from crewai.events.base_events import BaseEvent
from crewai.state.event_record import EventRecord, EventNode
# ── Helpers ──────────────────────────────────────────────────────────
def _event(type: str, **kwargs) -> BaseEvent:
return BaseEvent(type=type, **kwargs)
def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]:
"""Build a simple chain: e0 → e1 → e2 → ... with previous_event_id."""
g = EventRecord()
events: list[BaseEvent] = []
for i in range(n):
e = _event(
f"step_{i}",
previous_event_id=events[-1].event_id if events else None,
emission_sequence=i + 1,
)
events.append(e)
g.add(e)
return g, events
def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]:
"""Build a parent/child tree:
crew_start
├── task_start
│ ├── agent_start
│ └── agent_complete (started=agent_start)
└── task_complete (started=task_start)
"""
g = EventRecord()
crew_start = _event("crew_kickoff_started", emission_sequence=1)
task_start = _event(
"task_started",
parent_event_id=crew_start.event_id,
previous_event_id=crew_start.event_id,
emission_sequence=2,
)
agent_start = _event(
"agent_execution_started",
parent_event_id=task_start.event_id,
previous_event_id=task_start.event_id,
emission_sequence=3,
)
agent_complete = _event(
"agent_execution_completed",
parent_event_id=task_start.event_id,
previous_event_id=agent_start.event_id,
started_event_id=agent_start.event_id,
emission_sequence=4,
)
task_complete = _event(
"task_completed",
parent_event_id=crew_start.event_id,
previous_event_id=agent_complete.event_id,
started_event_id=task_start.event_id,
emission_sequence=5,
)
for e in [crew_start, task_start, agent_start, agent_complete, task_complete]:
g.add(e)
return g, {
"crew_start": crew_start,
"task_start": task_start,
"agent_start": agent_start,
"agent_complete": agent_complete,
"task_complete": task_complete,
}
# ── EventNode tests ─────────────────────────────────────────────────
class TestEventNode:
def test_add_edge(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "abc")
assert node.neighbors("child") == ["abc"]
def test_neighbors_empty(self):
node = EventNode(event=_event("test"))
assert node.neighbors("parent") == []
def test_multiple_edges_same_type(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "a")
node.add_edge("child", "b")
assert node.neighbors("child") == ["a", "b"]
# ── EventRecord core tests ───────────────────────────────────────────
class TestEventRecordCore:
def test_add_single_event(self):
g = EventRecord()
e = _event("test")
node = g.add(e)
assert len(g) == 1
assert e.event_id in g
assert node.event.type == "test"
def test_get_existing(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert g.get(e.event_id) is not None
def test_get_missing(self):
g = EventRecord()
assert g.get("nonexistent") is None
def test_contains(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert e.event_id in g
assert "missing" not in g
# ── Edge wiring tests ───────────────────────────────────────────────
class TestEdgeWiring:
def test_parent_child_bidirectional(self):
g = EventRecord()
parent = _event("parent")
child = _event("child", parent_event_id=parent.event_id)
g.add(parent)
g.add(child)
parent_node = g.get(parent.event_id)
child_node = g.get(child.event_id)
assert child.event_id in parent_node.neighbors("child")
assert parent.event_id in child_node.neighbors("parent")
def test_previous_next_bidirectional(self):
g, events = _linear_record(3)
node0 = g.get(events[0].event_id)
node1 = g.get(events[1].event_id)
node2 = g.get(events[2].event_id)
assert events[1].event_id in node0.neighbors("next")
assert events[0].event_id in node1.neighbors("previous")
assert events[2].event_id in node1.neighbors("next")
assert events[1].event_id in node2.neighbors("previous")
def test_trigger_bidirectional(self):
g = EventRecord()
cause = _event("cause")
effect = _event("effect", triggered_by_event_id=cause.event_id)
g.add(cause)
g.add(effect)
assert effect.event_id in g.get(cause.event_id).neighbors("trigger")
assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by")
def test_started_completed_by_bidirectional(self):
g = EventRecord()
start = _event("start")
end = _event("end", started_event_id=start.event_id)
g.add(start)
g.add(end)
assert end.event_id in g.get(start.event_id).neighbors("completed_by")
assert start.event_id in g.get(end.event_id).neighbors("started")
def test_dangling_reference_ignored(self):
"""Edge to a non-existent node should not be wired."""
g = EventRecord()
e = _event("orphan", parent_event_id="nonexistent")
g.add(e)
node = g.get(e.event_id)
assert node.neighbors("parent") == []
# ── Edge symmetry validation ─────────────────────────────────────────
SYMMETRIC_PAIRS = [
("parent", "child"),
("previous", "next"),
("triggered_by", "trigger"),
("started", "completed_by"),
]
class TestEdgeSymmetry:
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_tree(self, forward, reverse):
g, _ = _tree_record()
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None, f"{target_id} missing from record"
assert node_id in target_node.neighbors(reverse), (
f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} "
f"but {target_id} has no {reverse.value} back to {node_id}"
)
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_linear(self, forward, reverse):
g, _ = _linear_record(10)
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None
assert node_id in target_node.neighbors(reverse)
# ── Ordering tests ───────────────────────────────────────────────────
class TestOrdering:
def test_emission_sequence_monotonic(self):
g, events = _linear_record(10)
sequences = [e.emission_sequence for e in events]
assert sequences == sorted(sequences)
assert len(set(sequences)) == len(sequences), "Duplicate sequences"
def test_next_chain_follows_sequence_order(self):
g, events = _linear_record(5)
current = g.get(events[0].event_id)
visited = []
while current:
visited.append(current.event.event_id)
nexts = current.neighbors("next")
current = g.get(nexts[0]) if nexts else None
assert visited == [e.event_id for e in events]
# ── Traversal tests ─────────────────────────────────────────────────
class TestTraversal:
def test_roots_single_root(self):
g, events = _tree_record()
roots = g.roots()
assert len(roots) == 1
assert roots[0].event.type == "crew_kickoff_started"
def test_roots_multiple(self):
g = EventRecord()
g.add(_event("root1"))
g.add(_event("root2"))
assert len(g.roots()) == 2
def test_descendants_of_crew_start(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_types = {n.event.type for n in desc}
assert desc_types == {
"task_started",
"task_completed",
"agent_execution_started",
"agent_execution_completed",
}
def test_descendants_of_leaf(self):
g, events = _tree_record()
desc = g.descendants(events["task_complete"].event_id)
assert desc == []
def test_descendants_does_not_include_self(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_ids = {n.event.event_id for n in desc}
assert events["crew_start"].event_id not in desc_ids
# ── Serialization round-trip tests ──────────────────────────────────
class TestSerialization:
def test_empty_record_roundtrip(self):
g = EventRecord()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 0
def test_linear_record_roundtrip(self):
g, events = _linear_record(5)
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
for e in events:
assert e.event_id in restored
def test_tree_record_roundtrip(self):
g, events = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
# Verify edges survived
crew_node = restored.get(events["crew_start"].event_id)
assert len(crew_node.neighbors("child")) == 2
def test_roundtrip_preserves_edge_symmetry(self):
g, _ = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
for node_id, node in restored.nodes.items():
for forward, reverse in SYMMETRIC_PAIRS:
for target_id in node.neighbors(forward):
target_node = restored.get(target_id)
assert node_id in target_node.neighbors(reverse)
def test_roundtrip_preserves_event_data(self):
g = EventRecord()
e = _event(
"test",
source_type="crew",
task_id="t1",
agent_role="researcher",
emission_sequence=42,
)
g.add(e)
restored = EventRecord.model_validate_json(g.model_dump_json())
re = restored.get(e.event_id).event
assert re.type == "test"
assert re.source_type == "crew"
assert re.task_id == "t1"
assert re.agent_role == "researcher"
assert re.emission_sequence == 42
# ── RuntimeState integration tests ──────────────────────────────────
class TestRuntimeStateIntegration:
def test_runtime_state_serializes_event_record(self):
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
dumped = json.loads(state.model_dump_json())
assert "entities" in dumped
assert "event_record" in dumped
assert len(dumped["event_record"]["nodes"]) == 2
def test_runtime_state_roundtrip_with_record(self):
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
raw = state.model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
assert len(restored.event_record) == 2
assert e1.event_id in restored.event_record
assert e2.event_id in restored.event_record
# Verify edges survived
e2_node = restored.event_record.get(e2.event_id)
assert e1.event_id in e2_node.neighbors("parent")
def test_runtime_state_without_record_still_loads(self):
"""Backwards compat: a bare entity list should still validate."""
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
# Simulate old-format JSON (just the entity list)
old_json = json.dumps(
[json.loads(crew.model_dump_json())]
)
restored = RuntimeState.model_validate_json(
old_json, context={"from_checkpoint": True}
)
assert len(restored.root) == 1
assert len(restored.event_record) == 0