mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: add RWLock to EventRecord for thread-safe concurrent access
This commit is contained in:
@@ -8,9 +8,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.utilities.rw_lock import RWLock
|
||||
|
||||
|
||||
EdgeType = Literal[
|
||||
@@ -61,6 +62,7 @@ class EventRecord(BaseModel):
|
||||
"""
|
||||
|
||||
nodes: dict[str, EventNode] = Field(default_factory=dict)
|
||||
_lock: RWLock = PrivateAttr(default_factory=RWLock)
|
||||
|
||||
def add(self, event: BaseEvent) -> EventNode:
|
||||
"""Add an event to the record and wire its edges.
|
||||
@@ -71,26 +73,34 @@ class EventRecord(BaseModel):
|
||||
Returns:
|
||||
The created node.
|
||||
"""
|
||||
node = EventNode(event=event)
|
||||
self.nodes[event.event_id] = node
|
||||
with self._lock.w_locked():
|
||||
node = EventNode(event=event)
|
||||
self.nodes[event.event_id] = node
|
||||
|
||||
if event.parent_event_id and event.parent_event_id in self.nodes:
|
||||
node.add_edge("parent", event.parent_event_id)
|
||||
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
|
||||
if event.parent_event_id and event.parent_event_id in self.nodes:
|
||||
node.add_edge("parent", event.parent_event_id)
|
||||
self.nodes[event.parent_event_id].add_edge("child", event.event_id)
|
||||
|
||||
if event.triggered_by_event_id and event.triggered_by_event_id in self.nodes:
|
||||
node.add_edge("triggered_by", event.triggered_by_event_id)
|
||||
self.nodes[event.triggered_by_event_id].add_edge("trigger", event.event_id)
|
||||
if (
|
||||
event.triggered_by_event_id
|
||||
and event.triggered_by_event_id in self.nodes
|
||||
):
|
||||
node.add_edge("triggered_by", event.triggered_by_event_id)
|
||||
self.nodes[event.triggered_by_event_id].add_edge(
|
||||
"trigger", event.event_id
|
||||
)
|
||||
|
||||
if event.previous_event_id and event.previous_event_id in self.nodes:
|
||||
node.add_edge("previous", event.previous_event_id)
|
||||
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
|
||||
if event.previous_event_id and event.previous_event_id in self.nodes:
|
||||
node.add_edge("previous", event.previous_event_id)
|
||||
self.nodes[event.previous_event_id].add_edge("next", event.event_id)
|
||||
|
||||
if event.started_event_id and event.started_event_id in self.nodes:
|
||||
node.add_edge("started", event.started_event_id)
|
||||
self.nodes[event.started_event_id].add_edge("completed_by", event.event_id)
|
||||
if event.started_event_id and event.started_event_id in self.nodes:
|
||||
node.add_edge("started", event.started_event_id)
|
||||
self.nodes[event.started_event_id].add_edge(
|
||||
"completed_by", event.event_id
|
||||
)
|
||||
|
||||
return node
|
||||
return node
|
||||
|
||||
def get(self, event_id: str) -> EventNode | None:
|
||||
"""Look up a node by event ID.
|
||||
@@ -101,7 +111,8 @@ class EventRecord(BaseModel):
|
||||
Returns:
|
||||
The node, or None if not found.
|
||||
"""
|
||||
return self.nodes.get(event_id)
|
||||
with self._lock.r_locked():
|
||||
return self.nodes.get(event_id)
|
||||
|
||||
def descendants(self, event_id: str) -> list[EventNode]:
|
||||
"""Return all descendant nodes, children recursively.
|
||||
@@ -112,28 +123,29 @@ class EventRecord(BaseModel):
|
||||
Returns:
|
||||
All descendant nodes in breadth-first order.
|
||||
"""
|
||||
result: list[EventNode] = []
|
||||
queue = [event_id]
|
||||
visited: set[str] = set()
|
||||
with self._lock.r_locked():
|
||||
result: list[EventNode] = []
|
||||
queue = [event_id]
|
||||
visited: set[str] = set()
|
||||
|
||||
while queue:
|
||||
current_id = queue.pop(0)
|
||||
if current_id in visited:
|
||||
continue
|
||||
visited.add(current_id)
|
||||
while queue:
|
||||
current_id = queue.pop(0)
|
||||
if current_id in visited:
|
||||
continue
|
||||
visited.add(current_id)
|
||||
|
||||
node = self.nodes.get(current_id)
|
||||
if node is None:
|
||||
continue
|
||||
node = self.nodes.get(current_id)
|
||||
if node is None:
|
||||
continue
|
||||
|
||||
for child_id in node.neighbors("child"):
|
||||
if child_id not in visited:
|
||||
child_node = self.nodes.get(child_id)
|
||||
if child_node:
|
||||
result.append(child_node)
|
||||
queue.append(child_id)
|
||||
for child_id in node.neighbors("child"):
|
||||
if child_id not in visited:
|
||||
child_node = self.nodes.get(child_id)
|
||||
if child_node:
|
||||
result.append(child_node)
|
||||
queue.append(child_id)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
def roots(self) -> list[EventNode]:
|
||||
"""Return all root nodes — events with no parent.
|
||||
@@ -141,10 +153,15 @@ class EventRecord(BaseModel):
|
||||
Returns:
|
||||
List of root event nodes.
|
||||
"""
|
||||
return [node for node in self.nodes.values() if not node.neighbors("parent")]
|
||||
with self._lock.r_locked():
|
||||
return [
|
||||
node for node in self.nodes.values() if not node.neighbors("parent")
|
||||
]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.nodes)
|
||||
with self._lock.r_locked():
|
||||
return len(self.nodes)
|
||||
|
||||
def __contains__(self, event_id: str) -> bool:
|
||||
return event_id in self.nodes
|
||||
with self._lock.r_locked():
|
||||
return event_id in self.nodes
|
||||
|
||||
Reference in New Issue
Block a user