mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 15:48:23 +00:00
Compare commits
2 Commits
devin/1768
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
519f8ce0eb | ||
|
|
802ca92e42 |
@@ -1,10 +1,10 @@
|
||||
from functools import lru_cache
|
||||
import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path: str = ".") -> None:
|
||||
self.path = path
|
||||
self._is_git_repo_cache: bool | None = None
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
@@ -40,26 +40,22 @@ class Repository:
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None) # noqa: B019
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository.
|
||||
|
||||
The result is cached at the instance level to avoid redundant checks
|
||||
while allowing proper garbage collection of Repository instances.
|
||||
Notes:
|
||||
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||
"""
|
||||
if self._is_git_repo_cache is not None:
|
||||
return self._is_git_repo_cache
|
||||
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
self._is_git_repo_cache = True
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
self._is_git_repo_cache = False
|
||||
|
||||
return self._is_git_repo_cache
|
||||
return False
|
||||
|
||||
def has_uncommitted_changes(self) -> bool:
|
||||
"""Check if the repository has uncommitted changes."""
|
||||
|
||||
@@ -2,8 +2,11 @@ from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import portalocker
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
@@ -123,10 +126,15 @@ class FileHandler:
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
"""Handler for saving and loading data using pickle.
|
||||
"""Thread-safe handler for saving and loading data using pickle.
|
||||
|
||||
This class provides thread-safe file operations using portalocker for
|
||||
cross-process file locking and atomic write operations to prevent
|
||||
data corruption during concurrent access.
|
||||
|
||||
Attributes:
|
||||
file_path: The path to the pickle file.
|
||||
_lock: Threading lock for thread-safe operations within the same process.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name: str) -> None:
|
||||
@@ -141,34 +149,62 @@ class PickleHandler:
|
||||
file_name += ".pkl"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""Initialize the file with an empty dictionary and overwrite any existing data."""
|
||||
self.save({})
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
"""Save the data to the specified file using pickle with thread-safe atomic writes.
|
||||
|
||||
This method uses a two-phase approach for thread safety:
|
||||
1. Threading lock for same-process thread safety
|
||||
2. Atomic write (write to temp file, then rename) for cross-process safety
|
||||
and data integrity
|
||||
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
with self._lock:
|
||||
dir_name = os.path.dirname(self.file_path) or os.getcwd()
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(temp_path, self.file_path)
|
||||
except Exception:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle.
|
||||
"""Load the data from the specified file using pickle with thread-safe locking.
|
||||
|
||||
This method uses portalocker for cross-process read locking to ensure
|
||||
data consistency when multiple processes may be accessing the file.
|
||||
|
||||
Returns:
|
||||
The data loaded from the file.
|
||||
The data loaded from the file, or an empty dictionary if the file
|
||||
does not exist or is empty.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
with self._lock:
|
||||
if (
|
||||
not os.path.exists(self.file_path)
|
||||
or os.path.getsize(self.file_path) == 0
|
||||
):
|
||||
return {}
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
with portalocker.Lock(
|
||||
self.file_path, "rb", flags=portalocker.LOCK_SH
|
||||
) as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {}
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
import gc
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.cli.git import Repository
|
||||
|
||||
|
||||
@@ -103,82 +99,3 @@ def test_origin_url(fp, repository):
|
||||
stdout="https://github.com/user/repo.git\n",
|
||||
)
|
||||
assert repository.origin_url() == "https://github.com/user/repo.git"
|
||||
|
||||
|
||||
def test_repository_garbage_collection(fp):
|
||||
"""Test that Repository instances can be garbage collected.
|
||||
|
||||
This test verifies the fix for the memory leak issue where using
|
||||
@lru_cache on the is_git_repo() method prevented garbage collection
|
||||
of Repository instances.
|
||||
"""
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
|
||||
repo = Repository(path=".")
|
||||
weak_ref = weakref.ref(repo)
|
||||
|
||||
assert weak_ref() is not None
|
||||
|
||||
del repo
|
||||
gc.collect()
|
||||
|
||||
assert weak_ref() is None, (
|
||||
"Repository instance was not garbage collected. "
|
||||
"This indicates a memory leak, likely from @lru_cache on instance methods."
|
||||
)
|
||||
|
||||
|
||||
def test_is_git_repo_caching(fp):
|
||||
"""Test that is_git_repo() result is cached at the instance level.
|
||||
|
||||
This verifies that the instance-level caching works correctly,
|
||||
only calling the subprocess once per instance.
|
||||
"""
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
|
||||
repo = Repository(path=".")
|
||||
|
||||
assert repo._is_git_repo_cache is True
|
||||
|
||||
result1 = repo.is_git_repo()
|
||||
result2 = repo.is_git_repo()
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert repo._is_git_repo_cache is True
|
||||
|
||||
|
||||
def test_multiple_repository_instances_independent_caches(fp):
|
||||
"""Test that multiple Repository instances have independent caches.
|
||||
|
||||
This verifies that the instance-level caching doesn't share state
|
||||
between different Repository instances.
|
||||
"""
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
|
||||
repo1 = Repository(path=".")
|
||||
repo2 = Repository(path=".")
|
||||
|
||||
assert repo1._is_git_repo_cache is True
|
||||
assert repo2._is_git_repo_cache is True
|
||||
|
||||
assert repo1._is_git_repo_cache is not repo2._is_git_repo_cache or (
|
||||
repo1._is_git_repo_cache == repo2._is_git_repo_cache
|
||||
)
|
||||
|
||||
weak_ref1 = weakref.ref(repo1)
|
||||
del repo1
|
||||
gc.collect()
|
||||
|
||||
assert weak_ref1() is None
|
||||
assert repo2._is_git_repo_cache is True
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
@@ -8,7 +10,6 @@ from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a unique file name for each test to avoid race conditions in parallel test execution
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_data_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
@@ -47,3 +48,234 @@ class TestPickleHandler(unittest.TestCase):
|
||||
|
||||
assert str(exc.value) == "pickle data was truncated"
|
||||
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
|
||||
|
||||
|
||||
class TestPickleHandlerThreadSafety(unittest.TestCase):
|
||||
"""Tests for thread-safety of PickleHandler operations."""
|
||||
|
||||
def setUp(self):
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_thread_safe_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
|
||||
def test_concurrent_writes_same_handler(self):
|
||||
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
|
||||
num_threads = 10
|
||||
num_writes_per_thread = 20
|
||||
errors: list[Exception] = []
|
||||
write_count = 0
|
||||
count_lock = threading.Lock()
|
||||
|
||||
def write_data(thread_id: int) -> None:
|
||||
nonlocal write_count
|
||||
for i in range(num_writes_per_thread):
|
||||
try:
|
||||
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
|
||||
self.handler.save(data)
|
||||
with count_lock:
|
||||
write_count += 1
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
t = threading.Thread(target=write_data, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
|
||||
assert write_count == num_threads * num_writes_per_thread
|
||||
loaded_data = self.handler.load()
|
||||
assert isinstance(loaded_data, dict)
|
||||
assert "thread" in loaded_data
|
||||
assert "iteration" in loaded_data
|
||||
|
||||
def test_concurrent_reads_same_handler(self):
|
||||
"""Test that concurrent reads from multiple threads don't cause issues."""
|
||||
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
|
||||
self.handler.save(test_data)
|
||||
|
||||
num_threads = 20
|
||||
results: list[dict] = []
|
||||
errors: list[Exception] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def read_data() -> None:
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
results.append(data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=read_data)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
|
||||
assert len(results) == num_threads
|
||||
for result in results:
|
||||
assert result == test_data
|
||||
|
||||
def test_concurrent_read_write_same_handler(self):
|
||||
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
|
||||
initial_data = {"counter": 0}
|
||||
self.handler.save(initial_data)
|
||||
|
||||
num_writers = 5
|
||||
num_readers = 10
|
||||
writes_per_thread = 10
|
||||
reads_per_thread = 20
|
||||
write_errors: list[Exception] = []
|
||||
read_errors: list[Exception] = []
|
||||
read_results: list[dict] = []
|
||||
results_lock = threading.Lock()
|
||||
|
||||
def writer(thread_id: int) -> None:
|
||||
for i in range(writes_per_thread):
|
||||
try:
|
||||
data = {"writer": thread_id, "write_num": i}
|
||||
self.handler.save(data)
|
||||
except Exception as e:
|
||||
write_errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
for _ in range(reads_per_thread):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
with results_lock:
|
||||
read_results.append(data)
|
||||
except Exception as e:
|
||||
read_errors.append(e)
|
||||
|
||||
threads = []
|
||||
for i in range(num_writers):
|
||||
t = threading.Thread(target=writer, args=(i,))
|
||||
threads.append(t)
|
||||
|
||||
for _ in range(num_readers):
|
||||
t = threading.Thread(target=reader)
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(write_errors) == 0, f"Write errors: {write_errors}"
|
||||
assert len(read_errors) == 0, f"Read errors: {read_errors}"
|
||||
for result in read_results:
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_atomic_write_no_partial_data(self):
|
||||
"""Test that atomic writes prevent partial/corrupted data from being read."""
|
||||
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
|
||||
num_iterations = 50
|
||||
errors: list[Exception] = []
|
||||
corruption_detected = False
|
||||
corruption_lock = threading.Lock()
|
||||
|
||||
def writer() -> None:
|
||||
for _ in range(num_iterations):
|
||||
try:
|
||||
self.handler.save(large_data)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def reader() -> None:
|
||||
nonlocal corruption_detected
|
||||
for _ in range(num_iterations * 2):
|
||||
try:
|
||||
data = self.handler.load()
|
||||
if data and data != {} and data != large_data:
|
||||
with corruption_lock:
|
||||
corruption_detected = True
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
|
||||
writer_thread.start()
|
||||
reader_thread.start()
|
||||
|
||||
writer_thread.join()
|
||||
reader_thread.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert not corruption_detected, "Partial/corrupted data was read"
|
||||
|
||||
def test_thread_pool_concurrent_operations(self):
|
||||
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
|
||||
num_operations = 100
|
||||
errors: list[Exception] = []
|
||||
|
||||
def operation(op_id: int) -> str:
|
||||
try:
|
||||
if op_id % 3 == 0:
|
||||
self.handler.save({"op_id": op_id, "type": "write"})
|
||||
return f"write_{op_id}"
|
||||
else:
|
||||
data = self.handler.load()
|
||||
return f"read_{op_id}_{type(data).__name__}"
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
return f"error_{op_id}"
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
futures = [executor.submit(operation, i) for i in range(num_operations)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == num_operations
|
||||
|
||||
def test_multiple_handlers_same_file(self):
|
||||
"""Test that multiple PickleHandler instances for the same file work correctly."""
|
||||
handler1 = PickleHandler(self.file_name)
|
||||
handler2 = PickleHandler(self.file_name)
|
||||
|
||||
num_operations = 50
|
||||
errors: list[Exception] = []
|
||||
|
||||
def use_handler1() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler1.save({"handler": 1, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def use_handler2() -> None:
|
||||
for i in range(num_operations):
|
||||
try:
|
||||
handler2.save({"handler": 2, "iteration": i})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
t1 = threading.Thread(target=use_handler1)
|
||||
t2 = threading.Thread(target=use_handler2)
|
||||
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
final_data = self.handler.load()
|
||||
assert isinstance(final_data, dict)
|
||||
assert "handler" in final_data
|
||||
assert "iteration" in final_data
|
||||
|
||||
Reference in New Issue
Block a user