mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
Compare commits
2 Commits
devin/1768
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
519f8ce0eb | ||
|
|
802ca92e42 |
@@ -33,24 +33,13 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: LongTermMemoryItem,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save an item to long-term memory.
|
||||
|
||||
Args:
|
||||
value: The LongTermMemoryItem to save.
|
||||
metadata: Optional metadata dict (not used, metadata is extracted from the
|
||||
LongTermMemoryItem). Included for supertype compatibility.
|
||||
"""
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -59,23 +48,23 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item_metadata = value.metadata
|
||||
item_metadata.update(
|
||||
{"agent": value.agent, "expected_output": value.expected_output}
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=value.task,
|
||||
score=item_metadata["quality"],
|
||||
metadata=item_metadata,
|
||||
datetime=value.datetime,
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -86,28 +75,25 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.6,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The task description to search for.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results (not used for
|
||||
long-term memory, included for supertype compatibility).
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
@@ -115,8 +101,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -125,14 +111,14 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(query, limit)
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
query=task,
|
||||
results=results,
|
||||
limit=limit,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -145,32 +131,26 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: LongTermMemoryItem,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The LongTermMemoryItem to save.
|
||||
metadata: Optional metadata dict (not used, metadata is extracted from the
|
||||
LongTermMemoryItem). Included for supertype compatibility.
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -179,23 +159,23 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item_metadata = value.metadata
|
||||
item_metadata.update(
|
||||
{"agent": value.agent, "expected_output": value.expected_output}
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=value.task,
|
||||
score=item_metadata["quality"],
|
||||
metadata=item_metadata,
|
||||
datetime=value.datetime,
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -206,28 +186,25 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value.task,
|
||||
metadata=value.metadata,
|
||||
agent_role=value.agent,
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.6,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The task description to search for.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results (not used for
|
||||
long-term memory, included for supertype compatibility).
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
@@ -235,8 +212,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -245,14 +222,14 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(query, limit)
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
query=task,
|
||||
results=results,
|
||||
limit=limit,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -265,8 +242,8 @@ class LongTermMemory(Memory):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
|
||||
@@ -2,8 +2,11 @@ from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import portalocker
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
@@ -123,10 +126,15 @@ class FileHandler:
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
"""Handler for saving and loading data using pickle.
|
||||
"""Thread-safe handler for saving and loading data using pickle.
|
||||
|
||||
This class provides thread-safe file operations using portalocker for
|
||||
cross-process file locking and atomic write operations to prevent
|
||||
data corruption during concurrent access.
|
||||
|
||||
Attributes:
|
||||
file_path: The path to the pickle file.
|
||||
_lock: Threading lock for thread-safe operations within the same process.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name: str) -> None:
|
||||
@@ -141,34 +149,62 @@ class PickleHandler:
|
||||
file_name += ".pkl"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""Initialize the file with an empty dictionary and overwrite any existing data."""
|
||||
self.save({})
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
"""Save the data to the specified file using pickle with thread-safe atomic writes.
|
||||
|
||||
This method uses a two-phase approach for thread safety:
|
||||
1. Threading lock for same-process thread safety
|
||||
2. Atomic write (write to temp file, then rename) for cross-process safety
|
||||
and data integrity
|
||||
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
with self._lock:
|
||||
dir_name = os.path.dirname(self.file_path) or os.getcwd()
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(temp_path, self.file_path)
|
||||
except Exception:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle.
|
||||
"""Load the data from the specified file using pickle with thread-safe locking.
|
||||
|
||||
This method uses portalocker for cross-process read locking to ensure
|
||||
data consistency when multiple processes may be accessing the file.
|
||||
|
||||
Returns:
|
||||
The data loaded from the file.
|
||||
The data loaded from the file, or an empty dictionary if the file
|
||||
does not exist or is empty.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
with self._lock:
|
||||
if (
|
||||
not os.path.exists(self.file_path)
|
||||
or os.path.getsize(self.file_path) == 0
|
||||
):
|
||||
return {}
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
with portalocker.Lock(
|
||||
self.file_path, "rb", flags=portalocker.LOCK_SH
|
||||
) as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {}
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import inspect
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
@@ -15,7 +13,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -117,7 +114,7 @@ def test_long_term_memory_search_events(long_term_memory):
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, limit=5)
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
@@ -177,104 +174,10 @@ def test_save_and_search(long_term_memory):
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", limit=5)[0]
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
|
||||
|
||||
class TestLongTermMemoryTypeSignatureCompatibility:
|
||||
"""Tests to verify LongTermMemory method signatures are compatible with Memory base class.
|
||||
|
||||
These tests ensure that the Liskov Substitution Principle is maintained and that
|
||||
LongTermMemory can be used polymorphically wherever Memory is expected.
|
||||
"""
|
||||
|
||||
def test_save_signature_has_value_parameter(self):
|
||||
"""Test that save() uses 'value' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.save)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "value" in params, "save() should have 'value' parameter for LSP compliance"
|
||||
assert "metadata" in params, "save() should have 'metadata' parameter for LSP compliance"
|
||||
|
||||
def test_save_signature_has_metadata_with_default(self):
|
||||
"""Test that save() has metadata parameter with default value."""
|
||||
sig = inspect.signature(LongTermMemory.save)
|
||||
metadata_param = sig.parameters.get("metadata")
|
||||
assert metadata_param is not None, "save() should have 'metadata' parameter"
|
||||
assert metadata_param.default is None, "metadata should default to None"
|
||||
|
||||
def test_search_signature_has_query_parameter(self):
|
||||
"""Test that search() uses 'query' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.search)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "query" in params, "search() should have 'query' parameter for LSP compliance"
|
||||
assert "limit" in params, "search() should have 'limit' parameter for LSP compliance"
|
||||
assert "score_threshold" in params, "search() should have 'score_threshold' parameter for LSP compliance"
|
||||
|
||||
def test_search_signature_has_score_threshold_with_default(self):
|
||||
"""Test that search() has score_threshold parameter with default value."""
|
||||
sig = inspect.signature(LongTermMemory.search)
|
||||
score_threshold_param = sig.parameters.get("score_threshold")
|
||||
assert score_threshold_param is not None, "search() should have 'score_threshold' parameter"
|
||||
assert score_threshold_param.default == 0.6, "score_threshold should default to 0.6"
|
||||
|
||||
def test_asave_signature_has_value_parameter(self):
|
||||
"""Test that asave() uses 'value' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.asave)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "value" in params, "asave() should have 'value' parameter for LSP compliance"
|
||||
assert "metadata" in params, "asave() should have 'metadata' parameter for LSP compliance"
|
||||
|
||||
def test_asearch_signature_has_query_parameter(self):
|
||||
"""Test that asearch() uses 'query' parameter name matching Memory base class."""
|
||||
sig = inspect.signature(LongTermMemory.asearch)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "query" in params, "asearch() should have 'query' parameter for LSP compliance"
|
||||
assert "limit" in params, "asearch() should have 'limit' parameter for LSP compliance"
|
||||
assert "score_threshold" in params, "asearch() should have 'score_threshold' parameter for LSP compliance"
|
||||
|
||||
def test_long_term_memory_is_subclass_of_memory(self):
|
||||
"""Test that LongTermMemory is a proper subclass of Memory."""
|
||||
assert issubclass(LongTermMemory, Memory), "LongTermMemory should be a subclass of Memory"
|
||||
|
||||
def test_save_with_metadata_parameter(self, long_term_memory):
|
||||
"""Test that save() can be called with the metadata parameter (even if unused)."""
|
||||
memory_item = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task_with_metadata",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.8,
|
||||
metadata={"task": "test_task_with_metadata", "quality": 0.8},
|
||||
)
|
||||
long_term_memory.save(value=memory_item, metadata={"extra": "data"})
|
||||
results = long_term_memory.search(query="test_task_with_metadata", limit=1)
|
||||
assert len(results) > 0
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
def test_search_with_score_threshold_parameter(self, long_term_memory):
|
||||
"""Test that search() can be called with the score_threshold parameter."""
|
||||
memory_item = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task_score_threshold",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.9,
|
||||
metadata={"task": "test_task_score_threshold", "quality": 0.9},
|
||||
)
|
||||
long_term_memory.save(value=memory_item)
|
||||
results = long_term_memory.search(
|
||||
query="test_task_score_threshold",
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(self):
|
||||
"""Fixture to create a LongTermMemory instance for this test class."""
|
||||
return LongTermMemory()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
@@ -8,7 +10,6 @@ from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a unique file name for each test to avoid race conditions in parallel test execution
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_data_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
@@ -47,3 +48,234 @@ class TestPickleHandler(unittest.TestCase):
|
||||
|
||||
assert str(exc.value) == "pickle data was truncated"
|
||||
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
|
||||
|
||||
|
||||
class TestPickleHandlerThreadSafety(unittest.TestCase):
|
||||
"""Tests for thread-safety of PickleHandler operations."""
|
||||
|
||||
def setUp(self):
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_thread_safe_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
|
||||
def test_concurrent_writes_same_handler(self):
|
||||
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
|
||||
num_threads = 10
|
||||
num_writes_per_thread = 20
|
||||
errors: list[Exception] = []
|
||||
write_count = 0
|
||||
count_lock = threading.Lock()
|
||||
|
||||
def write_data(thread_id: int) -> None:
|
||||
nonlocal write_count
|
||||
for i in range(num_writes_per_thread):
|
||||
try:
|
||||
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
|
||||
self.handler.save(data)
|
||||
with count_lock:
|
||||
write_count += 1
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
t = threading.Thread(target=write_data, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
|
||||
assert write_count == num_threads * num_writes_per_thread
|
||||
loaded_data = self.handler.load()
|
||||
assert isinstance(loaded_data, dict)
|
||||
assert "thread" in loaded_data
|
||||
assert "iteration" in loaded_data
|
||||
|
||||
def test_concurrent_reads_same_handler(self):
|
||||
"""Test that concurrent reads from multiple threads don't cause issues."""
|
||||
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
|
||||
self.handler.save(test_data)
|
||||
|
||||
num_threads = 20
|
||||
results: list[dict] = []
|
||||
errors: list[Exception] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def read_data() -> None:
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
results.append(data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=read_data)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
|
||||
assert len(results) == num_threads
|
||||
for result in results:
|
||||
assert result == test_data
|
||||
|
||||
def test_concurrent_read_write_same_handler(self):
|
||||
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
|
||||
initial_data = {"counter": 0}
|
||||
self.handler.save(initial_data)
|
||||
|
||||
num_writers = 5
|
||||
num_readers = 10
|
||||
writes_per_thread = 10
|
||||
reads_per_thread = 20
|
||||
write_errors: list[Exception] = []
|
||||
read_errors: list[Exception] = []
|
||||
read_results: list[dict] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def writer(thread_id: int) -> None:
|
||||
for i in range(writes_per_thread):
|
||||
try:
|
||||
data = {"writer": thread_id, "write_num": i}
|
||||
self.handler.save(data)
|
||||
except Exception as e:
|
||||
write_errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
for _ in range(reads_per_thread):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
read_results.append(data)
|
||||
except Exception as e:
|
||||
read_errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_writers):
|
||||
t = threading.Thread(target=writer, args=(i,))
|
||||
threads.append(t)
|
||||
|
||||
for _ in range(num_readers):
|
||||
t = threading.Thread(target=reader)
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(write_errors) == 0, f"Write errors: {write_errors}"
|
||||
assert len(read_errors) == 0, f"Read errors: {read_errors}"
|
||||
for result in read_results:
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_atomic_write_no_partial_data(self):
|
||||
"""Test that atomic writes prevent partial/corrupted data from being read."""
|
||||
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
|
||||
num_iterations = 50
|
||||
errors: list[Exception] = []
|
||||
corruption_detected = False
|
||||
corruption_lock = threading.Lock()
|
||||
|
||||
def writer() -> None:
|
||||
for _ in range(num_iterations):
|
||||
try:
|
||||
self.handler.save(large_data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
nonlocal corruption_detected
|
||||
for _ in range(num_iterations * 2):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
if data and data != {} and data != large_data:
|
||||
with corruption_lock:
|
||||
corruption_detected = True
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
|
||||
writer_thread.start()
|
||||
reader_thread.start()
|
||||
|
||||
writer_thread.join()
|
||||
reader_thread.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert not corruption_detected, "Partial/corrupted data was read"
|
||||
|
||||
def test_thread_pool_concurrent_operations(self):
|
||||
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
|
||||
num_operations = 100
|
||||
errors: list[Exception] = []
|
||||
|
||||
def operation(op_id: int) -> str:
|
||||
try:
|
||||
if op_id % 3 == 0:
|
||||
self.handler.save({"op_id": op_id, "type": "write"})
|
||||
return f"write_{op_id}"
|
||||
else:
|
||||
data = self.handler.load()
|
||||
return f"read_{op_id}_{type(data).__name__}"
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
return f"error_{op_id}"
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
futures = [executor.submit(operation, i) for i in range(num_operations)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == num_operations
|
||||
|
||||
def test_multiple_handlers_same_file(self):
|
||||
"""Test that multiple PickleHandler instances for the same file work correctly."""
|
||||
handler1 = PickleHandler(self.file_name)
|
||||
handler2 = PickleHandler(self.file_name)
|
||||
|
||||
num_operations = 50
|
||||
errors: list[Exception] = []
|
||||
|
||||
def use_handler1() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler1.save({"handler": 1, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def use_handler2() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler2.save({"handler": 2, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
t1 = threading.Thread(target=use_handler1)
|
||||
t2 = threading.Thread(target=use_handler2)
|
||||
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
final_data = self.handler.load()
|
||||
assert isinstance(final_data, dict)
|
||||
assert "handler" in final_data
|
||||
assert "iteration" in final_data
|
||||
|
||||
Reference in New Issue
Block a user