mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Use file lock around Chroma client initialization (#3181)
This commit fixes a bug with concurrent processess and Chroma where `table collections already exists` (and similar) were raised. https://cookbook.chromadb.dev/core/system_constraints/
This commit is contained in:
@@ -1,16 +1,27 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import pytest
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.chromadb import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
is_ipv4_pattern,
|
||||
sanitize_collection_name,
|
||||
create_persistent_client,
|
||||
)
|
||||
|
||||
|
||||
def persistent_client_worker(path, queue):
|
||||
try:
|
||||
create_persistent_client(path=path)
|
||||
queue.put(None)
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
|
||||
class TestChromadbUtils(unittest.TestCase):
|
||||
def test_sanitize_collection_name_long_name(self):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
@@ -79,3 +90,34 @@ class TestChromadbUtils(unittest.TestCase):
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_create_persistent_client_passes_args(self):
|
||||
with patch(
|
||||
"crewai.utilities.chromadb.PersistentClient"
|
||||
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_instance = MagicMock()
|
||||
mock_persistent_client.return_value = mock_instance
|
||||
|
||||
settings = Settings(allow_reset=True)
|
||||
client = create_persistent_client(path=tmpdir, settings=settings)
|
||||
|
||||
mock_persistent_client.assert_called_once_with(
|
||||
path=tmpdir, settings=settings
|
||||
)
|
||||
self.assertIs(client, mock_instance)
|
||||
|
||||
def test_create_persistent_client_process_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = [
|
||||
multiprocessing.Process(
|
||||
target=persistent_client_worker, args=(tmpdir, queue)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
errors = [queue.get(timeout=5) for _ in processes]
|
||||
self.assertTrue(all(err is None for err in errors))
|
||||
|
||||
Reference in New Issue
Block a user