From 802ca92e421fcd675f339c70a8bfc273c443274d Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 10 Jan 2026 21:09:02 +0000 Subject: [PATCH] fix: make PickleHandler thread-safe with portalocker and atomic writes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add threading lock for same-process thread safety - Use atomic write operations (write to temp file, then rename) for data integrity - Use portalocker for cross-process read locking - Add comprehensive thread-safety tests covering concurrent reads, writes, and mixed operations Fixes #4215 Co-Authored-By: João --- .../src/crewai/utilities/file_handler.py | 70 ++++-- .../tests/utilities/test_file_handler.py | 234 +++++++++++++++++- 2 files changed, 286 insertions(+), 18 deletions(-) diff --git a/lib/crewai/src/crewai/utilities/file_handler.py b/lib/crewai/src/crewai/utilities/file_handler.py index ff50197a1..78af84cf9 100644 --- a/lib/crewai/src/crewai/utilities/file_handler.py +++ b/lib/crewai/src/crewai/utilities/file_handler.py @@ -2,8 +2,11 @@ from datetime import datetime import json import os import pickle +import tempfile +import threading from typing import Any, TypedDict +import portalocker from typing_extensions import Unpack @@ -123,10 +126,15 @@ class FileHandler: class PickleHandler: - """Handler for saving and loading data using pickle. + """Thread-safe handler for saving and loading data using pickle. + + This class provides thread-safe file operations using portalocker for + cross-process file locking and atomic write operations to prevent + data corruption during concurrent access. Attributes: file_path: The path to the pickle file. + _lock: Threading lock for thread-safe operations within the same process. """ def __init__(self, file_name: str) -> None: @@ -141,34 +149,62 @@ class PickleHandler: file_name += ".pkl" self.file_path = os.path.join(os.getcwd(), file_name) + self._lock = threading.Lock() def initialize_file(self) -> None: """Initialize the file with an empty dictionary and overwrite any existing data.""" self.save({}) def save(self, data: Any) -> None: - """ - Save the data to the specified file using pickle. + """Save the data to the specified file using pickle with thread-safe atomic writes. + + This method uses a two-phase approach for thread safety: + 1. Threading lock for same-process thread safety + 2. Atomic write (write to temp file, then rename) for cross-process safety + and data integrity Args: - data: The data to be saved to the file. + data: The data to be saved to the file. """ - with open(self.file_path, "wb") as f: - pickle.dump(obj=data, file=f) + with self._lock: + dir_name = os.path.dirname(self.file_path) or os.getcwd() + fd, temp_path = tempfile.mkstemp( + suffix=".pkl.tmp", prefix="pickle_", dir=dir_name + ) + try: + with os.fdopen(fd, "wb") as f: + pickle.dump(obj=data, file=f) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, self.file_path) + except Exception: + if os.path.exists(temp_path): + os.unlink(temp_path) + raise def load(self) -> Any: - """Load the data from the specified file using pickle. + """Load the data from the specified file using pickle with thread-safe locking. + + This method uses portalocker for cross-process read locking to ensure + data consistency when multiple processes may be accessing the file. Returns: - The data loaded from the file. + The data loaded from the file, or an empty dictionary if the file + does not exist or is empty. """ - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - return {} # Return an empty dictionary if the file does not exist or is empty + with self._lock: + if ( + not os.path.exists(self.file_path) + or os.path.getsize(self.file_path) == 0 + ): + return {} - with open(self.file_path, "rb") as file: - try: - return pickle.load(file) # noqa: S301 - except EOFError: - return {} # Return an empty dictionary if the file is empty or corrupted - except Exception: - raise # Raise any other exceptions that occur during loading + with portalocker.Lock( + self.file_path, "rb", flags=portalocker.LOCK_SH + ) as file: + try: + return pickle.load(file) # noqa: S301 + except EOFError: + return {} + except Exception: + raise diff --git a/lib/crewai/tests/utilities/test_file_handler.py b/lib/crewai/tests/utilities/test_file_handler.py index 1e1cbfba8..b28f1187b 100644 --- a/lib/crewai/tests/utilities/test_file_handler.py +++ b/lib/crewai/tests/utilities/test_file_handler.py @@ -1,6 +1,8 @@ import os +import threading import unittest import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from crewai.utilities.file_handler import PickleHandler @@ -8,7 +10,6 @@ from crewai.utilities.file_handler import PickleHandler class TestPickleHandler(unittest.TestCase): def setUp(self): - # Use a unique file name for each test to avoid race conditions in parallel test execution unique_id = str(uuid.uuid4()) self.file_name = f"test_data_{unique_id}.pkl" self.file_path = os.path.join(os.getcwd(), self.file_name) @@ -47,3 +48,234 @@ class TestPickleHandler(unittest.TestCase): assert str(exc.value) == "pickle data was truncated" assert "" == str(exc.type) + + +class TestPickleHandlerThreadSafety(unittest.TestCase): + """Tests for thread-safety of PickleHandler operations.""" + + def setUp(self): + unique_id = str(uuid.uuid4()) + self.file_name = f"test_thread_safe_{unique_id}.pkl" + self.file_path = os.path.join(os.getcwd(), self.file_name) + self.handler = PickleHandler(self.file_name) + + def tearDown(self): + if os.path.exists(self.file_path): + os.remove(self.file_path) + + def test_concurrent_writes_same_handler(self): + """Test that concurrent writes from multiple threads using the same handler don't corrupt data.""" + num_threads = 10 + num_writes_per_thread = 20 + errors: list[Exception] = [] + write_count = 0 + count_lock = threading.Lock() + + def write_data(thread_id: int) -> None: + nonlocal write_count + for i in range(num_writes_per_thread): + try: + data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"} + self.handler.save(data) + with count_lock: + write_count += 1 + except Exception as e: + errors.append(e) + + threads = [] + for i in range(num_threads): + t = threading.Thread(target=write_data, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}" + assert write_count == num_threads * num_writes_per_thread + loaded_data = self.handler.load() + assert isinstance(loaded_data, dict) + assert "thread" in loaded_data + assert "iteration" in loaded_data + + def test_concurrent_reads_same_handler(self): + """Test that concurrent reads from multiple threads don't cause issues.""" + test_data = {"key": "value", "nested": {"a": 1, "b": 2}} + self.handler.save(test_data) + + num_threads = 20 + results: list[dict] = [] + errors: list[Exception] = [] + results_lock = threading.Lock() + + def read_data() -> None: + try: + data = self.handler.load() + with results_lock: + results.append(data) + except Exception as e: + errors.append(e) + + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=read_data) + threads.append(t) + t.start() + + for t in threads: + t.join() + + assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}" + assert len(results) == num_threads + for result in results: + assert result == test_data + + def test_concurrent_read_write_same_handler(self): + """Test that concurrent reads and writes don't corrupt data or cause errors.""" + initial_data = {"counter": 0} + self.handler.save(initial_data) + + num_writers = 5 + num_readers = 10 + writes_per_thread = 10 + reads_per_thread = 20 + write_errors: list[Exception] = [] + read_errors: list[Exception] = [] + read_results: list[dict] = [] + results_lock = threading.Lock() + + def writer(thread_id: int) -> None: + for i in range(writes_per_thread): + try: + data = {"writer": thread_id, "write_num": i} + self.handler.save(data) + except Exception as e: + write_errors.append(e) + + def reader() -> None: + for _ in range(reads_per_thread): + try: + data = self.handler.load() + with results_lock: + read_results.append(data) + except Exception as e: + read_errors.append(e) + + threads = [] + for i in range(num_writers): + t = threading.Thread(target=writer, args=(i,)) + threads.append(t) + + for _ in range(num_readers): + t = threading.Thread(target=reader) + threads.append(t) + + for t in threads: + t.start() + + for t in threads: + t.join() + + assert len(write_errors) == 0, f"Write errors: {write_errors}" + assert len(read_errors) == 0, f"Read errors: {read_errors}" + for result in read_results: + assert isinstance(result, dict) + + def test_atomic_write_no_partial_data(self): + """Test that atomic writes prevent partial/corrupted data from being read.""" + large_data = {"key": "x" * 100000, "numbers": list(range(10000))} + num_iterations = 50 + errors: list[Exception] = [] + corruption_detected = False + corruption_lock = threading.Lock() + + def writer() -> None: + for _ in range(num_iterations): + try: + self.handler.save(large_data) + except Exception as e: + errors.append(e) + + def reader() -> None: + nonlocal corruption_detected + for _ in range(num_iterations * 2): + try: + data = self.handler.load() + if data and data != {} and data != large_data: + with corruption_lock: + corruption_detected = True + except Exception as e: + errors.append(e) + + writer_thread = threading.Thread(target=writer) + reader_thread = threading.Thread(target=reader) + + writer_thread.start() + reader_thread.start() + + writer_thread.join() + reader_thread.join() + + assert len(errors) == 0, f"Errors occurred: {errors}" + assert not corruption_detected, "Partial/corrupted data was read" + + def test_thread_pool_concurrent_operations(self): + """Test thread safety using ThreadPoolExecutor for more realistic concurrent access.""" + num_operations = 100 + errors: list[Exception] = [] + + def operation(op_id: int) -> str: + try: + if op_id % 3 == 0: + self.handler.save({"op_id": op_id, "type": "write"}) + return f"write_{op_id}" + else: + data = self.handler.load() + return f"read_{op_id}_{type(data).__name__}" + except Exception as e: + errors.append(e) + return f"error_{op_id}" + + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(operation, i) for i in range(num_operations)] + results = [f.result() for f in as_completed(futures)] + + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == num_operations + + def test_multiple_handlers_same_file(self): + """Test that multiple PickleHandler instances for the same file work correctly.""" + handler1 = PickleHandler(self.file_name) + handler2 = PickleHandler(self.file_name) + + num_operations = 50 + errors: list[Exception] = [] + + def use_handler1() -> None: + for i in range(num_operations): + try: + handler1.save({"handler": 1, "iteration": i}) + except Exception as e: + errors.append(e) + + def use_handler2() -> None: + for i in range(num_operations): + try: + handler2.save({"handler": 2, "iteration": i}) + except Exception as e: + errors.append(e) + + t1 = threading.Thread(target=use_handler1) + t2 = threading.Thread(target=use_handler2) + + t1.start() + t2.start() + + t1.join() + t2.join() + + assert len(errors) == 0, f"Errors occurred: {errors}" + final_data = self.handler.load() + assert isinstance(final_data, dict) + assert "handler" in final_data + assert "iteration" in final_data