fix: achieve parity between rag package and current impl (#3418)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

- Sanitize ChromaDB collection names and use original dir naming
- Add persistent client with file locking to the ChromaDB factory
- Add upsert support to the ChromaDB client
- Suppress ChromaDB deprecation warnings for `model_fields`
- Extract `suppress_logging` into shared `logger_utils`
- Update tests to reflect upsert behavior
- Docs: add additional note
This commit is contained in:
Greyson LaLonde
2025-08-28 11:22:36 -04:00
committed by GitHub
parent 0f1b764c3e
commit ec1eff02a8
9 changed files with 186 additions and 78 deletions

View File

@@ -1,6 +1,4 @@
import contextlib
import hashlib import hashlib
import io
import logging import logging
import os import os
import shutil import shutil
@@ -20,23 +18,7 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
class KnowledgeStorage(BaseKnowledgeStorage): class KnowledgeStorage(BaseKnowledgeStorage):
@@ -64,7 +46,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filter: Optional[dict] = None, filter: Optional[dict] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
with suppress_logging(): with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection: if self.collection:
fetched = self.collection.query( fetched = self.collection.query(
query_texts=query, query_texts=query,

View File

@@ -1,5 +1,3 @@
import contextlib
import io
import logging import logging
import os import os
import shutil import shutil
@@ -12,26 +10,10 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings import warnings
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
""" """
Extends Storage to handle embeddings for memory entries, improving Extends Storage to handle embeddings for memory entries, improving
@@ -122,7 +104,9 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app() self._initialize_app()
try: try:
with suppress_logging(): with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
response = self.collection.query(query_texts=query, n_results=limit) response = self.collection.query(query_texts=query, n_results=limit)
results = [] results = []

View File

@@ -1,5 +1,6 @@
"""ChromaDB client implementation.""" """ChromaDB client implementation."""
import logging
from typing import Any from typing import Any
from chromadb.api.types import ( from chromadb.api.types import (
@@ -20,7 +21,9 @@ from crewai.rag.chromadb.utils import (
_is_sync_client, _is_sync_client,
_prepare_documents_for_chromadb, _prepare_documents_for_chromadb,
_process_query_results, _process_query_results,
_sanitize_collection_name,
) )
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import ( from crewai.rag.core.base_client import (
BaseClient, BaseClient,
BaseCollectionParams, BaseCollectionParams,
@@ -97,7 +100,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine" metadata["hnsw:space"] = "cosine"
self.client.create_collection( self.client.create_collection(
name=kwargs["collection_name"], name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"), configuration=kwargs.get("configuration"),
metadata=metadata, metadata=metadata,
embedding_function=kwargs.get( embedding_function=kwargs.get(
@@ -154,7 +157,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine" metadata["hnsw:space"] = "cosine"
await self.client.create_collection( await self.client.create_collection(
name=kwargs["collection_name"], name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"), configuration=kwargs.get("configuration"),
metadata=metadata, metadata=metadata,
embedding_function=kwargs.get( embedding_function=kwargs.get(
@@ -205,7 +208,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine" metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection( return self.client.get_or_create_collection(
name=kwargs["collection_name"], name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"), configuration=kwargs.get("configuration"),
metadata=metadata, metadata=metadata,
embedding_function=kwargs.get( embedding_function=kwargs.get(
@@ -258,7 +261,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine" metadata["hnsw:space"] = "cosine"
return await self.client.get_or_create_collection( return await self.client.get_or_create_collection(
name=kwargs["collection_name"], name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"), configuration=kwargs.get("configuration"),
metadata=metadata, metadata=metadata,
embedding_function=kwargs.get( embedding_function=kwargs.get(
@@ -298,12 +301,12 @@ class ChromaDBClient(BaseClient):
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection( collection = self.client.get_collection(
name=collection_name, name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
collection.add( collection.upsert(
ids=prepared.ids, ids=prepared.ids,
documents=prepared.texts, documents=prepared.texts,
metadatas=prepared.metadatas, metadatas=prepared.metadatas,
@@ -340,11 +343,11 @@ class ChromaDBClient(BaseClient):
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection( collection = await self.client.get_collection(
name=collection_name, name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
await collection.add( await collection.upsert(
ids=prepared.ids, ids=prepared.ids,
documents=prepared.texts, documents=prepared.texts,
metadatas=prepared.metadatas, metadatas=prepared.metadatas,
@@ -385,19 +388,22 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs) params = _extract_search_params(kwargs)
collection = self.client.get_collection( collection = self.client.get_collection(
name=params.collection_name, name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
where = params.where if params.where is not None else params.metadata_filter where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = collection.query( with suppress_logging(
query_texts=[params.query], "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
n_results=params.limit, ):
where=where, results: QueryResult = collection.query(
where_document=params.where_document, query_texts=[params.query],
include=params.include, n_results=params.limit,
) where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results( return _process_query_results(
collection=collection, collection=collection,
@@ -440,19 +446,22 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs) params = _extract_search_params(kwargs)
collection = await self.client.get_collection( collection = await self.client.get_collection(
name=params.collection_name, name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
where = params.where if params.where is not None else params.metadata_filter where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = await collection.query( with suppress_logging(
query_texts=[params.query], "chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
n_results=params.limit, ):
where=where, results: QueryResult = await collection.query(
where_document=params.where_document, query_texts=[params.query],
include=params.include, n_results=params.limit,
) where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results( return _process_query_results(
collection=collection, collection=collection,
@@ -485,7 +494,7 @@ class ChromaDBClient(BaseClient):
) )
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
self.client.delete_collection(name=collection_name) self.client.delete_collection(name=_sanitize_collection_name(collection_name))
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously. """Delete a collection and all its data asynchronously.
@@ -515,7 +524,9 @@ class ChromaDBClient(BaseClient):
) )
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
await self.client.delete_collection(name=collection_name) await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
def reset(self) -> None: def reset(self) -> None:
"""Reset the vector database by deleting all collections and data. """Reset the vector database by deleting all collections and data.

View File

@@ -23,6 +23,12 @@ warnings.filterwarnings(
module="pydantic._internal._generate_schema", module="pydantic._internal._generate_schema",
) )
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
def _default_settings() -> Settings: def _default_settings() -> Settings:
"""Create default ChromaDB settings. """Create default ChromaDB settings.

View File

@@ -1,10 +1,17 @@
"""Constants for ChromaDB configuration.""" """Constants for ChromaDB configuration."""
import os import re
from typing import Final from typing import Final
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant" DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database" DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb") DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
MIN_COLLECTION_LENGTH: Final[int] = 3
MAX_COLLECTION_LENGTH: Final[int] = 63
DEFAULT_COLLECTION: Final[str] = "default_collection"
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")

View File

@@ -1,6 +1,9 @@
"""Factory functions for creating ChromaDB clients.""" """Factory functions for creating ChromaDB clients."""
from chromadb import Client import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient from crewai.rag.chromadb.client import ChromaDBClient
@@ -14,11 +17,24 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
Returns: Returns:
Configured ChromaDBClient instance. Configured ChromaDBClient instance.
Notes:
Need to update to use chromadb.Client to support more client types in the near future.
""" """
persist_dir = config.settings.persist_directory
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(
path=persist_dir,
settings=config.settings,
tenant=config.tenant,
database=config.database,
)
return ChromaDBClient( return ChromaDBClient(
client=Client( client=client,
settings=config.settings, tenant=config.tenant, database=config.database
),
embedding_function=config.embedding_function, embedding_function=config.embedding_function,
) )

View File

@@ -12,6 +12,13 @@ from chromadb.api.types import (
) )
from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
IPV4_PATTERN,
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
)
from crewai.rag.chromadb.types import ( from crewai.rag.chromadb.types import (
ChromaDBClientType, ChromaDBClientType,
ChromaDBCollectionSearchParams, ChromaDBCollectionSearchParams,
@@ -216,3 +223,58 @@ def _process_query_results(
distance_metric=distance_metric, distance_metric=distance_metric,
score_threshold=params.score_threshold, score_threshold=params.score_threshold,
) )
def _is_ipv4_pattern(name: str) -> bool:
"""Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def _sanitize_collection_name(
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""Sanitize a collection name to meet ChromaDB requirements.
Requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
max_collection_length: Maximum allowed length for the collection name
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if _is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized

View File

@@ -0,0 +1,38 @@
"""Logging utility functions for CrewAI."""
import contextlib
import io
import logging
from collections.abc import Generator
@contextlib.contextmanager
def suppress_logging(
logger_name: str,
level: int | str,
) -> Generator[None, None, None]:
"""Suppress verbose logging output from specified logger.
Commonly used to suppress ChromaDB's verbose HNSW index logging.
Args:
logger_name: The logger to suppress
level: The minimum level to allow (e.g., logging.ERROR or "ERROR")
Yields:
None
Example:
with suppress_logging("chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR):
collection.query(query_texts=["test"])
"""
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)

View File

@@ -253,8 +253,8 @@ class TestChromaDBClient:
) )
# Verify documents were added to collection # Verify documents were added to collection
mock_collection.add.assert_called_once() mock_collection.upsert.assert_called_once()
call_args = mock_collection.add.call_args call_args = mock_collection.upsert.call_args
assert len(call_args.kwargs["ids"]) == 1 assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"] assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}] assert call_args.kwargs["metadatas"] == [{"source": "test"}]
@@ -279,7 +279,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents) client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.add.assert_called_once_with( mock_collection.upsert.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"], ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"], documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}], metadatas=[{"source": "test1"}, {"source": "test2"}],
@@ -319,8 +319,8 @@ class TestChromaDBClient:
) )
# Verify documents were added to collection # Verify documents were added to collection
mock_collection.add.assert_called_once() mock_collection.upsert.assert_called_once()
call_args = mock_collection.add.call_args call_args = mock_collection.upsert.call_args
assert len(call_args.kwargs["ids"]) == 1 assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"] assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}] assert call_args.kwargs["metadatas"] == [{"source": "test"}]
@@ -352,7 +352,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents collection_name="test_collection", documents=documents
) )
mock_collection.add.assert_called_once_with( mock_collection.upsert.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"], ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"], documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}], metadatas=[{"source": "test1"}, {"source": "test2"}],