fix: add RWLock to EventRecord for thread-safe concurrent access

This commit is contained in:
Greyson LaLonde
2026-04-06 20:44:55 +08:00
parent 470af2f9e1
commit ef7654d7d5

View File

@@ -8,9 +8,10 @@ from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from crewai.events.base_events import BaseEvent
from crewai.utilities.rw_lock import RWLock
EdgeType = Literal[
@@ -61,6 +62,7 @@ class EventRecord(BaseModel):
"""
nodes: dict[str, EventNode] = Field(default_factory=dict)
_lock: RWLock = PrivateAttr(default_factory=RWLock)
def add(self, event: BaseEvent) -> EventNode:
"""Add an event to the record and wire its edges.
@@ -71,26 +73,34 @@ class EventRecord(BaseModel):
Returns:
The created node.
"""
node = EventNode(event=event)
self.nodes[event.event_id] = node
with self._lock.w_locked():
node = EventNode(event=event)
self.nodes[event.event_id] = node
if event.parent_event_id and event.parent_event_id in self.nodes:
node.add_edge("parent", event.parent_event_id)
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
if event.parent_event_id and event.parent_event_id in self.nodes:
node.add_edge("parent", event.parent_event_id)
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes:
node.add_edge("triggered_by", event.triggered_by_event_id)
self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id)
if (
event.triggered_by_event_id
and event.triggered_by_event_id in self.nodes
):
node.add_edge("triggered_by", event.triggered_by_event_id)
self.nodes[event.triggered_by_event_id].add_edge(
"trigger", event.event_id
)
if event.previous_event_id and event.previous_event_id in self.nodes:
node.add_edge("previous", event.previous_event_id)
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
if event.previous_event_id and event.previous_event_id in self.nodes:
node.add_edge("previous", event.previous_event_id)
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
if event.started_event_id and event.started_event_id in self.nodes:
node.add_edge("started", event.started_event_id)
self.nodes[event.started_event_id].add_edge("completed_by", event.event_id)
if event.started_event_id and event.started_event_id in self.nodes:
node.add_edge("started", event.started_event_id)
self.nodes[event.started_event_id].add_edge(
"completed_by", event.event_id
)
return node
return node
def get(self, event_id: str) -> EventNode | None:
"""Look up a node by event ID.
@@ -101,7 +111,8 @@ class EventRecord(BaseModel):
Returns:
The node, or None if not found.
"""
return self.nodes.get(event_id)
with self._lock.r_locked():
return self.nodes.get(event_id)
def descendants(self, event_id: str) -> list[EventNode]:
"""Return all descendant nodes, children recursively.
@@ -112,28 +123,29 @@ class EventRecord(BaseModel):
Returns:
All descendant nodes in breadth-first order.
"""
result: list[EventNode] = []
queue = [event_id]
visited: set[str] = set()
with self._lock.r_locked():
result: list[EventNode] = []
queue = [event_id]
visited: set[str] = set()
while queue:
current_id = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
while queue:
current_id = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
node = self.nodes.get(current_id)
if node is None:
continue
node = self.nodes.get(current_id)
if node is None:
continue
for child_id in node.neighbors("child"):
if child_id not in visited:
child_node = self.nodes.get(child_id)
if child_node:
result.append(child_node)
queue.append(child_id)
for child_id in node.neighbors("child"):
if child_id not in visited:
child_node = self.nodes.get(child_id)
if child_node:
result.append(child_node)
queue.append(child_id)
return result
return result
def roots(self) -> list[EventNode]:
"""Return all root nodes — events with no parent.
@@ -141,10 +153,15 @@ class EventRecord(BaseModel):
Returns:
List of root event nodes.
"""
return [node for node in self.nodes.values() if not node.neighbors("parent")]
with self._lock.r_locked():
return [
node for node in self.nodes.values() if not node.neighbors("parent")
]
def __len__(self) -> int:
return len(self.nodes)
with self._lock.r_locked():
return len(self.nodes)
def __contains__(self, event_id: str) -> bool:
return event_id in self.nodes
with self._lock.r_locked():
return event_id in self.nodes