diff --git a/lib/crewai/src/crewai/state/event_record.py b/lib/crewai/src/crewai/state/event_record.py index bfa6b1fbb..a06dec398 100644 --- a/lib/crewai/src/crewai/state/event_record.py +++ b/lib/crewai/src/crewai/state/event_record.py @@ -8,9 +8,10 @@ from __future__ import annotations from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from crewai.events.base_events import BaseEvent +from crewai.utilities.rw_lock import RWLock EdgeType = Literal[ @@ -61,6 +62,7 @@ class EventRecord(BaseModel): """ nodes: dict[str, EventNode] = Field(default_factory=dict) + _lock: RWLock = PrivateAttr(default_factory=RWLock) def add(self, event: BaseEvent) -> EventNode: """Add an event to the record and wire its edges. @@ -71,26 +73,34 @@ class EventRecord(BaseModel): Returns: The created node. """ - node = EventNode(event=event) - self.nodes[event.event_id] = node + with self._lock.w_locked(): + node = EventNode(event=event) + self.nodes[event.event_id] = node - if event.parent_event_id and event.parent_event_id in self.nodes: - node.add_edge("parent", event.parent_event_id) - self.nodes[event.parent_event_id].add_edge("child", event.event_id) + if event.parent_event_id and event.parent_event_id in self.nodes: + node.add_edge("parent", event.parent_event_id) + self.nodes[event.parent_event_id].add_edge("child", event.event_id) - if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes: - node.add_edge("triggered_by", event.triggered_by_event_id) - self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id) + if ( + event.triggered_by_event_id + and event.triggered_by_event_id in self.nodes + ): + node.add_edge("triggered_by", event.triggered_by_event_id) + self.nodes[event.triggered_by_event_id].add_edge( + "trigger", event.event_id + ) - if event.previous_event_id and event.previous_event_id in self.nodes: - node.add_edge("previous", event.previous_event_id) - self.nodes[event.previous_event_id].add_edge("next", event.event_id) + if event.previous_event_id and event.previous_event_id in self.nodes: + node.add_edge("previous", event.previous_event_id) + self.nodes[event.previous_event_id].add_edge("next", event.event_id) - if event.started_event_id and event.started_event_id in self.nodes: - node.add_edge("started", event.started_event_id) - self.nodes[event.started_event_id].add_edge("completed_by", event.event_id) + if event.started_event_id and event.started_event_id in self.nodes: + node.add_edge("started", event.started_event_id) + self.nodes[event.started_event_id].add_edge( + "completed_by", event.event_id + ) - return node + return node def get(self, event_id: str) -> EventNode | None: """Look up a node by event ID. @@ -101,7 +111,8 @@ class EventRecord(BaseModel): Returns: The node, or None if not found. """ - return self.nodes.get(event_id) + with self._lock.r_locked(): + return self.nodes.get(event_id) def descendants(self, event_id: str) -> list[EventNode]: """Return all descendant nodes, children recursively. @@ -112,28 +123,29 @@ class EventRecord(BaseModel): Returns: All descendant nodes in breadth-first order. """ - result: list[EventNode] = [] - queue = [event_id] - visited: set[str] = set() + with self._lock.r_locked(): + result: list[EventNode] = [] + queue = [event_id] + visited: set[str] = set() - while queue: - current_id = queue.pop(0) - if current_id in visited: - continue - visited.add(current_id) + while queue: + current_id = queue.pop(0) + if current_id in visited: + continue + visited.add(current_id) - node = self.nodes.get(current_id) - if node is None: - continue + node = self.nodes.get(current_id) + if node is None: + continue - for child_id in node.neighbors("child"): - if child_id not in visited: - child_node = self.nodes.get(child_id) - if child_node: - result.append(child_node) - queue.append(child_id) + for child_id in node.neighbors("child"): + if child_id not in visited: + child_node = self.nodes.get(child_id) + if child_node: + result.append(child_node) + queue.append(child_id) - return result + return result def roots(self) -> list[EventNode]: """Return all root nodes — events with no parent. @@ -141,10 +153,15 @@ class EventRecord(BaseModel): Returns: List of root event nodes. """ - return [node for node in self.nodes.values() if not node.neighbors("parent")] + with self._lock.r_locked(): + return [ + node for node in self.nodes.values() if not node.neighbors("parent") + ] def __len__(self) -> int: - return len(self.nodes) + with self._lock.r_locked(): + return len(self.nodes) def __contains__(self, event_id: str) -> bool: - return event_id in self.nodes + with self._lock.r_locked(): + return event_id in self.nodes