Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
519f8ce0eb chore: re-trigger CI checks
Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:12:38 +00:00
Devin AI
802ca92e42 fix: make PickleHandler thread-safe with portalocker and atomic writes
- Add threading lock for same-process thread safety
- Use atomic write operations (write to temp file, then rename) for data integrity
- Use portalocker for cross-process read locking
- Add comprehensive thread-safety tests covering concurrent reads, writes, and mixed operations

Fixes #4215

Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:09:02 +00:00
2 changed files with 286 additions and 18 deletions

View File

@@ -2,8 +2,11 @@ from datetime import datetime
import json import json
import os import os
import pickle import pickle
import tempfile
import threading
from typing import Any, TypedDict from typing import Any, TypedDict
import portalocker
from typing_extensions import Unpack from typing_extensions import Unpack
@@ -123,10 +126,15 @@ class FileHandler:
class PickleHandler: class PickleHandler:
"""Handler for saving and loading data using pickle. """Thread-safe handler for saving and loading data using pickle.
This class provides thread-safe file operations using portalocker for
cross-process file locking and atomic write operations to prevent
data corruption during concurrent access.
Attributes: Attributes:
file_path: The path to the pickle file. file_path: The path to the pickle file.
_lock: Threading lock for thread-safe operations within the same process.
""" """
def __init__(self, file_name: str) -> None: def __init__(self, file_name: str) -> None:
@@ -141,34 +149,62 @@ class PickleHandler:
file_name += ".pkl" file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name) self.file_path = os.path.join(os.getcwd(), file_name)
self._lock = threading.Lock()
def initialize_file(self) -> None: def initialize_file(self) -> None:
"""Initialize the file with an empty dictionary and overwrite any existing data.""" """Initialize the file with an empty dictionary and overwrite any existing data."""
self.save({}) self.save({})
def save(self, data: Any) -> None: def save(self, data: Any) -> None:
""" """Save the data to the specified file using pickle with thread-safe atomic writes.
Save the data to the specified file using pickle.
This method uses a two-phase approach for thread safety:
1. Threading lock for same-process thread safety
2. Atomic write (write to temp file, then rename) for cross-process safety
and data integrity
Args: Args:
data: The data to be saved to the file. data: The data to be saved to the file.
""" """
with open(self.file_path, "wb") as f: with self._lock:
pickle.dump(obj=data, file=f) dir_name = os.path.dirname(self.file_path) or os.getcwd()
fd, temp_path = tempfile.mkstemp(
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
)
try:
with os.fdopen(fd, "wb") as f:
pickle.dump(obj=data, file=f)
f.flush()
os.fsync(f.fileno())
os.replace(temp_path, self.file_path)
except Exception:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def load(self) -> Any: def load(self) -> Any:
"""Load the data from the specified file using pickle. """Load the data from the specified file using pickle with thread-safe locking.
This method uses portalocker for cross-process read locking to ensure
data consistency when multiple processes may be accessing the file.
Returns: Returns:
The data loaded from the file. The data loaded from the file, or an empty dictionary if the file
does not exist or is empty.
""" """
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: with self._lock:
return {} # Return an empty dictionary if the file does not exist or is empty if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
with open(self.file_path, "rb") as file: with portalocker.Lock(
try: self.file_path, "rb", flags=portalocker.LOCK_SH
return pickle.load(file) # noqa: S301 ) as file:
except EOFError: try:
return {} # Return an empty dictionary if the file is empty or corrupted return pickle.load(file) # noqa: S301
except Exception: except EOFError:
raise # Raise any other exceptions that occur during loading return {}
except Exception:
raise

View File

@@ -1,6 +1,8 @@
import os import os
import threading
import unittest import unittest
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest import pytest
from crewai.utilities.file_handler import PickleHandler from crewai.utilities.file_handler import PickleHandler
@@ -8,7 +10,6 @@ from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase): class TestPickleHandler(unittest.TestCase):
def setUp(self): def setUp(self):
# Use a unique file name for each test to avoid race conditions in parallel test execution
unique_id = str(uuid.uuid4()) unique_id = str(uuid.uuid4())
self.file_name = f"test_data_{unique_id}.pkl" self.file_name = f"test_data_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name) self.file_path = os.path.join(os.getcwd(), self.file_name)
@@ -47,3 +48,234 @@ class TestPickleHandler(unittest.TestCase):
assert str(exc.value) == "pickle data was truncated" assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type) assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
class TestPickleHandlerThreadSafety(unittest.TestCase):
"""Tests for thread-safety of PickleHandler operations."""
def setUp(self):
unique_id = str(uuid.uuid4())
self.file_name = f"test_thread_safe_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_concurrent_writes_same_handler(self):
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
num_threads = 10
num_writes_per_thread = 20
errors: list[Exception] = []
write_count = 0
count_lock = threading.Lock()
def write_data(thread_id: int) -> None:
nonlocal write_count
for i in range(num_writes_per_thread):
try:
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
self.handler.save(data)
with count_lock:
write_count += 1
except Exception as e:
errors.append(e)
threads = []
for i in range(num_threads):
t = threading.Thread(target=write_data, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
assert write_count == num_threads * num_writes_per_thread
loaded_data = self.handler.load()
assert isinstance(loaded_data, dict)
assert "thread" in loaded_data
assert "iteration" in loaded_data
def test_concurrent_reads_same_handler(self):
"""Test that concurrent reads from multiple threads don't cause issues."""
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
self.handler.save(test_data)
num_threads = 20
results: list[dict] = []
errors: list[Exception] = []
results_lock = threading.Lock()
def read_data() -> None:
try:
data = self.handler.load()
with results_lock:
results.append(data)
except Exception as e:
errors.append(e)
threads = []
for _ in range(num_threads):
t = threading.Thread(target=read_data)
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
assert len(results) == num_threads
for result in results:
assert result == test_data
def test_concurrent_read_write_same_handler(self):
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
initial_data = {"counter": 0}
self.handler.save(initial_data)
num_writers = 5
num_readers = 10
writes_per_thread = 10
reads_per_thread = 20
write_errors: list[Exception] = []
read_errors: list[Exception] = []
read_results: list[dict] = []
results_lock = threading.Lock()
def writer(thread_id: int) -> None:
for i in range(writes_per_thread):
try:
data = {"writer": thread_id, "write_num": i}
self.handler.save(data)
except Exception as e:
write_errors.append(e)
def reader() -> None:
for _ in range(reads_per_thread):
try:
data = self.handler.load()
with results_lock:
read_results.append(data)
except Exception as e:
read_errors.append(e)
threads = []
for i in range(num_writers):
t = threading.Thread(target=writer, args=(i,))
threads.append(t)
for _ in range(num_readers):
t = threading.Thread(target=reader)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
assert len(write_errors) == 0, f"Write errors: {write_errors}"
assert len(read_errors) == 0, f"Read errors: {read_errors}"
for result in read_results:
assert isinstance(result, dict)
def test_atomic_write_no_partial_data(self):
"""Test that atomic writes prevent partial/corrupted data from being read."""
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
num_iterations = 50
errors: list[Exception] = []
corruption_detected = False
corruption_lock = threading.Lock()
def writer() -> None:
for _ in range(num_iterations):
try:
self.handler.save(large_data)
except Exception as e:
errors.append(e)
def reader() -> None:
nonlocal corruption_detected
for _ in range(num_iterations * 2):
try:
data = self.handler.load()
if data and data != {} and data != large_data:
with corruption_lock:
corruption_detected = True
except Exception as e:
errors.append(e)
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert not corruption_detected, "Partial/corrupted data was read"
def test_thread_pool_concurrent_operations(self):
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
num_operations = 100
errors: list[Exception] = []
def operation(op_id: int) -> str:
try:
if op_id % 3 == 0:
self.handler.save({"op_id": op_id, "type": "write"})
return f"write_{op_id}"
else:
data = self.handler.load()
return f"read_{op_id}_{type(data).__name__}"
except Exception as e:
errors.append(e)
return f"error_{op_id}"
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(operation, i) for i in range(num_operations)]
results = [f.result() for f in as_completed(futures)]
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_operations
def test_multiple_handlers_same_file(self):
"""Test that multiple PickleHandler instances for the same file work correctly."""
handler1 = PickleHandler(self.file_name)
handler2 = PickleHandler(self.file_name)
num_operations = 50
errors: list[Exception] = []
def use_handler1() -> None:
for i in range(num_operations):
try:
handler1.save({"handler": 1, "iteration": i})
except Exception as e:
errors.append(e)
def use_handler2() -> None:
for i in range(num_operations):
try:
handler2.save({"handler": 2, "iteration": i})
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=use_handler1)
t2 = threading.Thread(target=use_handler2)
t1.start()
t2.start()
t1.join()
t2.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
final_data = self.handler.load()
assert isinstance(final_data, dict)
assert "handler" in final_data
assert "iteration" in final_data