mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
fix: Remove OpenAI dependency for memory reset when using alternative LLMs
- Add environment variables for default embedding provider - Support Ollama as default embedding provider - Add tests for memory reset with different providers - Update documentation Fixes #2023 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -6,12 +6,17 @@ import shutil
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI, Collection
|
||||||
|
from chromadb.api.types import Documents, Embeddings, Metadatas
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||||
|
EmbeddingConfigurationError,
|
||||||
|
EmbeddingInitializationError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -32,9 +37,17 @@ def suppress_logging(
|
|||||||
|
|
||||||
|
|
||||||
class RAGStorage(BaseRAGStorage):
|
class RAGStorage(BaseRAGStorage):
|
||||||
"""
|
"""RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
|
||||||
Extends Storage to handle embeddings for memory entries, improving
|
|
||||||
search efficiency.
|
This class extends BaseRAGStorage to handle embeddings for memory entries,
|
||||||
|
improving search efficiency through vector similarity.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
app: ChromaDB client instance
|
||||||
|
collection: ChromaDB collection for storing embeddings
|
||||||
|
type: Type of memory storage
|
||||||
|
allow_reset: Whether memory reset is allowed
|
||||||
|
path: Custom storage path for the database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
@@ -59,15 +72,37 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
configurator = EmbeddingConfigurator()
|
configurator = EmbeddingConfigurator()
|
||||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self) -> None:
|
||||||
|
"""Initialize the ChromaDB client and collection.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If ChromaDB client initialization fails
|
||||||
|
EmbeddingConfigurationError: If embedding configuration is invalid
|
||||||
|
EmbeddingInitializationError: If embedding function fails to initialize
|
||||||
|
"""
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
self._set_embedder_config()
|
self._set_embedder_config()
|
||||||
chroma_client = chromadb.PersistentClient(
|
try:
|
||||||
path=self.path if self.path else self.storage_file_name,
|
chroma_client = chromadb.PersistentClient(
|
||||||
settings=Settings(allow_reset=self.allow_reset),
|
path=self.path if self.path else self.storage_file_name,
|
||||||
)
|
settings=Settings(allow_reset=self.allow_reset),
|
||||||
|
)
|
||||||
|
self.app = chroma_client
|
||||||
|
if not self.app:
|
||||||
|
raise RuntimeError("Failed to initialize ChromaDB client")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection = self.app.get_collection(
|
||||||
|
name=self.type, embedding_function=self.embedder_config
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.collection = self.app.create_collection(
|
||||||
|
name=self.type, embedding_function=self.embedder_config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
|
||||||
|
|
||||||
self.app = chroma_client
|
self.app = chroma_client
|
||||||
if not self.app:
|
if not self.app:
|
||||||
@@ -151,6 +186,12 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
"""Reset the memory storage by clearing the database and removing files.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If memory reset fails and allow_reset is False
|
||||||
|
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if self.app:
|
if self.app:
|
||||||
self.app.reset()
|
self.app.reset()
|
||||||
@@ -162,9 +203,9 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.collection = None
|
self.collection = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "attempt to write a readonly database" in str(e):
|
if "attempt to write a readonly database" in str(e):
|
||||||
# Ignore this specific error
|
# Ignore this specific error as it's expected in some environments
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
if not self.allow_reset:
|
||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
|
||||||
)
|
logging.error(f"Error during {self.type} memory reset: {str(e)}")
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
|
||||||
|
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||||
|
EmbeddingConfigurationError,
|
||||||
|
EmbeddingProviderError,
|
||||||
|
EmbeddingInitializationError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurator:
|
class EmbeddingConfigurator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -21,9 +27,21 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Dict[str, Any] | None = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config: Configuration dictionary containing provider and settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingFunction: Configured embedding function for vector storage
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbeddingProviderError: If the provider is not supported
|
||||||
|
EmbeddingConfigurationError: If the configuration is invalid
|
||||||
|
EmbeddingInitializationError: If the embedding function fails to initialize
|
||||||
|
"""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
@@ -36,11 +54,11 @@ class EmbeddingConfigurator:
|
|||||||
validate_embedding_function(provider)
|
validate_embedding_function(provider)
|
||||||
return provider
|
return provider
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
raise EmbeddingConfigurationError(f"Invalid custom embedding function: {str(e)}")
|
||||||
|
|
||||||
if not provider or provider not in self.embedding_functions:
|
if not provider or provider not in self.embedding_functions:
|
||||||
raise Exception(
|
raise EmbeddingProviderError(
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
str(provider), list(self.embedding_functions.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.embedding_functions[str(provider)](config, model_name)
|
return self.embedding_functions[str(provider)](config, model_name)
|
||||||
@@ -57,9 +75,10 @@ class EmbeddingConfigurator:
|
|||||||
return OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name=model)
|
return OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name=model)
|
||||||
elif provider == "ollama":
|
elif provider == "ollama":
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||||
return OllamaEmbeddingFunction(url=os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"), model_name=model)
|
url = os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings")
|
||||||
|
return OllamaEmbeddingFunction(url=url, model_name=model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported default embedding provider: {provider}. Set CREWAI_EMBEDDING_PROVIDER to 'openai' or 'ollama'")
|
raise EmbeddingProviderError(provider, ["openai", "ollama"])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_openai(config, model_name):
|
def _configure_openai(config, model_name):
|
||||||
@@ -157,9 +176,10 @@ class EmbeddingConfigurator:
|
|||||||
from ibm_watsonx_ai import Credentials
|
from ibm_watsonx_ai import Credentials
|
||||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise EmbeddingConfigurationError(
|
||||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding.",
|
||||||
) from e
|
provider="watson"
|
||||||
|
)
|
||||||
|
|
||||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
@@ -184,7 +204,6 @@ class EmbeddingConfigurator:
|
|||||||
embeddings = embedding.embed_documents(input)
|
embeddings = embedding.embed_documents(input)
|
||||||
return cast(Embeddings, embeddings)
|
return cast(Embeddings, embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error during Watson embedding:", e)
|
raise EmbeddingInitializationError("watson", str(e))
|
||||||
raise e
|
|
||||||
|
|
||||||
return WatsonEmbeddingFunction()
|
return WatsonEmbeddingFunction()
|
||||||
|
|||||||
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfigurationError(Exception):
|
||||||
|
def __init__(self, message: str, provider: Optional[str] = None):
|
||||||
|
self.message = message
|
||||||
|
self.provider = provider
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProviderError(EmbeddingConfigurationError):
|
||||||
|
def __init__(self, provider: str, supported_providers: List[str]):
|
||||||
|
message = f"Unsupported embedding provider: {provider}, supported providers: {supported_providers}"
|
||||||
|
super().__init__(message, provider)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingInitializationError(EmbeddingConfigurationError):
|
||||||
|
def __init__(self, provider: str, error: str):
|
||||||
|
message = f"Failed to initialize embedding function for provider {provider}: {error}"
|
||||||
|
super().__init__(message, provider)
|
||||||
@@ -2,6 +2,10 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
|
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
|
||||||
|
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||||
|
EmbeddingConfigurationError,
|
||||||
|
EmbeddingProviderError
|
||||||
|
)
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -32,3 +36,26 @@ def test_memory_reset_with_openai(temp_db_dir):
|
|||||||
]
|
]
|
||||||
for memory in memories:
|
for memory in memories:
|
||||||
memory.reset()
|
memory.reset()
|
||||||
|
|
||||||
|
def test_memory_reset_with_invalid_provider(temp_db_dir):
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
|
||||||
|
with pytest.raises(EmbeddingProviderError):
|
||||||
|
memories = [
|
||||||
|
ShortTermMemory(path=temp_db_dir),
|
||||||
|
LongTermMemory(path=temp_db_dir),
|
||||||
|
EntityMemory(path=temp_db_dir)
|
||||||
|
]
|
||||||
|
for memory in memories:
|
||||||
|
memory.reset()
|
||||||
|
|
||||||
|
def test_memory_reset_with_missing_ollama_url(temp_db_dir):
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||||
|
os.environ.pop("CREWAI_OLLAMA_URL", None)
|
||||||
|
# Should use default URL when CREWAI_OLLAMA_URL is not set
|
||||||
|
memories = [
|
||||||
|
ShortTermMemory(path=temp_db_dir),
|
||||||
|
LongTermMemory(path=temp_db_dir),
|
||||||
|
EntityMemory(path=temp_db_dir)
|
||||||
|
]
|
||||||
|
for memory in memories:
|
||||||
|
memory.reset()
|
||||||
|
|||||||
Reference in New Issue
Block a user