mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: achieve parity between rag package and current impl (#3418)
- Sanitize ChromaDB collection names and use original dir naming - Add persistent client with file locking to the ChromaDB factory - Add upsert support to the ChromaDB client - Suppress ChromaDB deprecation warnings for `model_fields` - Extract `suppress_logging` into shared `logger_utils` - Update tests to reflect upsert behavior - Docs: add additional note
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -20,23 +18,7 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
||||
level=logging.ERROR,
|
||||
):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -64,7 +46,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
with suppress_logging():
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
query_texts=query,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -12,26 +10,10 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
import warnings
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
||||
level=logging.ERROR,
|
||||
):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
@@ -122,7 +104,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
with suppress_logging():
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
@@ -20,7 +21,9 @@ from crewai.rag.chromadb.utils import (
|
||||
_is_sync_client,
|
||||
_prepare_documents_for_chromadb,
|
||||
_process_query_results,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
@@ -97,7 +100,7 @@ class ChromaDBClient(BaseClient):
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
@@ -154,7 +157,7 @@ class ChromaDBClient(BaseClient):
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
await self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
@@ -205,7 +208,7 @@ class ChromaDBClient(BaseClient):
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
@@ -258,7 +261,7 @@ class ChromaDBClient(BaseClient):
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return await self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
@@ -298,12 +301,12 @@ class ChromaDBClient(BaseClient):
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=collection_name,
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
collection.add(
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
@@ -340,11 +343,11 @@ class ChromaDBClient(BaseClient):
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=collection_name,
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
await collection.add(
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
@@ -385,19 +388,22 @@ class ChromaDBClient(BaseClient):
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
@@ -440,19 +446,22 @@ class ChromaDBClient(BaseClient):
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
@@ -485,7 +494,7 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=collection_name)
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
@@ -515,7 +524,9 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(name=collection_name)
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
@@ -23,6 +23,12 @@ warnings.filterwarnings(
|
||||
module="pydantic._internal._generate_schema",
|
||||
)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> Settings:
|
||||
"""Create default ChromaDB settings.
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
|
||||
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
|
||||
|
||||
MIN_COLLECTION_LENGTH: Final[int] = 3
|
||||
MAX_COLLECTION_LENGTH: Final[int] = 63
|
||||
DEFAULT_COLLECTION: Final[str] = "default_collection"
|
||||
|
||||
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
from chromadb import Client
|
||||
import os
|
||||
from hashlib import md5
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
@@ -14,11 +17,24 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
|
||||
Notes:
|
||||
Need to update to use chromadb.Client to support more client types in the near future.
|
||||
"""
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
tenant=config.tenant,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
return ChromaDBClient(
|
||||
client=Client(
|
||||
settings=config.settings, tenant=config.tenant, database=config.database
|
||||
),
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,13 @@ from chromadb.api.types import (
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_COLLECTION,
|
||||
INVALID_CHARS_PATTERN,
|
||||
IPV4_PATTERN,
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
)
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
@@ -216,3 +223,58 @@ def _process_query_results(
|
||||
distance_metric=distance_metric,
|
||||
score_threshold=params.score_threshold,
|
||||
)
|
||||
|
||||
|
||||
def _is_ipv4_pattern(name: str) -> bool:
|
||||
"""Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def _sanitize_collection_name(
|
||||
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||
) -> str:
|
||||
"""Sanitize a collection name to meet ChromaDB requirements.
|
||||
|
||||
Requirements:
|
||||
1. 3-63 characters long
|
||||
2. Starts and ends with alphanumeric character
|
||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||
4. No consecutive periods
|
||||
5. Not a valid IPv4 address
|
||||
|
||||
Args:
|
||||
name: The original collection name to sanitize
|
||||
max_collection_length: Maximum allowed length for the collection name
|
||||
|
||||
Returns:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
if _is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
if not sanitized[0].isalnum():
|
||||
sanitized = "a" + sanitized
|
||||
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
|
||||
38
src/crewai/utilities/logger_utils.py
Normal file
38
src/crewai/utilities/logger_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Logging utility functions for CrewAI."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name: str,
|
||||
level: int | str,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Suppress verbose logging output from specified logger.
|
||||
|
||||
Commonly used to suppress ChromaDB's verbose HNSW index logging.
|
||||
|
||||
Args:
|
||||
logger_name: The logger to suppress
|
||||
level: The minimum level to allow (e.g., logging.ERROR or "ERROR")
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Example:
|
||||
with suppress_logging("chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR):
|
||||
collection.query(query_texts=["test"])
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
@@ -253,8 +253,8 @@ class TestChromaDBClient:
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.add.assert_called_once()
|
||||
call_args = mock_collection.add.call_args
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
@@ -279,7 +279,7 @@ class TestChromaDBClient:
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.add.assert_called_once_with(
|
||||
mock_collection.upsert.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
@@ -319,8 +319,8 @@ class TestChromaDBClient:
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.add.assert_called_once()
|
||||
call_args = mock_collection.add.call_args
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
@@ -352,7 +352,7 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_collection.add.assert_called_once_with(
|
||||
mock_collection.upsert.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
|
||||
Reference in New Issue
Block a user