feat: add EventRecord to RuntimeState checkpoints

This commit is contained in:
Greyson LaLonde
2026-04-03 22:31:36 +08:00
parent de9300705d
commit c653d41b89
5 changed files with 618 additions and 8 deletions

View File

@@ -480,6 +480,10 @@ class CrewAIEventsBus:
event.parent_event_id = get_current_parent_id() event.parent_event_id = get_current_parent_id()
set_last_event_id(event.event_id) set_last_event_id(event.event_id)
if self._runtime_state is not None:
self._runtime_state.event_record.add(event)
event_type = type(event) event_type = type(event)
with self._rwlock.r_locked(): with self._rwlock.r_locked():
@@ -578,6 +582,9 @@ class CrewAIEventsBus:
source: The object emitting the event source: The object emitting the event
event: The event instance to emit event: The event instance to emit
""" """
if self._runtime_state is not None:
self._runtime_state.event_record.add(event)
event_type = type(event) event_type = type(event)
with self._rwlock.r_locked(): with self._rwlock.r_locked():

View File

@@ -0,0 +1,150 @@
"""Directed record of execution events.
Stores events as nodes with typed edges for parent/child, causal, and
sequential relationships. Provides O(1) lookups and traversal.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
from crewai.events.base_events import BaseEvent
EdgeType = Literal[
"parent",
"child",
"trigger",
"triggered_by",
"next",
"previous",
"started",
"completed_by",
]
class EventNode(BaseModel):
"""A node wrapping a single event with its adjacency lists."""
event: BaseEvent
edges: dict[EdgeType, list[str]] = Field(default_factory=dict)
def add_edge(self, edge_type: EdgeType, target_id: str) -> None:
"""Add an edge from this node to another.
Args:
edge_type: The relationship type.
target_id: The event_id of the target node.
"""
self.edges.setdefault(edge_type, []).append(target_id)
def neighbors(self, edge_type: EdgeType) -> list[str]:
"""Return neighbor IDs for a given edge type.
Args:
edge_type: The relationship type to query.
Returns:
List of event IDs connected by this edge type.
"""
return self.edges.get(edge_type, [])
class EventRecord(BaseModel):
"""Directed record of execution events with O(1) node lookup.
Events are added via :meth:`add` which automatically wires edges
based on the event's relationship fields — ``parent_event_id``,
``triggered_by_event_id``, ``previous_event_id``, ``started_event_id``.
"""
nodes: dict[str, EventNode] = Field(default_factory=dict)
def add(self, event: BaseEvent) -> EventNode:
"""Add an event to the record and wire its edges.
Args:
event: The event to insert.
Returns:
The created node.
"""
node = EventNode(event=event)
self.nodes[event.event_id] = node
if event.parent_event_id and event.parent_event_id in self.nodes:
node.add_edge("parent", event.parent_event_id)
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes:
node.add_edge("triggered_by", event.triggered_by_event_id)
self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id)
if event.previous_event_id and event.previous_event_id in self.nodes:
node.add_edge("previous", event.previous_event_id)
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
if event.started_event_id and event.started_event_id in self.nodes:
node.add_edge("started", event.started_event_id)
self.nodes[event.started_event_id].add_edge("completed_by", event.event_id)
return node
def get(self, event_id: str) -> EventNode | None:
"""Look up a node by event ID.
Args:
event_id: The event's unique identifier.
Returns:
The node, or None if not found.
"""
return self.nodes.get(event_id)
def descendants(self, event_id: str) -> list[EventNode]:
"""Return all descendant nodes, children recursively.
Args:
event_id: The root event ID to start from.
Returns:
All descendant nodes in breadth-first order.
"""
result: list[EventNode] = []
queue = [event_id]
visited: set[str] = set()
while queue:
current_id = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
node = self.nodes.get(current_id)
if node is None:
continue
for child_id in node.neighbors("child"):
if child_id not in visited:
child_node = self.nodes.get(child_id)
if child_node:
result.append(child_node)
queue.append(child_id)
return result
def roots(self) -> list[EventNode]:
"""Return all root nodes — events with no parent.
Returns:
List of root event nodes.
"""
return [node for node in self.nodes.values() if not node.neighbors("parent")]
def __len__(self) -> int:
return len(self.nodes)
def __contains__(self, event_id: str) -> bool:
return event_id in self.nodes

View File

@@ -12,8 +12,8 @@ from pydantic_core import CoreSchema, core_schema
class BaseProvider(Protocol): class BaseProvider(Protocol):
"""Interface for persisting and restoring runtime state checkpoints. """Interface for persisting and restoring runtime state checkpoints.
Implementations handle the storage backend (filesystem, cloud, database, Implementations handle the storage backend filesystem, cloud, database,
etc.) while ``RuntimeState`` handles serialization. etc. while ``RuntimeState`` handles serialization.
""" """
@classmethod @classmethod
@@ -39,10 +39,10 @@ class BaseProvider(Protocol):
Args: Args:
data: The serialized string to persist. data: The serialized string to persist.
directory: Logical destination (path, bucket prefix, etc.). directory: Logical destination: path, bucket prefix, etc.
Returns: Returns:
A location identifier for the saved checkpoint (e.g. file path, URI). A location identifier for the saved checkpoint, such as a file path or URI.
""" """
... ...
@@ -51,9 +51,9 @@ class BaseProvider(Protocol):
Args: Args:
data: The serialized string to persist. data: The serialized string to persist.
directory: Logical destination (path, bucket prefix, etc.). directory: Logical destination: path, bucket prefix, etc.
Returns: Returns:
A location identifier for the saved checkpoint (e.g. file path, URI). A location identifier for the saved checkpoint, such as a file path or URI.
""" """
... ...

View File

@@ -9,11 +9,19 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, TypedDict
from pydantic import PrivateAttr, RootModel from pydantic import (
ModelWrapValidatorHandler,
PrivateAttr,
RootModel,
SerializerFunctionWrapHandler,
model_serializer,
model_validator,
)
from crewai.context import capture_execution_context from crewai.context import capture_execution_context
from crewai.state.event_record import EventRecord
from crewai.state.provider.core import BaseProvider from crewai.state.provider.core import BaseProvider
from crewai.state.provider.json_provider import JsonProvider from crewai.state.provider.json_provider import JsonProvider
@@ -22,6 +30,11 @@ if TYPE_CHECKING:
from crewai import Entity from crewai import Entity
class CheckpointPayload(TypedDict):
entities: list[Entity]
event_record: dict[str, Any]
def _entity_discriminator(v: dict[str, Any] | object) -> str: def _entity_discriminator(v: dict[str, Any] | object) -> str:
if isinstance(v, dict): if isinstance(v, dict):
raw = v.get("entity_type", "agent") raw = v.get("entity_type", "agent")
@@ -64,6 +77,32 @@ def _sync_checkpoint_fields(entity: object) -> None:
class RuntimeState(RootModel): # type: ignore[type-arg] class RuntimeState(RootModel): # type: ignore[type-arg]
root: list[Entity] root: list[Entity]
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
@property
def event_record(self) -> EventRecord:
"""The execution event record."""
return self._event_record
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler) -> CheckpointPayload:
return {
"entities": handler(self),
"event_record": self._event_record.model_dump(),
}
@model_validator(mode="wrap")
@classmethod
def _deserialize(
cls, data: Any, handler: ModelWrapValidatorHandler[RuntimeState]
) -> RuntimeState:
if isinstance(data, dict) and "entities" in data:
record_data = data.get("event_record")
state = handler(data["entities"])
if record_data:
state._event_record = EventRecord.model_validate(record_data)
return state
return handler(data)
def checkpoint(self, directory: str) -> str: def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory. """Write a checkpoint file to the directory.

View File

@@ -0,0 +1,414 @@
"""Tests for EventRecord data structure and RuntimeState integration."""
from __future__ import annotations
import json
import pytest
from crewai.events.base_events import BaseEvent
from crewai.state.event_record import EventRecord, EventNode
# ── Helpers ──────────────────────────────────────────────────────────
def _event(type: str, **kwargs) -> BaseEvent:
return BaseEvent(type=type, **kwargs)
def _linear_record(n: int = 5) -> tuple[EventRecord, list[BaseEvent]]:
"""Build a simple chain: e0 → e1 → e2 → ... with previous_event_id."""
g = EventRecord()
events: list[BaseEvent] = []
for i in range(n):
e = _event(
f"step_{i}",
previous_event_id=events[-1].event_id if events else None,
emission_sequence=i + 1,
)
events.append(e)
g.add(e)
return g, events
def _tree_record() -> tuple[EventRecord, dict[str, BaseEvent]]:
"""Build a parent/child tree:
crew_start
├── task_start
│ ├── agent_start
│ └── agent_complete (started=agent_start)
└── task_complete (started=task_start)
"""
g = EventRecord()
crew_start = _event("crew_kickoff_started", emission_sequence=1)
task_start = _event(
"task_started",
parent_event_id=crew_start.event_id,
previous_event_id=crew_start.event_id,
emission_sequence=2,
)
agent_start = _event(
"agent_execution_started",
parent_event_id=task_start.event_id,
previous_event_id=task_start.event_id,
emission_sequence=3,
)
agent_complete = _event(
"agent_execution_completed",
parent_event_id=task_start.event_id,
previous_event_id=agent_start.event_id,
started_event_id=agent_start.event_id,
emission_sequence=4,
)
task_complete = _event(
"task_completed",
parent_event_id=crew_start.event_id,
previous_event_id=agent_complete.event_id,
started_event_id=task_start.event_id,
emission_sequence=5,
)
for e in [crew_start, task_start, agent_start, agent_complete, task_complete]:
g.add(e)
return g, {
"crew_start": crew_start,
"task_start": task_start,
"agent_start": agent_start,
"agent_complete": agent_complete,
"task_complete": task_complete,
}
# ── EventNode tests ─────────────────────────────────────────────────
class TestEventNode:
def test_add_edge(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "abc")
assert node.neighbors("child") == ["abc"]
def test_neighbors_empty(self):
node = EventNode(event=_event("test"))
assert node.neighbors("parent") == []
def test_multiple_edges_same_type(self):
node = EventNode(event=_event("test"))
node.add_edge("child", "a")
node.add_edge("child", "b")
assert node.neighbors("child") == ["a", "b"]
# ── EventRecord core tests ───────────────────────────────────────────
class TestEventRecordCore:
def test_add_single_event(self):
g = EventRecord()
e = _event("test")
node = g.add(e)
assert len(g) == 1
assert e.event_id in g
assert node.event.type == "test"
def test_get_existing(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert g.get(e.event_id) is not None
def test_get_missing(self):
g = EventRecord()
assert g.get("nonexistent") is None
def test_contains(self):
g = EventRecord()
e = _event("test")
g.add(e)
assert e.event_id in g
assert "missing" not in g
# ── Edge wiring tests ───────────────────────────────────────────────
class TestEdgeWiring:
def test_parent_child_bidirectional(self):
g = EventRecord()
parent = _event("parent")
child = _event("child", parent_event_id=parent.event_id)
g.add(parent)
g.add(child)
parent_node = g.get(parent.event_id)
child_node = g.get(child.event_id)
assert child.event_id in parent_node.neighbors("child")
assert parent.event_id in child_node.neighbors("parent")
def test_previous_next_bidirectional(self):
g, events = _linear_record(3)
node0 = g.get(events[0].event_id)
node1 = g.get(events[1].event_id)
node2 = g.get(events[2].event_id)
assert events[1].event_id in node0.neighbors("next")
assert events[0].event_id in node1.neighbors("previous")
assert events[2].event_id in node1.neighbors("next")
assert events[1].event_id in node2.neighbors("previous")
def test_trigger_bidirectional(self):
g = EventRecord()
cause = _event("cause")
effect = _event("effect", triggered_by_event_id=cause.event_id)
g.add(cause)
g.add(effect)
assert effect.event_id in g.get(cause.event_id).neighbors("trigger")
assert cause.event_id in g.get(effect.event_id).neighbors("triggered_by")
def test_started_completed_by_bidirectional(self):
g = EventRecord()
start = _event("start")
end = _event("end", started_event_id=start.event_id)
g.add(start)
g.add(end)
assert end.event_id in g.get(start.event_id).neighbors("completed_by")
assert start.event_id in g.get(end.event_id).neighbors("started")
def test_dangling_reference_ignored(self):
"""Edge to a non-existent node should not be wired."""
g = EventRecord()
e = _event("orphan", parent_event_id="nonexistent")
g.add(e)
node = g.get(e.event_id)
assert node.neighbors("parent") == []
# ── Edge symmetry validation ─────────────────────────────────────────
SYMMETRIC_PAIRS = [
("parent", "child"),
("previous", "next"),
("triggered_by", "trigger"),
("started", "completed_by"),
]
class TestEdgeSymmetry:
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_tree(self, forward, reverse):
g, _ = _tree_record()
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None, f"{target_id} missing from record"
assert node_id in target_node.neighbors(reverse), (
f"Asymmetric edge: {node_id} --{forward.value}--> {target_id} "
f"but {target_id} has no {reverse.value} back to {node_id}"
)
@pytest.mark.parametrize("forward,reverse", SYMMETRIC_PAIRS)
def test_symmetry_on_linear(self, forward, reverse):
g, _ = _linear_record(10)
for node_id, node in g.nodes.items():
for target_id in node.neighbors(forward):
target_node = g.get(target_id)
assert target_node is not None
assert node_id in target_node.neighbors(reverse)
# ── Ordering tests ───────────────────────────────────────────────────
class TestOrdering:
def test_emission_sequence_monotonic(self):
g, events = _linear_record(10)
sequences = [e.emission_sequence for e in events]
assert sequences == sorted(sequences)
assert len(set(sequences)) == len(sequences), "Duplicate sequences"
def test_next_chain_follows_sequence_order(self):
g, events = _linear_record(5)
current = g.get(events[0].event_id)
visited = []
while current:
visited.append(current.event.event_id)
nexts = current.neighbors("next")
current = g.get(nexts[0]) if nexts else None
assert visited == [e.event_id for e in events]
# ── Traversal tests ─────────────────────────────────────────────────
class TestTraversal:
def test_roots_single_root(self):
g, events = _tree_record()
roots = g.roots()
assert len(roots) == 1
assert roots[0].event.type == "crew_kickoff_started"
def test_roots_multiple(self):
g = EventRecord()
g.add(_event("root1"))
g.add(_event("root2"))
assert len(g.roots()) == 2
def test_descendants_of_crew_start(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_types = {n.event.type for n in desc}
assert desc_types == {
"task_started",
"task_completed",
"agent_execution_started",
"agent_execution_completed",
}
def test_descendants_of_leaf(self):
g, events = _tree_record()
desc = g.descendants(events["task_complete"].event_id)
assert desc == []
def test_descendants_does_not_include_self(self):
g, events = _tree_record()
desc = g.descendants(events["crew_start"].event_id)
desc_ids = {n.event.event_id for n in desc}
assert events["crew_start"].event_id not in desc_ids
# ── Serialization round-trip tests ──────────────────────────────────
class TestSerialization:
def test_empty_record_roundtrip(self):
g = EventRecord()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 0
def test_linear_record_roundtrip(self):
g, events = _linear_record(5)
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
for e in events:
assert e.event_id in restored
def test_tree_record_roundtrip(self):
g, events = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
assert len(restored) == 5
# Verify edges survived
crew_node = restored.get(events["crew_start"].event_id)
assert len(crew_node.neighbors("child")) == 2
def test_roundtrip_preserves_edge_symmetry(self):
g, _ = _tree_record()
restored = EventRecord.model_validate_json(g.model_dump_json())
for node_id, node in restored.nodes.items():
for forward, reverse in SYMMETRIC_PAIRS:
for target_id in node.neighbors(forward):
target_node = restored.get(target_id)
assert node_id in target_node.neighbors(reverse)
def test_roundtrip_preserves_event_data(self):
g = EventRecord()
e = _event(
"test",
source_type="crew",
task_id="t1",
agent_role="researcher",
emission_sequence=42,
)
g.add(e)
restored = EventRecord.model_validate_json(g.model_dump_json())
re = restored.get(e.event_id).event
assert re.type == "test"
assert re.source_type == "crew"
assert re.task_id == "t1"
assert re.agent_role == "researcher"
assert re.emission_sequence == 42
# ── RuntimeState integration tests ──────────────────────────────────
class TestRuntimeStateIntegration:
def test_runtime_state_serializes_event_record(self):
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
dumped = json.loads(state.model_dump_json())
assert "entities" in dumped
assert "event_record" in dumped
assert len(dumped["event_record"]["nodes"]) == 2
def test_runtime_state_roundtrip_with_record(self):
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
e1 = _event("crew_started", emission_sequence=1)
e2 = _event(
"task_started",
parent_event_id=e1.event_id,
emission_sequence=2,
)
state.event_record.add(e1)
state.event_record.add(e2)
raw = state.model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
assert len(restored.event_record) == 2
assert e1.event_id in restored.event_record
assert e2.event_id in restored.event_record
# Verify edges survived
e2_node = restored.event_record.get(e2.event_id)
assert e1.event_id in e2_node.neighbors("parent")
def test_runtime_state_without_record_still_loads(self):
"""Backwards compat: a bare entity list should still validate."""
from crewai import Agent, Crew, RuntimeState
agent = Agent(
role="test", goal="test", backstory="test", llm="gpt-4o-mini"
)
crew = Crew(agents=[agent], tasks=[], verbose=False)
state = RuntimeState(root=[crew])
# Simulate old-format JSON (just the entity list)
old_json = json.dumps(
[json.loads(crew.model_dump_json())]
)
restored = RuntimeState.model_validate_json(
old_json, context={"from_checkpoint": True}
)
assert len(restored.root) == 1
assert len(restored.event_record) == 0