Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
519f8ce0eb chore: re-trigger CI checks
Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:12:38 +00:00
Devin AI
802ca92e42 fix: make PickleHandler thread-safe with portalocker and atomic writes
- Add threading lock for same-process thread safety
- Use atomic write operations (write to temp file, then rename) for data integrity
- Use portalocker for cross-process read locking
- Add comprehensive thread-safety tests covering concurrent reads, writes, and mixed operations

Fixes #4215

Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:09:02 +00:00
4 changed files with 347 additions and 199 deletions

View File

@@ -33,24 +33,13 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
def save(
self,
value: LongTermMemoryItem,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save an item to long-term memory.
Args:
value: The LongTermMemoryItem to save.
metadata: Optional metadata dict (not used, metadata is extracted from the
LongTermMemoryItem). Included for supertype compatibility.
"""
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
@@ -59,23 +48,23 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
item_metadata = value.metadata
item_metadata.update(
{"agent": value.agent, "expected_output": value.expected_output}
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save(
task_description=value.task,
score=item_metadata["quality"],
metadata=item_metadata,
datetime=value.datetime,
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
@@ -86,28 +75,25 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
def search(
def search( # type: ignore[override]
self,
query: str,
limit: int = 3,
score_threshold: float = 0.6,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory for relevant entries.
Args:
query: The task description to search for.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results (not used for
long-term memory, included for supertype compatibility).
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
@@ -115,8 +101,8 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
@@ -125,14 +111,14 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(query, limit)
results = self.storage.load(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
query=task,
results=results,
limit=limit,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
@@ -145,32 +131,26 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asave(
self,
value: LongTermMemoryItem,
metadata: dict[str, Any] | None = None,
) -> None:
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
"""Save an item to long-term memory asynchronously.
Args:
value: The LongTermMemoryItem to save.
metadata: Optional metadata dict (not used, metadata is extracted from the
LongTermMemoryItem). Included for supertype compatibility.
item: The LongTermMemoryItem to save.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
@@ -179,23 +159,23 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
item_metadata = value.metadata
item_metadata.update(
{"agent": value.agent, "expected_output": value.expected_output}
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
await self.storage.asave(
task_description=value.task,
score=item_metadata["quality"],
metadata=item_metadata,
datetime=value.datetime,
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
@@ -206,28 +186,25 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value.task,
metadata=value.metadata,
agent_role=value.agent,
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asearch(
async def asearch( # type: ignore[override]
self,
query: str,
limit: int = 3,
score_threshold: float = 0.6,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory asynchronously.
Args:
query: The task description to search for.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results (not used for
long-term memory, included for supertype compatibility).
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
@@ -235,8 +212,8 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
@@ -245,14 +222,14 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = await self.storage.aload(query, limit)
results = await self.storage.aload(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
query=task,
results=results,
limit=limit,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
@@ -265,8 +242,8 @@ class LongTermMemory(Memory):
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),

View File

@@ -2,8 +2,11 @@ from datetime import datetime
import json
import os
import pickle
import tempfile
import threading
from typing import Any, TypedDict
import portalocker
from typing_extensions import Unpack
@@ -123,10 +126,15 @@ class FileHandler:
class PickleHandler:
"""Handler for saving and loading data using pickle.
"""Thread-safe handler for saving and loading data using pickle.
This class provides thread-safe file operations using portalocker for
cross-process file locking and atomic write operations to prevent
data corruption during concurrent access.
Attributes:
file_path: The path to the pickle file.
_lock: Threading lock for thread-safe operations within the same process.
"""
def __init__(self, file_name: str) -> None:
@@ -141,34 +149,62 @@ class PickleHandler:
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
self._lock = threading.Lock()
def initialize_file(self) -> None:
"""Initialize the file with an empty dictionary and overwrite any existing data."""
self.save({})
def save(self, data: Any) -> None:
"""
Save the data to the specified file using pickle.
"""Save the data to the specified file using pickle with thread-safe atomic writes.
This method uses a two-phase approach for thread safety:
1. Threading lock for same-process thread safety
2. Atomic write (write to temp file, then rename) for cross-process safety
and data integrity
Args:
data: The data to be saved to the file.
data: The data to be saved to the file.
"""
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
with self._lock:
dir_name = os.path.dirname(self.file_path) or os.getcwd()
fd, temp_path = tempfile.mkstemp(
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
)
try:
with os.fdopen(fd, "wb") as f:
pickle.dump(obj=data, file=f)
f.flush()
os.fsync(f.fileno())
os.replace(temp_path, self.file_path)
except Exception:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def load(self) -> Any:
"""Load the data from the specified file using pickle.
"""Load the data from the specified file using pickle with thread-safe locking.
This method uses portalocker for cross-process read locking to ensure
data consistency when multiple processes may be accessing the file.
Returns:
The data loaded from the file.
The data loaded from the file, or an empty dictionary if the file
does not exist or is empty.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with self._lock:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading
with portalocker.Lock(
self.file_path, "rb", flags=portalocker.LOCK_SH
) as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {}
except Exception:
raise

View File

@@ -1,7 +1,5 @@
import inspect
import threading
from collections import defaultdict
from typing import Any
from unittest.mock import ANY
import pytest
@@ -15,7 +13,6 @@ from crewai.events.types.memory_events import (
)
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@pytest.fixture
@@ -117,7 +114,7 @@ def test_long_term_memory_search_events(long_term_memory):
test_query = "test query"
long_term_memory.search(test_query, limit=5)
long_term_memory.search(test_query, latest_n=5)
with condition:
success = condition.wait_for(
@@ -177,104 +174,10 @@ def test_save_and_search(long_term_memory):
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
find = long_term_memory.search("test_task", limit=5)[0]
find = long_term_memory.search("test_task", latest_n=5)[0]
assert find["score"] == 0.5
assert find["datetime"] == "test_datetime"
assert find["metadata"]["agent"] == "test_agent"
assert find["metadata"]["quality"] == 0.5
assert find["metadata"]["task"] == "test_task"
assert find["metadata"]["expected_output"] == "test_output"
class TestLongTermMemoryTypeSignatureCompatibility:
"""Tests to verify LongTermMemory method signatures are compatible with Memory base class.
These tests ensure that the Liskov Substitution Principle is maintained and that
LongTermMemory can be used polymorphically wherever Memory is expected.
"""
def test_save_signature_has_value_parameter(self):
"""Test that save() uses 'value' parameter name matching Memory base class."""
sig = inspect.signature(LongTermMemory.save)
params = list(sig.parameters.keys())
assert "value" in params, "save() should have 'value' parameter for LSP compliance"
assert "metadata" in params, "save() should have 'metadata' parameter for LSP compliance"
def test_save_signature_has_metadata_with_default(self):
"""Test that save() has metadata parameter with default value."""
sig = inspect.signature(LongTermMemory.save)
metadata_param = sig.parameters.get("metadata")
assert metadata_param is not None, "save() should have 'metadata' parameter"
assert metadata_param.default is None, "metadata should default to None"
def test_search_signature_has_query_parameter(self):
"""Test that search() uses 'query' parameter name matching Memory base class."""
sig = inspect.signature(LongTermMemory.search)
params = list(sig.parameters.keys())
assert "query" in params, "search() should have 'query' parameter for LSP compliance"
assert "limit" in params, "search() should have 'limit' parameter for LSP compliance"
assert "score_threshold" in params, "search() should have 'score_threshold' parameter for LSP compliance"
def test_search_signature_has_score_threshold_with_default(self):
"""Test that search() has score_threshold parameter with default value."""
sig = inspect.signature(LongTermMemory.search)
score_threshold_param = sig.parameters.get("score_threshold")
assert score_threshold_param is not None, "search() should have 'score_threshold' parameter"
assert score_threshold_param.default == 0.6, "score_threshold should default to 0.6"
def test_asave_signature_has_value_parameter(self):
"""Test that asave() uses 'value' parameter name matching Memory base class."""
sig = inspect.signature(LongTermMemory.asave)
params = list(sig.parameters.keys())
assert "value" in params, "asave() should have 'value' parameter for LSP compliance"
assert "metadata" in params, "asave() should have 'metadata' parameter for LSP compliance"
def test_asearch_signature_has_query_parameter(self):
"""Test that asearch() uses 'query' parameter name matching Memory base class."""
sig = inspect.signature(LongTermMemory.asearch)
params = list(sig.parameters.keys())
assert "query" in params, "asearch() should have 'query' parameter for LSP compliance"
assert "limit" in params, "asearch() should have 'limit' parameter for LSP compliance"
assert "score_threshold" in params, "asearch() should have 'score_threshold' parameter for LSP compliance"
def test_long_term_memory_is_subclass_of_memory(self):
"""Test that LongTermMemory is a proper subclass of Memory."""
assert issubclass(LongTermMemory, Memory), "LongTermMemory should be a subclass of Memory"
def test_save_with_metadata_parameter(self, long_term_memory):
"""Test that save() can be called with the metadata parameter (even if unused)."""
memory_item = LongTermMemoryItem(
agent="test_agent",
task="test_task_with_metadata",
expected_output="test_output",
datetime="test_datetime",
quality=0.8,
metadata={"task": "test_task_with_metadata", "quality": 0.8},
)
long_term_memory.save(value=memory_item, metadata={"extra": "data"})
results = long_term_memory.search(query="test_task_with_metadata", limit=1)
assert len(results) > 0
assert results[0]["metadata"]["agent"] == "test_agent"
def test_search_with_score_threshold_parameter(self, long_term_memory):
"""Test that search() can be called with the score_threshold parameter."""
memory_item = LongTermMemoryItem(
agent="test_agent",
task="test_task_score_threshold",
expected_output="test_output",
datetime="test_datetime",
quality=0.9,
metadata={"task": "test_task_score_threshold", "quality": 0.9},
)
long_term_memory.save(value=memory_item)
results = long_term_memory.search(
query="test_task_score_threshold",
limit=5,
score_threshold=0.5,
)
assert isinstance(results, list)
@pytest.fixture
def long_term_memory(self):
"""Fixture to create a LongTermMemory instance for this test class."""
return LongTermMemory()

View File

@@ -1,6 +1,8 @@
import os
import threading
import unittest
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from crewai.utilities.file_handler import PickleHandler
@@ -8,7 +10,6 @@ from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase):
def setUp(self):
# Use a unique file name for each test to avoid race conditions in parallel test execution
unique_id = str(uuid.uuid4())
self.file_name = f"test_data_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
@@ -47,3 +48,234 @@ class TestPickleHandler(unittest.TestCase):
assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
class TestPickleHandlerThreadSafety(unittest.TestCase):
"""Tests for thread-safety of PickleHandler operations."""
def setUp(self):
unique_id = str(uuid.uuid4())
self.file_name = f"test_thread_safe_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_concurrent_writes_same_handler(self):
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
num_threads = 10
num_writes_per_thread = 20
errors: list[Exception] = []
write_count = 0
count_lock = threading.Lock()
def write_data(thread_id: int) -> None:
nonlocal write_count
for i in range(num_writes_per_thread):
try:
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
self.handler.save(data)
with count_lock:
write_count += 1
except Exception as e:
errors.append(e)
threads = []
for i in range(num_threads):
t = threading.Thread(target=write_data, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
assert write_count == num_threads * num_writes_per_thread
loaded_data = self.handler.load()
assert isinstance(loaded_data, dict)
assert "thread" in loaded_data
assert "iteration" in loaded_data
def test_concurrent_reads_same_handler(self):
"""Test that concurrent reads from multiple threads don't cause issues."""
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
self.handler.save(test_data)
num_threads = 20
results: list[dict] = []
errors: list[Exception] = []
results_lock = threading.Lock()
def read_data() -> None:
try:
data = self.handler.load()
with results_lock:
results.append(data)
except Exception as e:
errors.append(e)
threads = []
for _ in range(num_threads):
t = threading.Thread(target=read_data)
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
assert len(results) == num_threads
for result in results:
assert result == test_data
def test_concurrent_read_write_same_handler(self):
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
initial_data = {"counter": 0}
self.handler.save(initial_data)
num_writers = 5
num_readers = 10
writes_per_thread = 10
reads_per_thread = 20
write_errors: list[Exception] = []
read_errors: list[Exception] = []
read_results: list[dict] = []
results_lock = threading.Lock()
def writer(thread_id: int) -> None:
for i in range(writes_per_thread):
try:
data = {"writer": thread_id, "write_num": i}
self.handler.save(data)
except Exception as e:
write_errors.append(e)
def reader() -> None:
for _ in range(reads_per_thread):
try:
data = self.handler.load()
with results_lock:
read_results.append(data)
except Exception as e:
read_errors.append(e)
threads = []
for i in range(num_writers):
t = threading.Thread(target=writer, args=(i,))
threads.append(t)
for _ in range(num_readers):
t = threading.Thread(target=reader)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
assert len(write_errors) == 0, f"Write errors: {write_errors}"
assert len(read_errors) == 0, f"Read errors: {read_errors}"
for result in read_results:
assert isinstance(result, dict)
def test_atomic_write_no_partial_data(self):
"""Test that atomic writes prevent partial/corrupted data from being read."""
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
num_iterations = 50
errors: list[Exception] = []
corruption_detected = False
corruption_lock = threading.Lock()
def writer() -> None:
for _ in range(num_iterations):
try:
self.handler.save(large_data)
except Exception as e:
errors.append(e)
def reader() -> None:
nonlocal corruption_detected
for _ in range(num_iterations * 2):
try:
data = self.handler.load()
if data and data != {} and data != large_data:
with corruption_lock:
corruption_detected = True
except Exception as e:
errors.append(e)
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert not corruption_detected, "Partial/corrupted data was read"
def test_thread_pool_concurrent_operations(self):
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
num_operations = 100
errors: list[Exception] = []
def operation(op_id: int) -> str:
try:
if op_id % 3 == 0:
self.handler.save({"op_id": op_id, "type": "write"})
return f"write_{op_id}"
else:
data = self.handler.load()
return f"read_{op_id}_{type(data).__name__}"
except Exception as e:
errors.append(e)
return f"error_{op_id}"
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(operation, i) for i in range(num_operations)]
results = [f.result() for f in as_completed(futures)]
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_operations
def test_multiple_handlers_same_file(self):
"""Test that multiple PickleHandler instances for the same file work correctly."""
handler1 = PickleHandler(self.file_name)
handler2 = PickleHandler(self.file_name)
num_operations = 50
errors: list[Exception] = []
def use_handler1() -> None:
for i in range(num_operations):
try:
handler1.save({"handler": 1, "iteration": i})
except Exception as e:
errors.append(e)
def use_handler2() -> None:
for i in range(num_operations):
try:
handler2.save({"handler": 2, "iteration": i})
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=use_handler1)
t2 = threading.Thread(target=use_handler2)
t1.start()
t2.start()
t1.join()
t2.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
final_data = self.handler.load()
assert isinstance(final_data, dict)
assert "handler" in final_data
assert "iteration" in final_data