mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 20:08:29 +00:00
Use file lock around Chroma client initialization (#3181)
This commit fixes a bug with concurrent processess and Chroma where `table collections already exists` (and similar) were raised. https://cookbook.chromadb.dev/core/system_constraints/
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -26,4 +26,5 @@ test_flow.html
|
||||
crewairules.mdc
|
||||
plan.md
|
||||
conceptual_plan.md
|
||||
build_image
|
||||
build_image
|
||||
chromadb-*.lock
|
||||
|
||||
@@ -39,6 +39,7 @@ dependencies = [
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -18,6 +18,7 @@ from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -84,14 +85,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
self.app = create_persistent_client(
|
||||
path=os.path.join(db_storage_path(), "knowledge"),
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
@@ -111,9 +109,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def reset(self):
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
settings=Settings(allow_reset=True),
|
||||
self.app = create_persistent_client(
|
||||
path=base_path, settings=Settings(allow_reset=True)
|
||||
)
|
||||
|
||||
self.app.reset()
|
||||
|
||||
@@ -4,12 +4,12 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
@@ -60,17 +60,15 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
|
||||
self.app = create_persistent_client(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import re
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
from hashlib import md5
|
||||
from typing import Optional
|
||||
|
||||
|
||||
MIN_COLLECTION_LENGTH = 3
|
||||
MAX_COLLECTION_LENGTH = 63
|
||||
DEFAULT_COLLECTION = "default_collection"
|
||||
@@ -60,3 +64,16 @@ def sanitize_collection_name(name: Optional[str], max_collection_length: int = M
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def create_persistent_client(path: str, **kwargs):
|
||||
"""
|
||||
Creates a persistent client for ChromaDB with a lock file to prevent
|
||||
concurrent creations. Works for both multi-threads and multi-processes
|
||||
environments.
|
||||
"""
|
||||
lockfile = f"chromadb-{md5(path.encode(), usedforsecurity=False).hexdigest()}.lock"
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(path=path, **kwargs)
|
||||
|
||||
return client
|
||||
|
||||
@@ -1,16 +1,27 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import pytest
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.utilities.chromadb import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
is_ipv4_pattern,
|
||||
sanitize_collection_name,
|
||||
create_persistent_client,
|
||||
)
|
||||
|
||||
|
||||
def persistent_client_worker(path, queue):
|
||||
try:
|
||||
create_persistent_client(path=path)
|
||||
queue.put(None)
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
|
||||
class TestChromadbUtils(unittest.TestCase):
|
||||
def test_sanitize_collection_name_long_name(self):
|
||||
"""Test sanitizing a very long collection name."""
|
||||
@@ -79,3 +90,34 @@ class TestChromadbUtils(unittest.TestCase):
|
||||
self.assertLessEqual(len(sanitized), MAX_COLLECTION_LENGTH)
|
||||
self.assertTrue(sanitized[0].isalnum())
|
||||
self.assertTrue(sanitized[-1].isalnum())
|
||||
|
||||
def test_create_persistent_client_passes_args(self):
|
||||
with patch(
|
||||
"crewai.utilities.chromadb.PersistentClient"
|
||||
) as mock_persistent_client, tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_instance = MagicMock()
|
||||
mock_persistent_client.return_value = mock_instance
|
||||
|
||||
settings = Settings(allow_reset=True)
|
||||
client = create_persistent_client(path=tmpdir, settings=settings)
|
||||
|
||||
mock_persistent_client.assert_called_once_with(
|
||||
path=tmpdir, settings=settings
|
||||
)
|
||||
self.assertIs(client, mock_instance)
|
||||
|
||||
def test_create_persistent_client_process_safe(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = [
|
||||
multiprocessing.Process(
|
||||
target=persistent_client_worker, args=(tmpdir, queue)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
errors = [queue.get(timeout=5) for _ in processes]
|
||||
self.assertTrue(all(err is None for err in errors))
|
||||
|
||||
Reference in New Issue
Block a user