fix: Remove OpenAI dependency for memory reset when using alternative LLMs

- Add environment variables for default embedding provider
- Support Ollama as default embedding provider
- Add tests for memory reset with different providers
- Update documentation

Fixes #2023

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-05 10:56:01 +00:00
parent 8017ab2dfd
commit 649414805d
4 changed files with 133 additions and 26 deletions

View File

@@ -6,12 +6,17 @@ import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI from chromadb.api import ClientAPI, Collection
from chromadb.api.types import Documents, Embeddings, Metadatas
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingInitializationError
)
@contextlib.contextmanager @contextlib.contextmanager
@@ -32,9 +37,17 @@ def suppress_logging(
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
""" """RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
Extends Storage to handle embeddings for memory entries, improving
search efficiency. This class extends BaseRAGStorage to handle embeddings for memory entries,
improving search efficiency through vector similarity.
Attributes:
app: ChromaDB client instance
collection: ChromaDB collection for storing embeddings
type: Type of memory storage
allow_reset: Whether memory reset is allowed
path: Custom storage path for the database
""" """
app: ClientAPI | None = None app: ClientAPI | None = None
@@ -59,15 +72,37 @@ class RAGStorage(BaseRAGStorage):
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config) self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self): def _initialize_app(self) -> None:
"""Initialize the ChromaDB client and collection.
Raises:
RuntimeError: If ChromaDB client initialization fails
EmbeddingConfigurationError: If embedding configuration is invalid
EmbeddingInitializationError: If embedding function fails to initialize
"""
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
self._set_embedder_config() self._set_embedder_config()
chroma_client = chromadb.PersistentClient( try:
path=self.path if self.path else self.storage_file_name, chroma_client = chromadb.PersistentClient(
settings=Settings(allow_reset=self.allow_reset), path=self.path if self.path else self.storage_file_name,
) settings=Settings(allow_reset=self.allow_reset),
)
self.app = chroma_client
if not self.app:
raise RuntimeError("Failed to initialize ChromaDB client")
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception as e:
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
self.app = chroma_client self.app = chroma_client
if not self.app: if not self.app:
@@ -151,6 +186,12 @@ class RAGStorage(BaseRAGStorage):
) )
def reset(self) -> None: def reset(self) -> None:
"""Reset the memory storage by clearing the database and removing files.
Raises:
RuntimeError: If memory reset fails and allow_reset is False
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
"""
try: try:
if self.app: if self.app:
self.app.reset() self.app.reset()
@@ -162,9 +203,9 @@ class RAGStorage(BaseRAGStorage):
self.collection = None self.collection = None
except Exception as e: except Exception as e:
if "attempt to write a readonly database" in str(e): if "attempt to write a readonly database" in str(e):
# Ignore this specific error # Ignore this specific error as it's expected in some environments
pass pass
else: else:
raise Exception( if not self.allow_reset:
f"An error occurred while resetting the {self.type} memory: {e}" raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
) logging.error(f"Error during {self.type} memory reset: {str(e)}")

View File

@@ -1,9 +1,15 @@
import os import os
from typing import Any, Dict, cast from typing import Any, Dict, List, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingProviderError,
EmbeddingInitializationError
)
class EmbeddingConfigurator: class EmbeddingConfigurator:
def __init__(self): def __init__(self):
@@ -21,9 +27,21 @@ class EmbeddingConfigurator:
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Dict[str, Any] | None = None, embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config.
Args:
embedder_config: Configuration dictionary containing provider and settings
Returns:
EmbeddingFunction: Configured embedding function for vector storage
Raises:
EmbeddingProviderError: If the provider is not supported
EmbeddingConfigurationError: If the configuration is invalid
EmbeddingInitializationError: If the embedding function fails to initialize
"""
if embedder_config is None: if embedder_config is None:
return self._create_default_embedding_function() return self._create_default_embedding_function()
@@ -36,11 +54,11 @@ class EmbeddingConfigurator:
validate_embedding_function(provider) validate_embedding_function(provider)
return provider return provider
except Exception as e: except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}") raise EmbeddingConfigurationError(f"Invalid custom embedding function: {str(e)}")
if not provider or provider not in self.embedding_functions: if not provider or provider not in self.embedding_functions:
raise Exception( raise EmbeddingProviderError(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" str(provider), list(self.embedding_functions.keys())
) )
return self.embedding_functions[str(provider)](config, model_name) return self.embedding_functions[str(provider)](config, model_name)
@@ -57,9 +75,10 @@ class EmbeddingConfigurator:
return OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name=model) return OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name=model)
elif provider == "ollama": elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
return OllamaEmbeddingFunction(url=os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"), model_name=model) url = os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings")
return OllamaEmbeddingFunction(url=url, model_name=model)
else: else:
raise ValueError(f"Unsupported default embedding provider: {provider}. Set CREWAI_EMBEDDING_PROVIDER to 'openai' or 'ollama'") raise EmbeddingProviderError(provider, ["openai", "ollama"])
@staticmethod @staticmethod
def _configure_openai(config, model_name): def _configure_openai(config, model_name):
@@ -157,9 +176,10 @@ class EmbeddingConfigurator:
from ibm_watsonx_ai import Credentials from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
except ImportError as e: except ImportError as e:
raise ImportError( raise EmbeddingConfigurationError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding." "IBM Watson dependencies are not installed. Please install them to use Watson embedding.",
) from e provider="watson"
)
class WatsonEmbeddingFunction(EmbeddingFunction): class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings: def __call__(self, input: Documents) -> Embeddings:
@@ -184,7 +204,6 @@ class EmbeddingConfigurator:
embeddings = embedding.embed_documents(input) embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings) return cast(Embeddings, embeddings)
except Exception as e: except Exception as e:
print("Error during Watson embedding:", e) raise EmbeddingInitializationError("watson", str(e))
raise e
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()

View File

@@ -0,0 +1,20 @@
from typing import List, Optional
class EmbeddingConfigurationError(Exception):
def __init__(self, message: str, provider: Optional[str] = None):
self.message = message
self.provider = provider
super().__init__(self.message)
class EmbeddingProviderError(EmbeddingConfigurationError):
def __init__(self, provider: str, supported_providers: List[str]):
message = f"Unsupported embedding provider: {provider}, supported providers: {supported_providers}"
super().__init__(message, provider)
class EmbeddingInitializationError(EmbeddingConfigurationError):
def __init__(self, provider: str, error: str):
message = f"Failed to initialize embedding function for provider {provider}: {error}"
super().__init__(message, provider)

View File

@@ -2,6 +2,10 @@ import os
import tempfile import tempfile
import pytest import pytest
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingProviderError
)
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
@pytest.fixture @pytest.fixture
@@ -32,3 +36,26 @@ def test_memory_reset_with_openai(temp_db_dir):
] ]
for memory in memories: for memory in memories:
memory.reset() memory.reset()
def test_memory_reset_with_invalid_provider(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
with pytest.raises(EmbeddingProviderError):
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_missing_ollama_url(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
os.environ.pop("CREWAI_OLLAMA_URL", None)
# Should use default URL when CREWAI_OLLAMA_URL is not set
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()