refactor: consolidate ChromaDB response extraction logic

This commit is contained in:
Greyson LaLonde
2025-09-04 15:21:48 -04:00
parent 4812986f58
commit 221bfcccce

View File

@@ -3,7 +3,7 @@ import os
import shutil import shutil
import uuid import uuid
import warnings import warnings
from typing import Any, Optional from typing import Any
from chromadb import EmbeddingFunction from chromadb import EmbeddingFunction
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
@@ -16,14 +16,49 @@ from crewai.utilities.logger_utils import suppress_logging
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
def _extract_chromadb_response_item(
response_data: Any,
index: int,
expected_type: type[Any] | tuple[type[Any], ...],
) -> Any | None:
"""Extract an item from ChromaDB response data at the given index.
Args:
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
index: The index of the item to extract.
expected_type: The expected type(s) of the item.
Returns:
The extracted item if it exists and matches the expected type, otherwise None.
"""
if response_data is None or not response_data:
return None
# ChromaDB sometimes returns nested lists, handle both cases
data_list = (
response_data[0]
if response_data and isinstance(response_data[0], list)
else response_data
)
if index < len(data_list):
item = data_list[index]
if isinstance(item, expected_type):
return item
return None
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
""" """
Extends Storage to handle embeddings for memory entries, improving Extends Storage to handle embeddings for memory entries, improving
search efficiency. search efficiency.
Notes:
- TODO: Add type hints to EmbeddingFunction in next typing PR.
""" """
app: ClientAPI | None = None app: ClientAPI | None = None
embedder_config: EmbeddingFunction[Any] | None = None embedder_config: EmbeddingFunction[Any] | None = None # type: ignore
def __init__( def __init__(
self, self,
@@ -31,7 +66,7 @@ class RAGStorage(BaseRAGStorage):
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Any = None, embedder_config: Any = None,
crew: Any = None, crew: Any = None,
path: Optional[str] = None, path: str | None = None,
) -> None: ) -> None:
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
@@ -49,14 +84,19 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app() self._initialize_app()
def _set_embedder_config(self) -> None: def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator() """Sets the embedder_config using EmbeddingConfigurator.
Notes:
- TODO: remove the type: ignore on next typing pr.
"""
configurator = EmbeddingConfigurator() # type: ignore
# Pass the original embedder_config from __init__, not self.embedder_config # Pass the original embedder_config from __init__, not self.embedder_config
if hasattr(self, "_original_embedder_config"): if hasattr(self, "_original_embedder_config"):
self.embedder_config = configurator.configure_embedder( self.embedder_config = configurator.configure_embedder(
self._original_embedder_config self._original_embedder_config
) )
else: else:
self.embedder_config = configurator.configure_embedder(None) self.embedder_config = configurator.configure_embedder()
def _initialize_app(self) -> None: def _initialize_app(self) -> None:
from chromadb.config import Settings from chromadb.config import Settings
@@ -87,7 +127,8 @@ class RAGStorage(BaseRAGStorage):
""" """
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str: @staticmethod
def _build_storage_file_name(type: str, file_name: str) -> str:
""" """
Ensures file name does not exceed max allowed by OS Ensures file name does not exceed max allowed by OS
""" """
@@ -113,7 +154,7 @@ class RAGStorage(BaseRAGStorage):
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict[str, Any]] = None, filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> list[Any]: ) -> list[Any]:
if not hasattr(self, "app"): if not hasattr(self, "app"):
@@ -139,37 +180,22 @@ class RAGStorage(BaseRAGStorage):
) )
for i in range(len(ids_list)): for i in range(len(ids_list)):
# Handle metadatas # Handle metadatas
metadata = {} meta_item = _extract_chromadb_response_item(
if response.get("metadatas") and len(response["metadatas"]) > 0: response.get("metadatas"), i, dict
metadata_list = ( )
response["metadatas"][0] metadata: dict[str, Any] = meta_item if meta_item else {}
if isinstance(response["metadatas"][0], list)
else response["metadatas"]
)
if i < len(metadata_list):
metadata = metadata_list[i]
# Handle documents # Handle documents
context = "" doc_item = _extract_chromadb_response_item(
if response.get("documents") and len(response["documents"]) > 0: response.get("documents"), i, str
docs_list = ( )
response["documents"][0] context = doc_item if doc_item else ""
if isinstance(response["documents"][0], list)
else response["documents"]
)
if i < len(docs_list):
context = docs_list[i]
# Handle distances # Handle distances
score = 1.0 dist_item = _extract_chromadb_response_item(
if response.get("distances") and len(response["distances"]) > 0: response.get("distances"), i, (int, float)
dist_list = ( )
response["distances"][0] score = dist_item if dist_item is not None else 1.0
if isinstance(response["distances"][0], list)
else response["distances"]
)
if i < len(dist_list):
score = dist_list[i]
result = { result = {
"id": ids_list[i], "id": ids_list[i],
@@ -187,11 +213,22 @@ class RAGStorage(BaseRAGStorage):
logging.error(f"Error during {self.type} search: {str(e)}") logging.error(f"Error during {self.type} search: {str(e)}")
return [] return []
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore def _generate_embedding(
self, text: str, metadata: dict[str, Any] | None = None
) -> Any:
"""Generates and stores the embedding for the given text and metadata.
Args:
text: The text to generate an embedding for.
metadata: Optional metadata associated with the text.
Notes:
- Need to constrain the typing in the base class, this result isn't used
"""
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
self.collection.add( return self.collection.add(
documents=[text], documents=[text],
metadatas=[metadata or {}], metadatas=[metadata or {}],
ids=[str(uuid.uuid4())], ids=[str(uuid.uuid4())],
@@ -213,7 +250,8 @@ class RAGStorage(BaseRAGStorage):
f"An error occurred while resetting the {self.type} memory: {e}" f"An error occurred while resetting the {self.type} memory: {e}"
) )
def _create_default_embedding_function(self) -> EmbeddingFunction[Any]: @staticmethod
def _create_default_embedding_function() -> EmbeddingFunction[Any]:
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )