mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
refactor: consolidate ChromaDB response extraction logic
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from chromadb import EmbeddingFunction
|
from chromadb import EmbeddingFunction
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
@@ -16,14 +16,49 @@ from crewai.utilities.logger_utils import suppress_logging
|
|||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_chromadb_response_item(
|
||||||
|
response_data: Any,
|
||||||
|
index: int,
|
||||||
|
expected_type: type[Any] | tuple[type[Any], ...],
|
||||||
|
) -> Any | None:
|
||||||
|
"""Extract an item from ChromaDB response data at the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||||
|
index: The index of the item to extract.
|
||||||
|
expected_type: The expected type(s) of the item.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted item if it exists and matches the expected type, otherwise None.
|
||||||
|
"""
|
||||||
|
if response_data is None or not response_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ChromaDB sometimes returns nested lists, handle both cases
|
||||||
|
data_list = (
|
||||||
|
response_data[0]
|
||||||
|
if response_data and isinstance(response_data[0], list)
|
||||||
|
else response_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if index < len(data_list):
|
||||||
|
item = data_list[index]
|
||||||
|
if isinstance(item, expected_type):
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class RAGStorage(BaseRAGStorage):
|
class RAGStorage(BaseRAGStorage):
|
||||||
"""
|
"""
|
||||||
Extends Storage to handle embeddings for memory entries, improving
|
Extends Storage to handle embeddings for memory entries, improving
|
||||||
search efficiency.
|
search efficiency.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: Add type hints to EmbeddingFunction in next typing PR.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
embedder_config: EmbeddingFunction[Any] | None = None
|
embedder_config: EmbeddingFunction[Any] | None = None # type: ignore
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -31,7 +66,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: Any = None,
|
embedder_config: Any = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
path: Optional[str] = None,
|
path: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(type, allow_reset, embedder_config, crew)
|
super().__init__(type, allow_reset, embedder_config, crew)
|
||||||
agents = crew.agents if crew else []
|
agents = crew.agents if crew else []
|
||||||
@@ -49,14 +84,19 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
def _set_embedder_config(self) -> None:
|
def _set_embedder_config(self) -> None:
|
||||||
configurator = EmbeddingConfigurator()
|
"""Sets the embedder_config using EmbeddingConfigurator.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: remove the type: ignore on next typing pr.
|
||||||
|
"""
|
||||||
|
configurator = EmbeddingConfigurator() # type: ignore
|
||||||
# Pass the original embedder_config from __init__, not self.embedder_config
|
# Pass the original embedder_config from __init__, not self.embedder_config
|
||||||
if hasattr(self, "_original_embedder_config"):
|
if hasattr(self, "_original_embedder_config"):
|
||||||
self.embedder_config = configurator.configure_embedder(
|
self.embedder_config = configurator.configure_embedder(
|
||||||
self._original_embedder_config
|
self._original_embedder_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embedder_config = configurator.configure_embedder(None)
|
self.embedder_config = configurator.configure_embedder()
|
||||||
|
|
||||||
def _initialize_app(self) -> None:
|
def _initialize_app(self) -> None:
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
@@ -87,7 +127,8 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||||
|
|
||||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
@staticmethod
|
||||||
|
def _build_storage_file_name(type: str, file_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Ensures file name does not exceed max allowed by OS
|
Ensures file name does not exceed max allowed by OS
|
||||||
"""
|
"""
|
||||||
@@ -113,7 +154,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
filter: Optional[dict[str, Any]] = None,
|
filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
if not hasattr(self, "app"):
|
if not hasattr(self, "app"):
|
||||||
@@ -139,37 +180,22 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
)
|
)
|
||||||
for i in range(len(ids_list)):
|
for i in range(len(ids_list)):
|
||||||
# Handle metadatas
|
# Handle metadatas
|
||||||
metadata = {}
|
meta_item = _extract_chromadb_response_item(
|
||||||
if response.get("metadatas") and len(response["metadatas"]) > 0:
|
response.get("metadatas"), i, dict
|
||||||
metadata_list = (
|
)
|
||||||
response["metadatas"][0]
|
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||||
if isinstance(response["metadatas"][0], list)
|
|
||||||
else response["metadatas"]
|
|
||||||
)
|
|
||||||
if i < len(metadata_list):
|
|
||||||
metadata = metadata_list[i]
|
|
||||||
|
|
||||||
# Handle documents
|
# Handle documents
|
||||||
context = ""
|
doc_item = _extract_chromadb_response_item(
|
||||||
if response.get("documents") and len(response["documents"]) > 0:
|
response.get("documents"), i, str
|
||||||
docs_list = (
|
)
|
||||||
response["documents"][0]
|
context = doc_item if doc_item else ""
|
||||||
if isinstance(response["documents"][0], list)
|
|
||||||
else response["documents"]
|
|
||||||
)
|
|
||||||
if i < len(docs_list):
|
|
||||||
context = docs_list[i]
|
|
||||||
|
|
||||||
# Handle distances
|
# Handle distances
|
||||||
score = 1.0
|
dist_item = _extract_chromadb_response_item(
|
||||||
if response.get("distances") and len(response["distances"]) > 0:
|
response.get("distances"), i, (int, float)
|
||||||
dist_list = (
|
)
|
||||||
response["distances"][0]
|
score = dist_item if dist_item is not None else 1.0
|
||||||
if isinstance(response["distances"][0], list)
|
|
||||||
else response["distances"]
|
|
||||||
)
|
|
||||||
if i < len(dist_list):
|
|
||||||
score = dist_list[i]
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"id": ids_list[i],
|
"id": ids_list[i],
|
||||||
@@ -187,11 +213,22 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
|
def _generate_embedding(
|
||||||
|
self, text: str, metadata: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""Generates and stores the embedding for the given text and metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to generate an embedding for.
|
||||||
|
metadata: Optional metadata associated with the text.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Need to constrain the typing in the base class, this result isn't used
|
||||||
|
"""
|
||||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
self.collection.add(
|
return self.collection.add(
|
||||||
documents=[text],
|
documents=[text],
|
||||||
metadatas=[metadata or {}],
|
metadatas=[metadata or {}],
|
||||||
ids=[str(uuid.uuid4())],
|
ids=[str(uuid.uuid4())],
|
||||||
@@ -213,7 +250,8 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_default_embedding_function(self) -> EmbeddingFunction[Any]:
|
@staticmethod
|
||||||
|
def _create_default_embedding_function() -> EmbeddingFunction[Any]:
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user