diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index a41c231e8..7ecb5cafc 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -3,7 +3,7 @@ import os import shutil import uuid import warnings -from typing import Any, Optional +from typing import Any from chromadb import EmbeddingFunction from chromadb.api import ClientAPI @@ -16,14 +16,49 @@ from crewai.utilities.logger_utils import suppress_logging from crewai.utilities.paths import db_storage_path +def _extract_chromadb_response_item( + response_data: Any, + index: int, + expected_type: type[Any] | tuple[type[Any], ...], +) -> Any | None: + """Extract an item from ChromaDB response data at the given index. + + Args: + response_data: The response data from ChromaDB query (e.g., documents, metadatas). + index: The index of the item to extract. + expected_type: The expected type(s) of the item. + + Returns: + The extracted item if it exists and matches the expected type, otherwise None. + """ + if response_data is None or not response_data: + return None + + # ChromaDB sometimes returns nested lists, handle both cases + data_list = ( + response_data[0] + if response_data and isinstance(response_data[0], list) + else response_data + ) + + if index < len(data_list): + item = data_list[index] + if isinstance(item, expected_type): + return item + return None + + class RAGStorage(BaseRAGStorage): """ Extends Storage to handle embeddings for memory entries, improving search efficiency. + + Notes: + - TODO: Add type hints to EmbeddingFunction in next typing PR. """ app: ClientAPI | None = None - embedder_config: EmbeddingFunction[Any] | None = None + embedder_config: EmbeddingFunction[Any] | None = None # type: ignore def __init__( self, @@ -31,7 +66,7 @@ class RAGStorage(BaseRAGStorage): allow_reset: bool = True, embedder_config: Any = None, crew: Any = None, - path: Optional[str] = None, + path: str | None = None, ) -> None: super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] @@ -49,14 +84,19 @@ class RAGStorage(BaseRAGStorage): self._initialize_app() def _set_embedder_config(self) -> None: - configurator = EmbeddingConfigurator() + """Sets the embedder_config using EmbeddingConfigurator. + + Notes: + - TODO: remove the type: ignore on next typing pr. + """ + configurator = EmbeddingConfigurator() # type: ignore # Pass the original embedder_config from __init__, not self.embedder_config if hasattr(self, "_original_embedder_config"): self.embedder_config = configurator.configure_embedder( self._original_embedder_config ) else: - self.embedder_config = configurator.configure_embedder(None) + self.embedder_config = configurator.configure_embedder() def _initialize_app(self) -> None: from chromadb.config import Settings @@ -87,7 +127,8 @@ class RAGStorage(BaseRAGStorage): """ return role.replace("\n", "").replace(" ", "_").replace("/", "_") - def _build_storage_file_name(self, type: str, file_name: str) -> str: + @staticmethod + def _build_storage_file_name(type: str, file_name: str) -> str: """ Ensures file name does not exceed max allowed by OS """ @@ -113,7 +154,7 @@ class RAGStorage(BaseRAGStorage): self, query: str, limit: int = 3, - filter: Optional[dict[str, Any]] = None, + filter: dict[str, Any] | None = None, score_threshold: float = 0.35, ) -> list[Any]: if not hasattr(self, "app"): @@ -139,37 +180,22 @@ class RAGStorage(BaseRAGStorage): ) for i in range(len(ids_list)): # Handle metadatas - metadata = {} - if response.get("metadatas") and len(response["metadatas"]) > 0: - metadata_list = ( - response["metadatas"][0] - if isinstance(response["metadatas"][0], list) - else response["metadatas"] - ) - if i < len(metadata_list): - metadata = metadata_list[i] + meta_item = _extract_chromadb_response_item( + response.get("metadatas"), i, dict + ) + metadata: dict[str, Any] = meta_item if meta_item else {} # Handle documents - context = "" - if response.get("documents") and len(response["documents"]) > 0: - docs_list = ( - response["documents"][0] - if isinstance(response["documents"][0], list) - else response["documents"] - ) - if i < len(docs_list): - context = docs_list[i] + doc_item = _extract_chromadb_response_item( + response.get("documents"), i, str + ) + context = doc_item if doc_item else "" # Handle distances - score = 1.0 - if response.get("distances") and len(response["distances"]) > 0: - dist_list = ( - response["distances"][0] - if isinstance(response["distances"][0], list) - else response["distances"] - ) - if i < len(dist_list): - score = dist_list[i] + dist_item = _extract_chromadb_response_item( + response.get("distances"), i, (int, float) + ) + score = dist_item if dist_item is not None else 1.0 result = { "id": ids_list[i], @@ -187,11 +213,22 @@ class RAGStorage(BaseRAGStorage): logging.error(f"Error during {self.type} search: {str(e)}") return [] - def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore + def _generate_embedding( + self, text: str, metadata: dict[str, Any] | None = None + ) -> Any: + """Generates and stores the embedding for the given text and metadata. + + Args: + text: The text to generate an embedding for. + metadata: Optional metadata associated with the text. + + Notes: + - Need to constrain the typing in the base class, this result isn't used + """ if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() - self.collection.add( + return self.collection.add( documents=[text], metadatas=[metadata or {}], ids=[str(uuid.uuid4())], @@ -213,7 +250,8 @@ class RAGStorage(BaseRAGStorage): f"An error occurred while resetting the {self.type} memory: {e}" ) - def _create_default_embedding_function(self) -> EmbeddingFunction[Any]: + @staticmethod + def _create_default_embedding_function() -> EmbeddingFunction[Any]: from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, )