fix: achieve parity between rag package and current impl (#3418)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

- Sanitize ChromaDB collection names and use original dir naming
- Add persistent client with file locking to the ChromaDB factory
- Add upsert support to the ChromaDB client
- Suppress ChromaDB deprecation warnings for `model_fields`
- Extract `suppress_logging` into shared `logger_utils`
- Update tests to reflect upsert behavior
- Docs: add additional note
This commit is contained in:
Greyson LaLonde
2025-08-28 11:22:36 -04:00
committed by GitHub
parent 0f1b764c3e
commit ec1eff02a8
9 changed files with 186 additions and 78 deletions

View File

@@ -1,6 +1,4 @@
import contextlib
import hashlib
import io
import logging
import os
import shutil
@@ -20,23 +18,7 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
from crewai.utilities.logger_utils import suppress_logging
class KnowledgeStorage(BaseKnowledgeStorage):
@@ -64,7 +46,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging():
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
if self.collection:
fetched = self.collection.query(
query_texts=query,

View File

@@ -1,5 +1,3 @@
import contextlib
import io
import logging
import os
import shutil
@@ -12,26 +10,10 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
@@ -122,7 +104,9 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app()
try:
with suppress_logging():
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
response = self.collection.query(query_texts=query, n_results=limit)
results = []

View File

@@ -1,5 +1,6 @@
"""ChromaDB client implementation."""
import logging
from typing import Any
from chromadb.api.types import (
@@ -20,7 +21,9 @@ from crewai.rag.chromadb.utils import (
_is_sync_client,
_prepare_documents_for_chromadb,
_process_query_results,
_sanitize_collection_name,
)
from crewai.utilities.logger_utils import suppress_logging
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
@@ -97,7 +100,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine"
self.client.create_collection(
name=kwargs["collection_name"],
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
@@ -154,7 +157,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine"
await self.client.create_collection(
name=kwargs["collection_name"],
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
@@ -205,7 +208,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection(
name=kwargs["collection_name"],
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
@@ -258,7 +261,7 @@ class ChromaDBClient(BaseClient):
metadata["hnsw:space"] = "cosine"
return await self.client.get_or_create_collection(
name=kwargs["collection_name"],
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
@@ -298,12 +301,12 @@ class ChromaDBClient(BaseClient):
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
name=collection_name,
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
collection.add(
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
@@ -340,11 +343,11 @@ class ChromaDBClient(BaseClient):
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
name=collection_name,
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
await collection.add(
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
@@ -385,19 +388,22 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
name=params.collection_name,
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
@@ -440,19 +446,22 @@ class ChromaDBClient(BaseClient):
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
name=params.collection_name,
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
@@ -485,7 +494,7 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=collection_name)
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
@@ -515,7 +524,9 @@ class ChromaDBClient(BaseClient):
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(name=collection_name)
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.

View File

@@ -23,6 +23,12 @@ warnings.filterwarnings(
module="pydantic._internal._generate_schema",
)
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
def _default_settings() -> Settings:
"""Create default ChromaDB settings.

View File

@@ -1,10 +1,17 @@
"""Constants for ChromaDB configuration."""
import os
import re
from typing import Final
from crewai.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
MIN_COLLECTION_LENGTH: Final[int] = 3
MAX_COLLECTION_LENGTH: Final[int] = 63
DEFAULT_COLLECTION: Final[str] = "default_collection"
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")

View File

@@ -1,6 +1,9 @@
"""Factory functions for creating ChromaDB clients."""
from chromadb import Client
import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
@@ -14,11 +17,24 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
Returns:
Configured ChromaDBClient instance.
Notes:
Need to update to use chromadb.Client to support more client types in the near future.
"""
persist_dir = config.settings.persist_directory
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(
path=persist_dir,
settings=config.settings,
tenant=config.tenant,
database=config.database,
)
return ChromaDBClient(
client=Client(
settings=config.settings, tenant=config.tenant, database=config.database
),
client=client,
embedding_function=config.embedding_function,
)

View File

@@ -12,6 +12,13 @@ from chromadb.api.types import (
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
IPV4_PATTERN,
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
)
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionSearchParams,
@@ -216,3 +223,58 @@ def _process_query_results(
distance_metric=distance_metric,
score_threshold=params.score_threshold,
)
def _is_ipv4_pattern(name: str) -> bool:
"""Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def _sanitize_collection_name(
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""Sanitize a collection name to meet ChromaDB requirements.
Requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
max_collection_length: Maximum allowed length for the collection name
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if _is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized

View File

@@ -0,0 +1,38 @@
"""Logging utility functions for CrewAI."""
import contextlib
import io
import logging
from collections.abc import Generator
@contextlib.contextmanager
def suppress_logging(
logger_name: str,
level: int | str,
) -> Generator[None, None, None]:
"""Suppress verbose logging output from specified logger.
Commonly used to suppress ChromaDB's verbose HNSW index logging.
Args:
logger_name: The logger to suppress
level: The minimum level to allow (e.g., logging.ERROR or "ERROR")
Yields:
None
Example:
with suppress_logging("chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR):
collection.query(query_texts=["test"])
"""
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)

View File

@@ -253,8 +253,8 @@ class TestChromaDBClient:
)
# Verify documents were added to collection
mock_collection.add.assert_called_once()
call_args = mock_collection.add.call_args
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
@@ -279,7 +279,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.add.assert_called_once_with(
mock_collection.upsert.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}],
@@ -319,8 +319,8 @@ class TestChromaDBClient:
)
# Verify documents were added to collection
mock_collection.add.assert_called_once()
call_args = mock_collection.add.call_args
mock_collection.upsert.assert_called_once()
call_args = mock_collection.upsert.call_args
assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
@@ -352,7 +352,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents
)
mock_collection.add.assert_called_once_with(
mock_collection.upsert.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}],