mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 16:22:49 +00:00
fix: Remove OpenAI dependency for memory reset when using alternative LLMs
- Add environment variables for default embedding provider - Support Ollama as default embedding provider - Add tests for memory reset with different providers - Update documentation Fixes #2023 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||
EmbeddingConfigurationError,
|
||||
EmbeddingProviderError,
|
||||
EmbeddingInitializationError
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
@@ -21,9 +27,21 @@ class EmbeddingConfigurator:
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Dict[str, Any] | None = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
"""Configures and returns an embedding function based on the provided config.
|
||||
|
||||
Args:
|
||||
embedder_config: Configuration dictionary containing provider and settings
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction: Configured embedding function for vector storage
|
||||
|
||||
Raises:
|
||||
EmbeddingProviderError: If the provider is not supported
|
||||
EmbeddingConfigurationError: If the configuration is invalid
|
||||
EmbeddingInitializationError: If the embedding function fails to initialize
|
||||
"""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
@@ -36,11 +54,11 @@ class EmbeddingConfigurator:
|
||||
validate_embedding_function(provider)
|
||||
return provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
raise EmbeddingConfigurationError(f"Invalid custom embedding function: {str(e)}")
|
||||
|
||||
if not provider or provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
raise EmbeddingProviderError(
|
||||
str(provider), list(self.embedding_functions.keys())
|
||||
)
|
||||
|
||||
return self.embedding_functions[str(provider)](config, model_name)
|
||||
@@ -57,9 +75,10 @@ class EmbeddingConfigurator:
|
||||
return OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name=model)
|
||||
elif provider == "ollama":
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
return OllamaEmbeddingFunction(url=os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"), model_name=model)
|
||||
url = os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings")
|
||||
return OllamaEmbeddingFunction(url=url, model_name=model)
|
||||
else:
|
||||
raise ValueError(f"Unsupported default embedding provider: {provider}. Set CREWAI_EMBEDDING_PROVIDER to 'openai' or 'ollama'")
|
||||
raise EmbeddingProviderError(provider, ["openai", "ollama"])
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
@@ -157,9 +176,10 @@ class EmbeddingConfigurator:
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
raise EmbeddingConfigurationError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding.",
|
||||
provider="watson"
|
||||
)
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
@@ -184,7 +204,6 @@ class EmbeddingConfigurator:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
raise EmbeddingInitializationError("watson", str(e))
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingConfigurationError(Exception):
|
||||
def __init__(self, message: str, provider: Optional[str] = None):
|
||||
self.message = message
|
||||
self.provider = provider
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class EmbeddingProviderError(EmbeddingConfigurationError):
|
||||
def __init__(self, provider: str, supported_providers: List[str]):
|
||||
message = f"Unsupported embedding provider: {provider}, supported providers: {supported_providers}"
|
||||
super().__init__(message, provider)
|
||||
|
||||
|
||||
class EmbeddingInitializationError(EmbeddingConfigurationError):
|
||||
def __init__(self, provider: str, error: str):
|
||||
message = f"Failed to initialize embedding function for provider {provider}: {error}"
|
||||
super().__init__(message, provider)
|
||||
Reference in New Issue
Block a user