mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Enhance embedding configuration with custom embedder support
- Add support for custom embedding functions in EmbeddingConfigurator - Update type hints for embedder configuration - Extend configuration options for various embedding providers - Add optional embedder configuration to Memory class
This commit is contained in:
@@ -8,6 +8,8 @@ class Memory:
|
|||||||
Base class for memory, now supporting agent tags and generic metadata.
|
Base class for memory, now supporting agent tags and generic metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def __init__(self, storage: RAGStorage):
|
def __init__(self, storage: RAGStorage):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: Optional[Any] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
):
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, Union, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
|
|||||||
"bedrock": self._configure_bedrock,
|
"bedrock": self._configure_bedrock,
|
||||||
"huggingface": self._configure_huggingface,
|
"huggingface": self._configure_huggingface,
|
||||||
"watson": self._configure_watson,
|
"watson": self._configure_watson,
|
||||||
|
"custom": self._configure_custom,
|
||||||
}
|
}
|
||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Dict[str, Any] | None = None,
|
embedder_config: Union[Dict[str, Any], None] = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
@@ -30,19 +31,13 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.get("provider")
|
||||||
config = embedder_config.get("config", {})
|
config = embedder_config.get("config", {})
|
||||||
model_name = config.get("model")
|
model_name = config.get("model") if provider != "custom" else None
|
||||||
|
|
||||||
if isinstance(provider, EmbeddingFunction):
|
|
||||||
try:
|
|
||||||
validate_embedding_function(provider)
|
|
||||||
return provider
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
|
||||||
|
|
||||||
if provider not in self.embedding_functions:
|
if provider not in self.embedding_functions:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.embedding_functions[provider](config, model_name)
|
return self.embedding_functions[provider](config, model_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -64,6 +59,13 @@ class EmbeddingConfigurator:
|
|||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
api_base=config.get("api_base", None),
|
||||||
|
api_type=config.get("api_type", None),
|
||||||
|
api_version=config.get("api_version", None),
|
||||||
|
default_headers=config.get("default_headers", None),
|
||||||
|
dimensions=config.get("dimensions", None),
|
||||||
|
deployment_id=config.get("deployment_id", None),
|
||||||
|
organization_id=config.get("organization_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -78,6 +80,10 @@ class EmbeddingConfigurator:
|
|||||||
api_type=config.get("api_type", "azure"),
|
api_type=config.get("api_type", "azure"),
|
||||||
api_version=config.get("api_version"),
|
api_version=config.get("api_version"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
default_headers=config.get("default_headers"),
|
||||||
|
dimensions=config.get("dimensions"),
|
||||||
|
deployment_id=config.get("deployment_id"),
|
||||||
|
organization_id=config.get("organization_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -100,6 +106,8 @@ class EmbeddingConfigurator:
|
|||||||
return GoogleVertexEmbeddingFunction(
|
return GoogleVertexEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
|
project_id=config.get("project_id"),
|
||||||
|
region=config.get("region"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -111,6 +119,7 @@ class EmbeddingConfigurator:
|
|||||||
return GoogleGenerativeAiEmbeddingFunction(
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
|
task_type=config.get("task_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -195,3 +204,28 @@ class EmbeddingConfigurator:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
return WatsonEmbeddingFunction()
|
return WatsonEmbeddingFunction()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_custom(config, model_name):
|
||||||
|
custom_embedder = config.get("embedder")
|
||||||
|
if isinstance(custom_embedder, EmbeddingFunction):
|
||||||
|
try:
|
||||||
|
validate_embedding_function(custom_embedder)
|
||||||
|
return custom_embedder
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||||
|
elif callable(custom_embedder):
|
||||||
|
try:
|
||||||
|
instance = custom_embedder()
|
||||||
|
if isinstance(instance, EmbeddingFunction):
|
||||||
|
validate_embedding_function(instance)
|
||||||
|
return instance
|
||||||
|
raise ValueError(
|
||||||
|
"Custom embedder does not create an EmbeddingFunction instance"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user