diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index db84a90f7..7e22e5a94 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -4,7 +4,7 @@ import io import logging import os import shutil -from typing import Any, Dict, List, Optional, Union, cast, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast # Type checking imports that don't cause runtime imports if TYPE_CHECKING: diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 14ae63440..ddb2c021f 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,15 +4,18 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional # Type checking imports that don't cause runtime imports if TYPE_CHECKING: + import chromadb from chromadb.api import ClientAPI + from chromadb.config import Settings from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator -from crewai.utilities.constants import MAX_FILE_NAME_LENGTH +from crewai.utilities.chromadb import sanitize_collection_name +from crewai.utilities.logger import Logger from crewai.utilities.paths import db_storage_path @@ -39,77 +42,45 @@ class RAGStorage(BaseRAGStorage): search efficiency. """ + collection: Optional[Any] = None + collection_name: Optional[str] = "memory" app: Optional[Any] = None def __init__( - self, type, allow_reset=True, embedder_config=None, crew=None, path=None + self, + type: str = "memory", + allow_reset: bool = True, + embedder_config: Optional[Dict[str, Any]] = None, + crew: Any = None, + collection_name: Optional[str] = None, ): super().__init__(type, allow_reset, embedder_config, crew) - agents = crew.agents if crew else [] - agents = [self._sanitize_role(agent.role) for agent in agents] - agents = "_".join(agents) - self.agents = agents - self.storage_file_name = self._build_storage_file_name(type, agents) + self.collection_name = collection_name or type + self._set_embedder_config(embedder_config) - self.type = type + def save( + self, + value: Any, + metadata: Dict[str, Any], + ) -> None: + with suppress_logging(): + if not self.collection: + self._initialize_app() - self.allow_reset = allow_reset - self.path = path - self._initialize_app() + if isinstance(value, list): + documents = value + metadatas = [metadata] * len(value) if metadata else None + ids = [str(uuid.uuid4()) for _ in range(len(documents))] + else: + documents = [value] + metadatas = [metadata] if metadata else None + ids = [str(uuid.uuid4())] - def _set_embedder_config(self): - configurator = EmbeddingConfigurator() - self.embedder_config = configurator.configure_embedder(self.embedder_config) - - def _initialize_app(self): - # Import chromadb here to avoid importing at module level - import chromadb - from chromadb.config import Settings - - self._set_embedder_config() - chroma_client = chromadb.PersistentClient( - path=self.path if self.path else self.storage_file_name, - settings=Settings(allow_reset=self.allow_reset), - ) - - self.app = chroma_client - - try: - self.collection = self.app.get_collection( - name=self.type, embedding_function=self.embedder_config + self.collection.add( + documents=documents, + metadatas=metadatas, + ids=ids, ) - except Exception: - self.collection = self.app.create_collection( - name=self.type, embedding_function=self.embedder_config - ) - - def _sanitize_role(self, role: str) -> str: - """ - Sanitizes agent roles to ensure valid directory names. - """ - return role.replace("\n", "").replace(" ", "_").replace("/", "_") - - def _build_storage_file_name(self, type: str, file_name: str) -> str: - """ - Ensures file name does not exceed max allowed by OS - """ - base_path = f"{db_storage_path()}/{type}" - - if len(file_name) > MAX_FILE_NAME_LENGTH: - logging.warning( - f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters." - ) - file_name = file_name[:MAX_FILE_NAME_LENGTH] - - return f"{base_path}/{file_name}" - - def save(self, value: Any, metadata: Dict[str, Any]) -> None: - if not hasattr(self, "app") or not hasattr(self, "collection"): - self._initialize_app() - try: - self._generate_embedding(value, metadata) - except Exception as e: - logging.error(f"Error during {self.type} save: {str(e)}") def search( self, @@ -118,54 +89,96 @@ class RAGStorage(BaseRAGStorage): filter: Optional[dict] = None, score_threshold: float = 0.35, ) -> List[Any]: - if not hasattr(self, "app"): - self._initialize_app() + with suppress_logging(): + if not hasattr(self, "collection") or not self.collection: + self._initialize_app() - try: - with suppress_logging(): - response = self.collection.query(query_texts=query, n_results=limit) + if isinstance(query, str): + query = [query] + fetched = self.collection.query( + query_texts=query, + n_results=limit, + where=filter, + ) results = [] - for i in range(len(response["ids"][0])): + for i in range(len(fetched["ids"][0])): # type: ignore result = { - "id": response["ids"][0][i], - "metadata": response["metadatas"][0][i], - "context": response["documents"][0][i], - "score": response["distances"][0][i], + "id": fetched["ids"][0][i], # type: ignore + "metadata": fetched["metadatas"][0][i], # type: ignore + "context": fetched["documents"][0][i], # type: ignore + "score": fetched["distances"][0][i], # type: ignore } if result["score"] >= score_threshold: results.append(result) - return results - except Exception as e: - logging.error(f"Error during {self.type} search: {str(e)}") - return [] - def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore - if not hasattr(self, "app") or not hasattr(self, "collection"): + def _initialize_app(self): + # Import chromadb here to avoid importing at module level + import chromadb + from chromadb.config import Settings + + base_path = os.path.join(db_storage_path(), "memory") + chroma_client = chromadb.PersistentClient( + path=base_path, + settings=Settings(allow_reset=self.allow_reset), + ) + + self.app = chroma_client + + try: + collection_name = ( + f"memory_{self.collection_name}" + if self.collection_name + else "memory" + ) + if self.app: + self.collection = self.app.get_or_create_collection( + name=sanitize_collection_name(collection_name), + embedding_function=self.embedder, + ) + else: + raise Exception("Vector Database Client not initialized") + except Exception: + raise Exception("Failed to create or get collection") + + def initialize_rag_storage(self): + self._initialize_app() + + def reset(self) -> None: + # Import chromadb here to avoid importing at module level + import chromadb + from chromadb.config import Settings + + base_path = os.path.join(db_storage_path(), "memory") + if not self.app: + self.app = chromadb.PersistentClient( + path=base_path, + settings=Settings(allow_reset=True), + ) + + self.app.reset() + shutil.rmtree(base_path) + self.app = None + self.collection = None + + def _generate_embedding( + self, text: str, metadata: Optional[Dict[str, Any]] = None + ) -> Any: + if not hasattr(self, "collection") or not self.collection: self._initialize_app() + id = str(uuid.uuid4()) self.collection.add( documents=[text], metadatas=[metadata or {}], - ids=[str(uuid.uuid4())], + ids=[id], ) + return id - def reset(self) -> None: - try: - if self.app: - self.app.reset() - shutil.rmtree(f"{db_storage_path()}/{self.type}") - self.app = None - self.collection = None - except Exception as e: - if "attempt to write a readonly database" in str(e): - # Ignore this specific error - pass - else: - raise Exception( - f"An error occurred while resetting the {self.type} memory: {e}" - ) + def _sanitize_role(self, role: str) -> str: + """Sanitize role name for use in file names.""" + return role.lower().replace(" ", "_").replace("\n", "").replace("/", "_") def _create_default_embedding_function(self): from chromadb.utils.embedding_functions.openai_embedding_function import ( @@ -175,3 +188,20 @@ class RAGStorage(BaseRAGStorage): return OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" ) + + def _set_embedder_config(self, embedder_config: Optional[Dict[str, Any]] = None) -> None: + """Set the embedding configuration for the RAG storage. + + Args: + embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. + If None or empty, defaults to the default embedding function. + """ + self.embedder = ( + EmbeddingConfigurator().configure_embedder(embedder_config) + if embedder_config + else self._create_default_embedding_function() + ) + + def _build_storage_file_name(self, role_name: str) -> str: + """Build storage file name from role name.""" + return f"{self._sanitize_role(role_name)}_memory" diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 83720a958..cbaa954eb 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, cast, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, cast # Type checking imports that don't cause runtime imports if TYPE_CHECKING: @@ -189,7 +189,7 @@ class EmbeddingConfigurator: ) from e # Import chromadb types here to avoid importing at module level - from chromadb import Documents, Embeddings, EmbeddingFunction + from chromadb import Documents, EmbeddingFunction, Embeddings class WatsonEmbeddingFunction(EmbeddingFunction): def __call__(self, input: Documents) -> Embeddings: diff --git a/tests/utilities/test_embedding_configurator.py b/tests/utilities/test_embedding_configurator.py index f38d2a3a9..37a0f2356 100644 --- a/tests/utilities/test_embedding_configurator.py +++ b/tests/utilities/test_embedding_configurator.py @@ -1,8 +1,10 @@ import importlib import sys -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + import pytest + class TestEmbeddingConfiguratorImports: """Test that ChromaDB is not imported at module level."""