fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled

- support nested config format with embedderconfig typeddict  
- fix parsing for model/model_name compatibility  
- add validation, typing_extensions, and improved type hints  
- enhance embedding factory with env var injection and provider support  
- add tests for openai, azure, and all embedding providers  
- misc fixes: test file rename, updated mocking patterns
This commit is contained in:
Greyson LaLonde
2025-09-23 11:57:46 -04:00
committed by GitHub
parent 3e97393f58
commit 4ac65eb0a6
7 changed files with 923 additions and 296 deletions

View File

@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
self,
type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
@@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage):
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config.provider
if isinstance(self.embedder_config, EmbeddingOptions)
else self.embedder_config.get("provider", "unknown")
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
f"Provider: {provider}\n"
f"Error: {e}"
) from e
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function

View File

@@ -1,6 +1,8 @@
"""Minimal embedding function factory for CrewAI."""
import os
from collections.abc import Callable, MutableMapping
from typing import Any, Final, Literal, TypedDict
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from typing_extensions import NotRequired
from crewai.rag.embeddings.types import EmbeddingOptions
AllowedEmbeddingProviders = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
class EmbedderConfig(TypedDict):
"""Configuration for embedding functions with nested format."""
provider: AllowedEmbeddingProviders
config: NotRequired[dict[str, Any]]
EMBEDDING_PROVIDERS: Final[
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
"openai": ("OPENAI_API_KEY", "api_key"),
"cohere": ("COHERE_API_KEY", "api_key"),
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
"google-palm": ("GOOGLE_API_KEY", "api_key"),
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
"jina": ("JINA_API_KEY", "api_key"),
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
}
def _inject_api_key_from_env(
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
) -> None:
"""Inject API key or other required configuration from environment if not explicitly provided.
Args:
provider: The embedding provider name
config_dict: The configuration dictionary to modify in-place
Raises:
ImportError: If required libraries for certain providers are not installed
ValueError: If AWS session creation fails for amazon-bedrock
"""
if provider in PROVIDER_ENV_MAPPING:
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
if config_key not in config_dict:
env_value = os.getenv(env_var_name)
if env_value:
config_dict[config_key] = env_value
if provider == "amazon-bedrock":
if "session" not in config_dict:
try:
import boto3 # type: ignore[import]
config_dict["session"] = boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
config: EmbeddingOptions | EmbedderConfig | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
config: Optional configuration - either an EmbeddingOptions object or a dict with:
- provider: The embedding provider to use (default: "openai")
- Any other provider-specific parameters
config: Optional configuration - either:
- EmbeddingOptions: Pydantic model with flat configuration
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
- None: Uses default OpenAI configuration
Returns:
EmbeddingFunction instance ready for use with ChromaDB
@@ -81,31 +180,33 @@ def get_embedding_function(
>>> embedder = get_embedding_function()
# Use Cohere with dict
>>> embedder = get_embedding_function({
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "cohere",
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... })
... "config": {
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... }
... }))
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use local sentence transformers (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "sentence-transformer",
... "model_name": "all-MiniLM-L6-v2"
# Use Azure OpenAI
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "openai",
... "config": {
... "api_key": "your-azure-key",
... "api_base": "https://your-resource.openai.azure.com/",
... "api_type": "azure",
... "api_version": "2023-05-15",
... "model": "text-embedding-3-small",
... "deployment_id": "your-deployment-name"
... }
... })
# Use Ollama for local embeddings
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "onnx"
... })
"""
@@ -114,35 +215,33 @@ def get_embedding_function(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# Handle EmbeddingOptions object
provider: AllowedEmbeddingProviders
config_dict: dict[str, Any]
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
provider = config_dict["provider"]
else:
config_dict = config.copy()
provider = config["provider"]
nested: dict[str, Any] = config.get("config", {})
provider = config_dict.pop("provider", "openai")
if not nested and len(config) > 1:
raise ValueError(
"Invalid embedder configuration format. "
"Configuration must be nested under a 'config' key. "
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
)
embedding_functions = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
config_dict = dict(nested)
if "model" in config_dict and "model_name" not in config_dict:
config_dict["model_name"] = config_dict.pop("model")
if provider not in embedding_functions:
if provider not in EMBEDDING_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}"
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
)
return embedding_functions[provider](**config_dict)
_inject_api_key_from_env(provider, config_dict)
return EMBEDDING_PROVIDERS[provider](**config_dict)

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any
from crewai.rag.embeddings.factory import EmbedderConfig
from crewai.rag.embeddings.types import EmbeddingOptions
class BaseRAGStorage(ABC):
"""
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None,
):
self.type = type