From ec1eff02a8b98851230b27c7b5351b857cf6ccb7 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 28 Aug 2025 11:22:36 -0400 Subject: [PATCH] fix: achieve parity between rag package and current impl (#3418) - Sanitize ChromaDB collection names and use original dir naming - Add persistent client with file locking to the ChromaDB factory - Add upsert support to the ChromaDB client - Suppress ChromaDB deprecation warnings for `model_fields` - Extract `suppress_logging` into shared `logger_utils` - Update tests to reflect upsert behavior - Docs: add additional note --- .../knowledge/storage/knowledge_storage.py | 24 ++----- src/crewai/memory/storage/rag_storage.py | 24 ++----- src/crewai/rag/chromadb/client.py | 63 +++++++++++-------- src/crewai/rag/chromadb/config.py | 6 ++ src/crewai/rag/chromadb/constants.py | 11 +++- src/crewai/rag/chromadb/factory.py | 24 +++++-- src/crewai/rag/chromadb/utils.py | 62 ++++++++++++++++++ src/crewai/utilities/logger_utils.py | 38 +++++++++++ tests/rag/chromadb/test_client.py | 12 ++-- 9 files changed, 186 insertions(+), 78 deletions(-) create mode 100644 src/crewai/utilities/logger_utils.py diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 347d6990b..3629dc7ce 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -1,6 +1,4 @@ -import contextlib import hashlib -import io import logging import os import shutil @@ -20,23 +18,7 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger from crewai.utilities.paths import db_storage_path from crewai.utilities.chromadb import create_persistent_client - - -@contextlib.contextmanager -def suppress_logging( - logger_name="chromadb.segment.impl.vector.local_persistent_hnsw", - level=logging.ERROR, -): - logger = logging.getLogger(logger_name) - original_level = logger.getEffectiveLevel() - logger.setLevel(level) - with ( - contextlib.redirect_stdout(io.StringIO()), - contextlib.redirect_stderr(io.StringIO()), - contextlib.suppress(UserWarning), - ): - yield - logger.setLevel(original_level) +from crewai.utilities.logger_utils import suppress_logging class KnowledgeStorage(BaseKnowledgeStorage): @@ -64,7 +46,9 @@ class KnowledgeStorage(BaseKnowledgeStorage): filter: Optional[dict] = None, score_threshold: float = 0.35, ) -> List[Dict[str, Any]]: - with suppress_logging(): + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): if self.collection: fetched = self.collection.query( query_texts=query, diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 5e8974f6e..504da2fce 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -1,5 +1,3 @@ -import contextlib -import io import logging import os import shutil @@ -12,26 +10,10 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator from crewai.utilities.chromadb import create_persistent_client from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.paths import db_storage_path +from crewai.utilities.logger_utils import suppress_logging import warnings -@contextlib.contextmanager -def suppress_logging( - logger_name="chromadb.segment.impl.vector.local_persistent_hnsw", - level=logging.ERROR, -): - logger = logging.getLogger(logger_name) - original_level = logger.getEffectiveLevel() - logger.setLevel(level) - with ( - contextlib.redirect_stdout(io.StringIO()), - contextlib.redirect_stderr(io.StringIO()), - contextlib.suppress(UserWarning), - ): - yield - logger.setLevel(original_level) - - class RAGStorage(BaseRAGStorage): """ Extends Storage to handle embeddings for memory entries, improving @@ -122,7 +104,9 @@ class RAGStorage(BaseRAGStorage): self._initialize_app() try: - with suppress_logging(): + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): response = self.collection.query(query_texts=query, n_results=limit) results = [] diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py index b61f85e36..ca3ae62b8 100644 --- a/src/crewai/rag/chromadb/client.py +++ b/src/crewai/rag/chromadb/client.py @@ -1,5 +1,6 @@ """ChromaDB client implementation.""" +import logging from typing import Any from chromadb.api.types import ( @@ -20,7 +21,9 @@ from crewai.rag.chromadb.utils import ( _is_sync_client, _prepare_documents_for_chromadb, _process_query_results, + _sanitize_collection_name, ) +from crewai.utilities.logger_utils import suppress_logging from crewai.rag.core.base_client import ( BaseClient, BaseCollectionParams, @@ -97,7 +100,7 @@ class ChromaDBClient(BaseClient): metadata["hnsw:space"] = "cosine" self.client.create_collection( - name=kwargs["collection_name"], + name=_sanitize_collection_name(kwargs["collection_name"]), configuration=kwargs.get("configuration"), metadata=metadata, embedding_function=kwargs.get( @@ -154,7 +157,7 @@ class ChromaDBClient(BaseClient): metadata["hnsw:space"] = "cosine" await self.client.create_collection( - name=kwargs["collection_name"], + name=_sanitize_collection_name(kwargs["collection_name"]), configuration=kwargs.get("configuration"), metadata=metadata, embedding_function=kwargs.get( @@ -205,7 +208,7 @@ class ChromaDBClient(BaseClient): metadata["hnsw:space"] = "cosine" return self.client.get_or_create_collection( - name=kwargs["collection_name"], + name=_sanitize_collection_name(kwargs["collection_name"]), configuration=kwargs.get("configuration"), metadata=metadata, embedding_function=kwargs.get( @@ -258,7 +261,7 @@ class ChromaDBClient(BaseClient): metadata["hnsw:space"] = "cosine" return await self.client.get_or_create_collection( - name=kwargs["collection_name"], + name=_sanitize_collection_name(kwargs["collection_name"]), configuration=kwargs.get("configuration"), metadata=metadata, embedding_function=kwargs.get( @@ -298,12 +301,12 @@ class ChromaDBClient(BaseClient): raise ValueError("Documents list cannot be empty") collection = self.client.get_collection( - name=collection_name, + name=_sanitize_collection_name(collection_name), embedding_function=self.embedding_function, ) prepared = _prepare_documents_for_chromadb(documents) - collection.add( + collection.upsert( ids=prepared.ids, documents=prepared.texts, metadatas=prepared.metadatas, @@ -340,11 +343,11 @@ class ChromaDBClient(BaseClient): raise ValueError("Documents list cannot be empty") collection = await self.client.get_collection( - name=collection_name, + name=_sanitize_collection_name(collection_name), embedding_function=self.embedding_function, ) prepared = _prepare_documents_for_chromadb(documents) - await collection.add( + await collection.upsert( ids=prepared.ids, documents=prepared.texts, metadatas=prepared.metadatas, @@ -385,19 +388,22 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) collection = self.client.get_collection( - name=params.collection_name, + name=_sanitize_collection_name(params.collection_name), embedding_function=self.embedding_function, ) where = params.where if params.where is not None else params.metadata_filter - results: QueryResult = collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, - ) + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) return _process_query_results( collection=collection, @@ -440,19 +446,22 @@ class ChromaDBClient(BaseClient): params = _extract_search_params(kwargs) collection = await self.client.get_collection( - name=params.collection_name, + name=_sanitize_collection_name(params.collection_name), embedding_function=self.embedding_function, ) where = params.where if params.where is not None else params.metadata_filter - results: QueryResult = await collection.query( - query_texts=[params.query], - n_results=params.limit, - where=where, - where_document=params.where_document, - include=params.include, - ) + with suppress_logging( + "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR + ): + results: QueryResult = await collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) return _process_query_results( collection=collection, @@ -485,7 +494,7 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - self.client.delete_collection(name=collection_name) + self.client.delete_collection(name=_sanitize_collection_name(collection_name)) async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Delete a collection and all its data asynchronously. @@ -515,7 +524,9 @@ class ChromaDBClient(BaseClient): ) collection_name = kwargs["collection_name"] - await self.client.delete_collection(name=collection_name) + await self.client.delete_collection( + name=_sanitize_collection_name(collection_name) + ) def reset(self) -> None: """Reset the vector database by deleting all collections and data. diff --git a/src/crewai/rag/chromadb/config.py b/src/crewai/rag/chromadb/config.py index 1f536dcf6..33a3ed9ae 100644 --- a/src/crewai/rag/chromadb/config.py +++ b/src/crewai/rag/chromadb/config.py @@ -23,6 +23,12 @@ warnings.filterwarnings( module="pydantic._internal._generate_schema", ) +warnings.filterwarnings( + "ignore", + message=r".*'model_fields'.*is deprecated.*", + module=r"^chromadb(\.|$)", +) + def _default_settings() -> Settings: """Create default ChromaDB settings. diff --git a/src/crewai/rag/chromadb/constants.py b/src/crewai/rag/chromadb/constants.py index 8dba23fd0..8082356c6 100644 --- a/src/crewai/rag/chromadb/constants.py +++ b/src/crewai/rag/chromadb/constants.py @@ -1,10 +1,17 @@ """Constants for ChromaDB configuration.""" -import os +import re from typing import Final from crewai.utilities.paths import db_storage_path DEFAULT_TENANT: Final[str] = "default_tenant" DEFAULT_DATABASE: Final[str] = "default_database" -DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb") +DEFAULT_STORAGE_PATH: Final[str] = db_storage_path() + +MIN_COLLECTION_LENGTH: Final[int] = 3 +MAX_COLLECTION_LENGTH: Final[int] = 63 +DEFAULT_COLLECTION: Final[str] = "default_collection" + +INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]") +IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$") diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py index 4d3844910..60bf69131 100644 --- a/src/crewai/rag/chromadb/factory.py +++ b/src/crewai/rag/chromadb/factory.py @@ -1,6 +1,9 @@ """Factory functions for creating ChromaDB clients.""" -from chromadb import Client +import os +from hashlib import md5 +import portalocker +from chromadb import PersistentClient from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.client import ChromaDBClient @@ -14,11 +17,24 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: Returns: Configured ChromaDBClient instance. + + Notes: + Need to update to use chromadb.Client to support more client types in the near future. """ + persist_dir = config.settings.persist_directory + lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest() + lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock") + + with portalocker.Lock(lockfile): + client = PersistentClient( + path=persist_dir, + settings=config.settings, + tenant=config.tenant, + database=config.database, + ) + return ChromaDBClient( - client=Client( - settings=config.settings, tenant=config.tenant, database=config.database - ), + client=client, embedding_function=config.embedding_function, ) diff --git a/src/crewai/rag/chromadb/utils.py b/src/crewai/rag/chromadb/utils.py index 1226be80c..23f66f4c0 100644 --- a/src/crewai/rag/chromadb/utils.py +++ b/src/crewai/rag/chromadb/utils.py @@ -12,6 +12,13 @@ from chromadb.api.types import ( ) from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.models.Collection import Collection +from crewai.rag.chromadb.constants import ( + DEFAULT_COLLECTION, + INVALID_CHARS_PATTERN, + IPV4_PATTERN, + MAX_COLLECTION_LENGTH, + MIN_COLLECTION_LENGTH, +) from crewai.rag.chromadb.types import ( ChromaDBClientType, ChromaDBCollectionSearchParams, @@ -216,3 +223,58 @@ def _process_query_results( distance_metric=distance_metric, score_threshold=params.score_threshold, ) + + +def _is_ipv4_pattern(name: str) -> bool: + """Check if a string matches an IPv4 address pattern. + + Args: + name: The string to check + + Returns: + True if the string matches an IPv4 pattern, False otherwise + """ + return bool(IPV4_PATTERN.match(name)) + + +def _sanitize_collection_name( + name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH +) -> str: + """Sanitize a collection name to meet ChromaDB requirements. + + Requirements: + 1. 3-63 characters long + 2. Starts and ends with alphanumeric character + 3. Contains only alphanumeric characters, underscores, or hyphens + 4. No consecutive periods + 5. Not a valid IPv4 address + + Args: + name: The original collection name to sanitize + max_collection_length: Maximum allowed length for the collection name + + Returns: + A sanitized collection name that meets ChromaDB requirements + """ + if not name: + return DEFAULT_COLLECTION + + if _is_ipv4_pattern(name): + name = f"ip_{name}" + + sanitized = INVALID_CHARS_PATTERN.sub("_", name) + + if not sanitized[0].isalnum(): + sanitized = "a" + sanitized + + if not sanitized[-1].isalnum(): + sanitized = sanitized[:-1] + "z" + + if len(sanitized) < MIN_COLLECTION_LENGTH: + sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized)) + if len(sanitized) > max_collection_length: + sanitized = sanitized[:max_collection_length] + if not sanitized[-1].isalnum(): + sanitized = sanitized[:-1] + "z" + + return sanitized diff --git a/src/crewai/utilities/logger_utils.py b/src/crewai/utilities/logger_utils.py new file mode 100644 index 000000000..b8af01289 --- /dev/null +++ b/src/crewai/utilities/logger_utils.py @@ -0,0 +1,38 @@ +"""Logging utility functions for CrewAI.""" + +import contextlib +import io +import logging +from collections.abc import Generator + + +@contextlib.contextmanager +def suppress_logging( + logger_name: str, + level: int | str, +) -> Generator[None, None, None]: + """Suppress verbose logging output from specified logger. + + Commonly used to suppress ChromaDB's verbose HNSW index logging. + + Args: + logger_name: The logger to suppress + level: The minimum level to allow (e.g., logging.ERROR or "ERROR") + + Yields: + None + + Example: + with suppress_logging("chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR): + collection.query(query_texts=["test"]) + """ + logger = logging.getLogger(logger_name) + original_level = logger.getEffectiveLevel() + logger.setLevel(level) + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + contextlib.suppress(UserWarning), + ): + yield + logger.setLevel(original_level) diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py index 4006b981e..88742a711 100644 --- a/tests/rag/chromadb/test_client.py +++ b/tests/rag/chromadb/test_client.py @@ -253,8 +253,8 @@ class TestChromaDBClient: ) # Verify documents were added to collection - mock_collection.add.assert_called_once() - call_args = mock_collection.add.call_args + mock_collection.upsert.assert_called_once() + call_args = mock_collection.upsert.call_args assert len(call_args.kwargs["ids"]) == 1 assert call_args.kwargs["documents"] == ["Test document"] assert call_args.kwargs["metadatas"] == [{"source": "test"}] @@ -279,7 +279,7 @@ class TestChromaDBClient: client.add_documents(collection_name="test_collection", documents=documents) - mock_collection.add.assert_called_once_with( + mock_collection.upsert.assert_called_once_with( ids=["custom_id_1", "custom_id_2"], documents=["First document", "Second document"], metadatas=[{"source": "test1"}, {"source": "test2"}], @@ -319,8 +319,8 @@ class TestChromaDBClient: ) # Verify documents were added to collection - mock_collection.add.assert_called_once() - call_args = mock_collection.add.call_args + mock_collection.upsert.assert_called_once() + call_args = mock_collection.upsert.call_args assert len(call_args.kwargs["ids"]) == 1 assert call_args.kwargs["documents"] == ["Test document"] assert call_args.kwargs["metadatas"] == [{"source": "test"}] @@ -352,7 +352,7 @@ class TestChromaDBClient: collection_name="test_collection", documents=documents ) - mock_collection.add.assert_called_once_with( + mock_collection.upsert.assert_called_once_with( ids=["custom_id_1", "custom_id_2"], documents=["First document", "Second document"], metadatas=[{"source": "test1"}, {"source": "test2"}],