mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: achieve parity between rag package and current impl (#3418)
- Sanitize ChromaDB collection names and use original dir naming - Add persistent client with file locking to the ChromaDB factory - Add upsert support to the ChromaDB client - Suppress ChromaDB deprecation warnings for `model_fields` - Extract `suppress_logging` into shared `logger_utils` - Update tests to reflect upsert behavior - Docs: add additional note
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
import contextlib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -20,23 +18,7 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
|||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
from crewai.utilities.chromadb import create_persistent_client
|
from crewai.utilities.chromadb import create_persistent_client
|
||||||
|
from crewai.utilities.logger_utils import suppress_logging
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def suppress_logging(
|
|
||||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
|
||||||
level=logging.ERROR,
|
|
||||||
):
|
|
||||||
logger = logging.getLogger(logger_name)
|
|
||||||
original_level = logger.getEffectiveLevel()
|
|
||||||
logger.setLevel(level)
|
|
||||||
with (
|
|
||||||
contextlib.redirect_stdout(io.StringIO()),
|
|
||||||
contextlib.redirect_stderr(io.StringIO()),
|
|
||||||
contextlib.suppress(UserWarning),
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
logger.setLevel(original_level)
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||||
@@ -64,7 +46,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
with suppress_logging():
|
with suppress_logging(
|
||||||
|
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||||
|
):
|
||||||
if self.collection:
|
if self.collection:
|
||||||
fetched = self.collection.query(
|
fetched = self.collection.query(
|
||||||
query_texts=query,
|
query_texts=query,
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -12,26 +10,10 @@ from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
|||||||
from crewai.utilities.chromadb import create_persistent_client
|
from crewai.utilities.chromadb import create_persistent_client
|
||||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
from crewai.utilities.logger_utils import suppress_logging
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def suppress_logging(
|
|
||||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
|
||||||
level=logging.ERROR,
|
|
||||||
):
|
|
||||||
logger = logging.getLogger(logger_name)
|
|
||||||
original_level = logger.getEffectiveLevel()
|
|
||||||
logger.setLevel(level)
|
|
||||||
with (
|
|
||||||
contextlib.redirect_stdout(io.StringIO()),
|
|
||||||
contextlib.redirect_stderr(io.StringIO()),
|
|
||||||
contextlib.suppress(UserWarning),
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
logger.setLevel(original_level)
|
|
||||||
|
|
||||||
|
|
||||||
class RAGStorage(BaseRAGStorage):
|
class RAGStorage(BaseRAGStorage):
|
||||||
"""
|
"""
|
||||||
Extends Storage to handle embeddings for memory entries, improving
|
Extends Storage to handle embeddings for memory entries, improving
|
||||||
@@ -122,7 +104,9 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with suppress_logging():
|
with suppress_logging(
|
||||||
|
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||||
|
):
|
||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(query_texts=query, n_results=limit)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""ChromaDB client implementation."""
|
"""ChromaDB client implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from chromadb.api.types import (
|
from chromadb.api.types import (
|
||||||
@@ -20,7 +21,9 @@ from crewai.rag.chromadb.utils import (
|
|||||||
_is_sync_client,
|
_is_sync_client,
|
||||||
_prepare_documents_for_chromadb,
|
_prepare_documents_for_chromadb,
|
||||||
_process_query_results,
|
_process_query_results,
|
||||||
|
_sanitize_collection_name,
|
||||||
)
|
)
|
||||||
|
from crewai.utilities.logger_utils import suppress_logging
|
||||||
from crewai.rag.core.base_client import (
|
from crewai.rag.core.base_client import (
|
||||||
BaseClient,
|
BaseClient,
|
||||||
BaseCollectionParams,
|
BaseCollectionParams,
|
||||||
@@ -97,7 +100,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
metadata["hnsw:space"] = "cosine"
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
self.client.create_collection(
|
self.client.create_collection(
|
||||||
name=kwargs["collection_name"],
|
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||||
configuration=kwargs.get("configuration"),
|
configuration=kwargs.get("configuration"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
embedding_function=kwargs.get(
|
embedding_function=kwargs.get(
|
||||||
@@ -154,7 +157,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
metadata["hnsw:space"] = "cosine"
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
await self.client.create_collection(
|
await self.client.create_collection(
|
||||||
name=kwargs["collection_name"],
|
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||||
configuration=kwargs.get("configuration"),
|
configuration=kwargs.get("configuration"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
embedding_function=kwargs.get(
|
embedding_function=kwargs.get(
|
||||||
@@ -205,7 +208,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
metadata["hnsw:space"] = "cosine"
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
return self.client.get_or_create_collection(
|
return self.client.get_or_create_collection(
|
||||||
name=kwargs["collection_name"],
|
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||||
configuration=kwargs.get("configuration"),
|
configuration=kwargs.get("configuration"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
embedding_function=kwargs.get(
|
embedding_function=kwargs.get(
|
||||||
@@ -258,7 +261,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
metadata["hnsw:space"] = "cosine"
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
return await self.client.get_or_create_collection(
|
return await self.client.get_or_create_collection(
|
||||||
name=kwargs["collection_name"],
|
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||||
configuration=kwargs.get("configuration"),
|
configuration=kwargs.get("configuration"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
embedding_function=kwargs.get(
|
embedding_function=kwargs.get(
|
||||||
@@ -298,12 +301,12 @@ class ChromaDBClient(BaseClient):
|
|||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
collection = self.client.get_collection(
|
collection = self.client.get_collection(
|
||||||
name=collection_name,
|
name=_sanitize_collection_name(collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
prepared = _prepare_documents_for_chromadb(documents)
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
collection.add(
|
collection.upsert(
|
||||||
ids=prepared.ids,
|
ids=prepared.ids,
|
||||||
documents=prepared.texts,
|
documents=prepared.texts,
|
||||||
metadatas=prepared.metadatas,
|
metadatas=prepared.metadatas,
|
||||||
@@ -340,11 +343,11 @@ class ChromaDBClient(BaseClient):
|
|||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
collection = await self.client.get_collection(
|
collection = await self.client.get_collection(
|
||||||
name=collection_name,
|
name=_sanitize_collection_name(collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
prepared = _prepare_documents_for_chromadb(documents)
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
await collection.add(
|
await collection.upsert(
|
||||||
ids=prepared.ids,
|
ids=prepared.ids,
|
||||||
documents=prepared.texts,
|
documents=prepared.texts,
|
||||||
metadatas=prepared.metadatas,
|
metadatas=prepared.metadatas,
|
||||||
@@ -385,19 +388,22 @@ class ChromaDBClient(BaseClient):
|
|||||||
params = _extract_search_params(kwargs)
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
collection = self.client.get_collection(
|
collection = self.client.get_collection(
|
||||||
name=params.collection_name,
|
name=_sanitize_collection_name(params.collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
where = params.where if params.where is not None else params.metadata_filter
|
where = params.where if params.where is not None else params.metadata_filter
|
||||||
|
|
||||||
results: QueryResult = collection.query(
|
with suppress_logging(
|
||||||
query_texts=[params.query],
|
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||||
n_results=params.limit,
|
):
|
||||||
where=where,
|
results: QueryResult = collection.query(
|
||||||
where_document=params.where_document,
|
query_texts=[params.query],
|
||||||
include=params.include,
|
n_results=params.limit,
|
||||||
)
|
where=where,
|
||||||
|
where_document=params.where_document,
|
||||||
|
include=params.include,
|
||||||
|
)
|
||||||
|
|
||||||
return _process_query_results(
|
return _process_query_results(
|
||||||
collection=collection,
|
collection=collection,
|
||||||
@@ -440,19 +446,22 @@ class ChromaDBClient(BaseClient):
|
|||||||
params = _extract_search_params(kwargs)
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
collection = await self.client.get_collection(
|
collection = await self.client.get_collection(
|
||||||
name=params.collection_name,
|
name=_sanitize_collection_name(params.collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
where = params.where if params.where is not None else params.metadata_filter
|
where = params.where if params.where is not None else params.metadata_filter
|
||||||
|
|
||||||
results: QueryResult = await collection.query(
|
with suppress_logging(
|
||||||
query_texts=[params.query],
|
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||||
n_results=params.limit,
|
):
|
||||||
where=where,
|
results: QueryResult = await collection.query(
|
||||||
where_document=params.where_document,
|
query_texts=[params.query],
|
||||||
include=params.include,
|
n_results=params.limit,
|
||||||
)
|
where=where,
|
||||||
|
where_document=params.where_document,
|
||||||
|
include=params.include,
|
||||||
|
)
|
||||||
|
|
||||||
return _process_query_results(
|
return _process_query_results(
|
||||||
collection=collection,
|
collection=collection,
|
||||||
@@ -485,7 +494,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
self.client.delete_collection(name=collection_name)
|
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||||
|
|
||||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||||
"""Delete a collection and all its data asynchronously.
|
"""Delete a collection and all its data asynchronously.
|
||||||
@@ -515,7 +524,9 @@ class ChromaDBClient(BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
await self.client.delete_collection(name=collection_name)
|
await self.client.delete_collection(
|
||||||
|
name=_sanitize_collection_name(collection_name)
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset the vector database by deleting all collections and data.
|
"""Reset the vector database by deleting all collections and data.
|
||||||
|
|||||||
@@ -23,6 +23,12 @@ warnings.filterwarnings(
|
|||||||
module="pydantic._internal._generate_schema",
|
module="pydantic._internal._generate_schema",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r".*'model_fields'.*is deprecated.*",
|
||||||
|
module=r"^chromadb(\.|$)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _default_settings() -> Settings:
|
def _default_settings() -> Settings:
|
||||||
"""Create default ChromaDB settings.
|
"""Create default ChromaDB settings.
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
"""Constants for ChromaDB configuration."""
|
"""Constants for ChromaDB configuration."""
|
||||||
|
|
||||||
import os
|
import re
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
|
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
|
||||||
|
|
||||||
|
MIN_COLLECTION_LENGTH: Final[int] = 3
|
||||||
|
MAX_COLLECTION_LENGTH: Final[int] = 63
|
||||||
|
DEFAULT_COLLECTION: Final[str] = "default_collection"
|
||||||
|
|
||||||
|
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
|
||||||
|
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
"""Factory functions for creating ChromaDB clients."""
|
"""Factory functions for creating ChromaDB clients."""
|
||||||
|
|
||||||
from chromadb import Client
|
import os
|
||||||
|
from hashlib import md5
|
||||||
|
import portalocker
|
||||||
|
from chromadb import PersistentClient
|
||||||
|
|
||||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
from crewai.rag.chromadb.client import ChromaDBClient
|
from crewai.rag.chromadb.client import ChromaDBClient
|
||||||
@@ -14,11 +17,24 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured ChromaDBClient instance.
|
Configured ChromaDBClient instance.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Need to update to use chromadb.Client to support more client types in the near future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
persist_dir = config.settings.persist_directory
|
||||||
|
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||||
|
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||||
|
|
||||||
|
with portalocker.Lock(lockfile):
|
||||||
|
client = PersistentClient(
|
||||||
|
path=persist_dir,
|
||||||
|
settings=config.settings,
|
||||||
|
tenant=config.tenant,
|
||||||
|
database=config.database,
|
||||||
|
)
|
||||||
|
|
||||||
return ChromaDBClient(
|
return ChromaDBClient(
|
||||||
client=Client(
|
client=client,
|
||||||
settings=config.settings, tenant=config.tenant, database=config.database
|
|
||||||
),
|
|
||||||
embedding_function=config.embedding_function,
|
embedding_function=config.embedding_function,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,13 @@ from chromadb.api.types import (
|
|||||||
)
|
)
|
||||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||||
from chromadb.api.models.Collection import Collection
|
from chromadb.api.models.Collection import Collection
|
||||||
|
from crewai.rag.chromadb.constants import (
|
||||||
|
DEFAULT_COLLECTION,
|
||||||
|
INVALID_CHARS_PATTERN,
|
||||||
|
IPV4_PATTERN,
|
||||||
|
MAX_COLLECTION_LENGTH,
|
||||||
|
MIN_COLLECTION_LENGTH,
|
||||||
|
)
|
||||||
from crewai.rag.chromadb.types import (
|
from crewai.rag.chromadb.types import (
|
||||||
ChromaDBClientType,
|
ChromaDBClientType,
|
||||||
ChromaDBCollectionSearchParams,
|
ChromaDBCollectionSearchParams,
|
||||||
@@ -216,3 +223,58 @@ def _process_query_results(
|
|||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
score_threshold=params.score_threshold,
|
score_threshold=params.score_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ipv4_pattern(name: str) -> bool:
|
||||||
|
"""Check if a string matches an IPv4 address pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the string matches an IPv4 pattern, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(IPV4_PATTERN.match(name))
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_collection_name(
|
||||||
|
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||||
|
) -> str:
|
||||||
|
"""Sanitize a collection name to meet ChromaDB requirements.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
1. 3-63 characters long
|
||||||
|
2. Starts and ends with alphanumeric character
|
||||||
|
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||||
|
4. No consecutive periods
|
||||||
|
5. Not a valid IPv4 address
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The original collection name to sanitize
|
||||||
|
max_collection_length: Maximum allowed length for the collection name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sanitized collection name that meets ChromaDB requirements
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return DEFAULT_COLLECTION
|
||||||
|
|
||||||
|
if _is_ipv4_pattern(name):
|
||||||
|
name = f"ip_{name}"
|
||||||
|
|
||||||
|
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||||
|
|
||||||
|
if not sanitized[0].isalnum():
|
||||||
|
sanitized = "a" + sanitized
|
||||||
|
|
||||||
|
if not sanitized[-1].isalnum():
|
||||||
|
sanitized = sanitized[:-1] + "z"
|
||||||
|
|
||||||
|
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||||
|
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||||
|
if len(sanitized) > max_collection_length:
|
||||||
|
sanitized = sanitized[:max_collection_length]
|
||||||
|
if not sanitized[-1].isalnum():
|
||||||
|
sanitized = sanitized[:-1] + "z"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|||||||
38
src/crewai/utilities/logger_utils.py
Normal file
38
src/crewai/utilities/logger_utils.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Logging utility functions for CrewAI."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def suppress_logging(
|
||||||
|
logger_name: str,
|
||||||
|
level: int | str,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Suppress verbose logging output from specified logger.
|
||||||
|
|
||||||
|
Commonly used to suppress ChromaDB's verbose HNSW index logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger_name: The logger to suppress
|
||||||
|
level: The minimum level to allow (e.g., logging.ERROR or "ERROR")
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with suppress_logging("chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR):
|
||||||
|
collection.query(query_texts=["test"])
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
original_level = logger.getEffectiveLevel()
|
||||||
|
logger.setLevel(level)
|
||||||
|
with (
|
||||||
|
contextlib.redirect_stdout(io.StringIO()),
|
||||||
|
contextlib.redirect_stderr(io.StringIO()),
|
||||||
|
contextlib.suppress(UserWarning),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
logger.setLevel(original_level)
|
||||||
@@ -253,8 +253,8 @@ class TestChromaDBClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify documents were added to collection
|
# Verify documents were added to collection
|
||||||
mock_collection.add.assert_called_once()
|
mock_collection.upsert.assert_called_once()
|
||||||
call_args = mock_collection.add.call_args
|
call_args = mock_collection.upsert.call_args
|
||||||
assert len(call_args.kwargs["ids"]) == 1
|
assert len(call_args.kwargs["ids"]) == 1
|
||||||
assert call_args.kwargs["documents"] == ["Test document"]
|
assert call_args.kwargs["documents"] == ["Test document"]
|
||||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||||
@@ -279,7 +279,7 @@ class TestChromaDBClient:
|
|||||||
|
|
||||||
client.add_documents(collection_name="test_collection", documents=documents)
|
client.add_documents(collection_name="test_collection", documents=documents)
|
||||||
|
|
||||||
mock_collection.add.assert_called_once_with(
|
mock_collection.upsert.assert_called_once_with(
|
||||||
ids=["custom_id_1", "custom_id_2"],
|
ids=["custom_id_1", "custom_id_2"],
|
||||||
documents=["First document", "Second document"],
|
documents=["First document", "Second document"],
|
||||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||||
@@ -319,8 +319,8 @@ class TestChromaDBClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify documents were added to collection
|
# Verify documents were added to collection
|
||||||
mock_collection.add.assert_called_once()
|
mock_collection.upsert.assert_called_once()
|
||||||
call_args = mock_collection.add.call_args
|
call_args = mock_collection.upsert.call_args
|
||||||
assert len(call_args.kwargs["ids"]) == 1
|
assert len(call_args.kwargs["ids"]) == 1
|
||||||
assert call_args.kwargs["documents"] == ["Test document"]
|
assert call_args.kwargs["documents"] == ["Test document"]
|
||||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||||
@@ -352,7 +352,7 @@ class TestChromaDBClient:
|
|||||||
collection_name="test_collection", documents=documents
|
collection_name="test_collection", documents=documents
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_collection.add.assert_called_once_with(
|
mock_collection.upsert.assert_called_once_with(
|
||||||
ids=["custom_id_1", "custom_id_2"],
|
ids=["custom_id_1", "custom_id_2"],
|
||||||
documents=["First document", "Second document"],
|
documents=["First document", "Second document"],
|
||||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||||
|
|||||||
Reference in New Issue
Block a user