fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled

- support nested config format with embedderconfig typeddict  
- fix parsing for model/model_name compatibility  
- add validation, typing_extensions, and improved type hints  
- enhance embedding factory with env var injection and provider support  
- add tests for openai, azure, and all embedding providers  
- misc fixes: test file rename, updated mocking patterns
This commit is contained in:
Greyson LaLonde
2025-09-23 11:57:46 -04:00
committed by GitHub
parent 3e97393f58
commit 4ac65eb0a6
7 changed files with 923 additions and 296 deletions

View File

@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.factory import create_client from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord from crewai.rag.types import BaseRecord
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None, embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None, crew: Any = None,
path: str | None = None, path: str | None = None,
) -> None: ) -> None:
@@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage):
if self.embedder_config: if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config) embedding_function = get_embedding_function(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config.provider
if isinstance(self.embedder_config, EmbeddingOptions)
else self.embedder_config.get("provider", "unknown")
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
f"Provider: {provider}\n"
f"Error: {e}"
) from e
config = ChromaDBConfig( config = ChromaDBConfig(
embedding_function=cast( embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function ChromaEmbeddingFunctionWrapper, embedding_function

View File

@@ -1,6 +1,8 @@
"""Minimal embedding function factory for CrewAI.""" """Minimal embedding function factory for CrewAI."""
import os import os
from collections.abc import Callable, MutableMapping
from typing import Any, Final, Literal, TypedDict
from chromadb import EmbeddingFunction from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
from chromadb.utils.embedding_functions.text2vec_embedding_function import ( from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction, Text2VecEmbeddingFunction,
) )
from typing_extensions import NotRequired
from crewai.rag.embeddings.types import EmbeddingOptions from crewai.rag.embeddings.types import EmbeddingOptions
AllowedEmbeddingProviders = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
class EmbedderConfig(TypedDict):
"""Configuration for embedding functions with nested format."""
provider: AllowedEmbeddingProviders
config: NotRequired[dict[str, Any]]
EMBEDDING_PROVIDERS: Final[
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
"openai": ("OPENAI_API_KEY", "api_key"),
"cohere": ("COHERE_API_KEY", "api_key"),
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
"google-palm": ("GOOGLE_API_KEY", "api_key"),
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
"jina": ("JINA_API_KEY", "api_key"),
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
}
def _inject_api_key_from_env(
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
) -> None:
"""Inject API key or other required configuration from environment if not explicitly provided.
Args:
provider: The embedding provider name
config_dict: The configuration dictionary to modify in-place
Raises:
ImportError: If required libraries for certain providers are not installed
ValueError: If AWS session creation fails for amazon-bedrock
"""
if provider in PROVIDER_ENV_MAPPING:
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
if config_key not in config_dict:
env_value = os.getenv(env_var_name)
if env_value:
config_dict[config_key] = env_value
if provider == "amazon-bedrock":
if "session" not in config_dict:
try:
import boto3 # type: ignore[import]
config_dict["session"] = boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
def get_embedding_function( def get_embedding_function(
config: EmbeddingOptions | dict | None = None, config: EmbeddingOptions | EmbedderConfig | None = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB. """Get embedding function - delegates to ChromaDB.
Args: Args:
config: Optional configuration - either an EmbeddingOptions object or a dict with: config: Optional configuration - either:
- provider: The embedding provider to use (default: "openai") - EmbeddingOptions: Pydantic model with flat configuration
- Any other provider-specific parameters - EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
- None: Uses default OpenAI configuration
Returns: Returns:
EmbeddingFunction instance ready for use with ChromaDB EmbeddingFunction instance ready for use with ChromaDB
@@ -81,31 +180,33 @@ def get_embedding_function(
>>> embedder = get_embedding_function() >>> embedder = get_embedding_function()
# Use Cohere with dict # Use Cohere with dict
>>> embedder = get_embedding_function({ >>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "cohere", ... "provider": "cohere",
... "api_key": "your-key", ... "config": {
... "model_name": "embed-english-v3.0" ... "api_key": "your-key",
... }) ... "model_name": "embed-english-v3.0"
... }
... }))
# Use with EmbeddingOptions # Use with EmbeddingOptions
>>> embedder = get_embedding_function( >>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2") ... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... ) ... )
# Use local sentence transformers (no API key needed) # Use Azure OpenAI
>>> embedder = get_embedding_function({ >>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "sentence-transformer", ... "provider": "openai",
... "model_name": "all-MiniLM-L6-v2" ... "config": {
... "api_key": "your-azure-key",
... "api_base": "https://your-resource.openai.azure.com/",
... "api_type": "azure",
... "api_version": "2023-05-15",
... "model": "text-embedding-3-small",
... "deployment_id": "your-deployment-name"
... }
... }) ... })
# Use Ollama for local embeddings >>> embedder = get_embedding_function(EmbedderConfig(**{
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "onnx" ... "provider": "onnx"
... }) ... })
""" """
@@ -114,35 +215,33 @@ def get_embedding_function(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
) )
# Handle EmbeddingOptions object provider: AllowedEmbeddingProviders
config_dict: dict[str, Any]
if isinstance(config, EmbeddingOptions): if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True) config_dict = config.model_dump(exclude_none=True)
provider = config_dict["provider"]
else: else:
config_dict = config.copy() provider = config["provider"]
nested: dict[str, Any] = config.get("config", {})
provider = config_dict.pop("provider", "openai") if not nested and len(config) > 1:
raise ValueError(
"Invalid embedder configuration format. "
"Configuration must be nested under a 'config' key. "
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
)
embedding_functions = { config_dict = dict(nested)
"openai": OpenAIEmbeddingFunction, if "model" in config_dict and "model_name" not in config_dict:
"cohere": CohereEmbeddingFunction, config_dict["model_name"] = config_dict.pop("model")
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
if provider not in embedding_functions: if provider not in EMBEDDING_PROVIDERS:
raise ValueError( raise ValueError(
f"Unsupported provider: {provider}. " f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}" f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
) )
return embedding_functions[provider](**config_dict)
_inject_api_key_from_env(provider, config_dict)
return EMBEDDING_PROVIDERS[provider](**config_dict)

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from crewai.rag.embeddings.factory import EmbedderConfig
from crewai.rag.embeddings.types import EmbeddingOptions
class BaseRAGStorage(ABC): class BaseRAGStorage(ABC):
""" """
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None, embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type

View File

@@ -0,0 +1,598 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from pydantic import SecretStr
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
options = EmbeddingOptions(
provider="openai",
api_key=SecretStr("test-key"),
model_name="text-embedding-3-large",
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model_name" in call_kwargs
assert call_kwargs["model_name"] == "text-embedding-3-large"
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_st = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_st
mock_providers.__contains__.return_value = True
config = {
"provider": "sentence-transformer",
"config": {"model_name": "all-MiniLM-L6-v2"},
}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_ollama = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_ollama
mock_providers.__contains__.return_value = True
config = {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
},
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_cohere = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_cohere
mock_providers.__contains__.return_value = True
config = {
"provider": "cohere",
"config": {"api_key": "cohere-key", "model_name": "embed-english-v3.0"},
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_hf = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_hf
mock_providers.__contains__.return_value = True
config = {
"provider": "huggingface",
"config": {
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
},
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_onnx = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_onnx
mock_providers.__contains__.return_value = True
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_palm = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_palm
mock_providers.__contains__.return_value = True
config = {"provider": "google-palm", "config": {"api_key": "palm-key"}}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function with explicit session."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_bedrock = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_bedrock
mock_providers.__contains__.return_value = True
# Provide an explicit session to avoid boto3 import
mock_session = MagicMock()
config = {
"provider": "amazon-bedrock",
"config": {
"session": mock_session,
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
},
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
session=mock_session,
region_name="us-west-2",
model_name="amazon.titan-embed-text-v1",
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_jina = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_jina
mock_providers.__contains__.return_value = True
config = {
"provider": "jina",
"config": {
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
},
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
options = EmbeddingOptions(
provider="openai", api_key=SecretStr("test-key"), model_name=None
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model_name" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_instructor = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_instructor
mock_providers.__contains__.return_value = True
config = {
"provider": "instructor",
"config": {"model_name": "hkunlp/instructor-large"},
}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance
def test_get_embedding_function_google_generativeai() -> None:
"""Test Google Generative AI embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_google = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_google
mock_providers.__contains__.return_value = True
config = {
"provider": "google-generativeai",
"config": {"api_key": "google-key", "model_name": "models/embedding-001"},
}
result = get_embedding_function(config)
mock_google.assert_called_once_with(
api_key="google-key", model_name="models/embedding-001"
)
assert result == mock_instance
def test_get_embedding_function_google_vertex() -> None:
"""Test Google Vertex AI embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_vertex = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_vertex
mock_providers.__contains__.return_value = True
config = {
"provider": "google-vertex",
"config": {
"api_key": "vertex-key",
"project_id": "my-project",
"region": "us-central1",
},
}
result = get_embedding_function(config)
mock_vertex.assert_called_once_with(
api_key="vertex-key", project_id="my-project", region="us-central1"
)
assert result == mock_instance
def test_get_embedding_function_roboflow() -> None:
"""Test Roboflow embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_roboflow = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_roboflow
mock_providers.__contains__.return_value = True
config = {
"provider": "roboflow",
"config": {
"api_key": "roboflow-key",
"api_url": "https://infer.roboflow.com",
},
}
result = get_embedding_function(config)
mock_roboflow.assert_called_once_with(
api_key="roboflow-key", api_url="https://infer.roboflow.com"
)
assert result == mock_instance
def test_get_embedding_function_openclip() -> None:
"""Test OpenCLIP embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openclip = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openclip
mock_providers.__contains__.return_value = True
config = {
"provider": "openclip",
"config": {"model_name": "ViT-B-32", "checkpoint": "laion2b_s34b_b79k"},
}
result = get_embedding_function(config)
mock_openclip.assert_called_once_with(
model_name="ViT-B-32", checkpoint="laion2b_s34b_b79k"
)
assert result == mock_instance
def test_get_embedding_function_text2vec() -> None:
"""Test Text2Vec embedding function."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_text2vec = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_text2vec
mock_providers.__contains__.return_value = True
config = {
"provider": "text2vec",
"config": {"model_name": "shibing624/text2vec-base-chinese"},
}
result = get_embedding_function(config)
mock_text2vec.assert_called_once_with(
model_name="shibing624/text2vec-base-chinese"
)
assert result == mock_instance
def test_model_to_model_name_conversion() -> None:
"""Test that 'model' field is converted to 'model_name' for nested config."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
config = {
"provider": "openai",
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
}
result = get_embedding_function(config)
mock_openai.assert_called_once_with(
api_key="test-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_api_key_injection_from_env_openai() -> None:
"""Test that OpenAI API key is injected from environment when not provided."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
mock_getenv.return_value = "env-openai-key"
config = {
"provider": "openai",
"config": {"model": "text-embedding-3-small"},
}
result = get_embedding_function(config)
mock_getenv.assert_called_with("OPENAI_API_KEY")
mock_openai.assert_called_once_with(
api_key="env-openai-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_api_key_injection_from_env_cohere() -> None:
"""Test that Cohere API key is injected from environment when not provided."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_cohere = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_cohere
mock_providers.__contains__.return_value = True
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
mock_getenv.return_value = "env-cohere-key"
config = {
"provider": "cohere",
"config": {"model_name": "embed-english-v3.0"},
}
result = get_embedding_function(config)
mock_getenv.assert_called_with("COHERE_API_KEY")
mock_cohere.assert_called_once_with(
api_key="env-cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_api_key_not_injected_when_provided() -> None:
"""Test that API key from config takes precedence over environment."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_openai = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_openai
mock_providers.__contains__.return_value = True
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
mock_getenv.return_value = "env-key"
config = {
"provider": "openai",
"config": {"api_key": "config-key", "model": "text-embedding-3-small"},
}
result = get_embedding_function(config)
mock_openai.assert_called_once_with(
api_key="config-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_amazon_bedrock_session_injection() -> None:
"""Test that boto3 session is automatically created for amazon-bedrock."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_bedrock = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_bedrock
mock_providers.__contains__.return_value = True
mock_boto3 = MagicMock()
with patch.dict("sys.modules", {"boto3": mock_boto3}):
mock_session = MagicMock()
mock_boto3.Session.return_value = mock_session
config = {
"provider": "amazon-bedrock",
"config": {"model_name": "amazon.titan-embed-text-v1"},
}
result = get_embedding_function(config)
mock_boto3.Session.assert_called_once()
mock_bedrock.assert_called_once_with(
session=mock_session, model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_amazon_bedrock_session_not_injected_when_provided() -> None:
"""Test that provided session is used for amazon-bedrock."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_instance = MagicMock()
mock_bedrock = MagicMock(return_value=mock_instance)
mock_providers.__getitem__.return_value = mock_bedrock
mock_providers.__contains__.return_value = True
existing_session = MagicMock()
config = {
"provider": "amazon-bedrock",
"config": {
"session": existing_session,
"model_name": "amazon.titan-embed-text-v1",
},
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
session=existing_session, model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_amazon_bedrock_boto3_import_error() -> None:
"""Test error handling when boto3 is not installed."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_providers.__contains__.return_value = True
with patch.dict("sys.modules", {"boto3": None}):
config = {
"provider": "amazon-bedrock",
"config": {"model_name": "amazon.titan-embed-text-v1"},
}
with pytest.raises(
ImportError, match="boto3 is required for amazon-bedrock"
):
get_embedding_function(config)
def test_amazon_bedrock_session_creation_error() -> None:
"""Test error handling when AWS session creation fails."""
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
mock_providers.__contains__.return_value = True
mock_boto3 = MagicMock()
with patch.dict("sys.modules", {"boto3": mock_boto3}):
mock_boto3.Session.side_effect = Exception("AWS credentials not configured")
config = {
"provider": "amazon-bedrock",
"config": {"model_name": "amazon.titan-embed-text-v1"},
}
with pytest.raises(ValueError, match="Failed to create AWS session"):
get_embedding_function(config)
def test_invalid_config_format() -> None:
"""Test error handling for invalid config format."""
config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
with pytest.raises(ValueError, match="Invalid embedder configuration format"):
get_embedding_function(config)

View File

@@ -0,0 +1,79 @@
"""Test Azure embedder configuration with factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
class TestAzureEmbedderFactory:
"""Test Azure embedder configuration with factory function."""
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
def test_azure_with_nested_config(self, mock_providers):
"""Test Azure configuration with nested config key."""
mock_embedding = MagicMock()
mock_openai_func = MagicMock(return_value=mock_embedding)
mock_providers.__getitem__.return_value = mock_openai_func
mock_providers.__contains__.return_value = True
embedder_config = EmbedderConfig(
provider="openai",
config={
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
)
result = get_embedding_function(embedder_config)
mock_openai_func.assert_called_once_with(
api_key="test-azure-key",
api_base="https://test.openai.azure.com/",
api_type="azure",
api_version="2023-05-15",
model_name="text-embedding-3-small",
deployment_id="test-deployment",
)
assert result == mock_embedding
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
def test_regular_openai_with_nested_config(self, mock_providers):
"""Test regular OpenAI configuration with nested config."""
mock_embedding = MagicMock()
mock_openai_func = MagicMock(return_value=mock_embedding)
mock_providers.__getitem__.return_value = mock_openai_func
mock_providers.__contains__.return_value = True
embedder_config = EmbedderConfig(
provider="openai",
config={"api_key": "test-openai-key", "model": "text-embedding-3-large"},
)
result = get_embedding_function(embedder_config)
mock_openai_func.assert_called_once_with(
api_key="test-openai-key", model_name="text-embedding-3-large"
)
assert result == mock_embedding
def test_flat_format_raises_error(self):
"""Test that flat format raises an error."""
embedder_config = {
"provider": "openai",
"api_key": "test-key",
"model_name": "text-embedding-3-small",
}
with pytest.raises(ValueError) as exc_info:
get_embedding_function(embedder_config)
assert "Invalid embedder configuration format" in str(exc_info.value)
assert "nested under a 'config' key" in str(exc_info.value)

View File

@@ -1,250 +0,0 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(
provider="openai", api_key="test-key", model="text-embedding-3-large"
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
# OpenAI uses model_name parameter, not model
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch(
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
) as mock_st:
mock_instance = MagicMock()
mock_st.return_value = mock_instance
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
mock_instance = MagicMock()
mock_ollama.return_value = mock_instance
config = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
mock_instance = MagicMock()
mock_cohere.return_value = mock_instance
config = {
"provider": "cohere",
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
mock_instance = MagicMock()
mock_hf.return_value = mock_instance
config = {
"provider": "huggingface",
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
mock_instance = MagicMock()
mock_onnx.return_value = mock_instance
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch(
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
) as mock_palm:
mock_instance = MagicMock()
mock_palm.return_value = mock_instance
config = {"provider": "google-palm", "api_key": "palm-key"}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function."""
with patch(
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
) as mock_bedrock:
mock_instance = MagicMock()
mock_bedrock.return_value = mock_instance
config = {
"provider": "amazon-bedrock",
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
mock_instance = MagicMock()
mock_jina.return_value = mock_instance
config = {
"provider": "jina",
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch(
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
) as mock_instructor:
mock_instance = MagicMock()
mock_instructor.return_value = mock_instance
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance

View File

@@ -0,0 +1,82 @@
"""Test Azure embedder configuration with nested format only."""
from unittest.mock import MagicMock, patch
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
class TestAzureEmbedderConfiguration:
"""Test Azure embedder configuration with nested format."""
@patch(
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
)
def test_azure_openai_with_nested_config(self, mock_openai_func):
"""Test Azure configuration using OpenAI provider with nested config key."""
mock_embedding = MagicMock()
mock_openai_func.return_value = mock_embedding
configurator = EmbeddingConfigurator()
embedder_config = {
"provider": "openai",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
result = configurator.configure_embedder(embedder_config)
mock_openai_func.assert_called_once_with(
api_key="test-azure-key",
model_name="text-embedding-3-small",
api_base="https://test.openai.azure.com/",
api_type="azure",
api_version="2023-05-15",
default_headers=None,
dimensions=None,
deployment_id="test-deployment",
organization_id=None,
)
assert result == mock_embedding
@patch(
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
)
def test_azure_provider_with_nested_config(self, mock_openai_func):
"""Test using 'azure' as provider with nested config."""
mock_embedding = MagicMock()
mock_openai_func.return_value = mock_embedding
configurator = EmbeddingConfigurator()
embedder_config = {
"provider": "azure",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2023-05-15",
"model": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
result = configurator.configure_embedder(embedder_config)
mock_openai_func.assert_called_once_with(
api_key="test-azure-key",
api_base="https://test.openai.azure.com/",
api_type="azure",
api_version="2023-05-15",
model_name="text-embedding-3-small",
default_headers=None,
dimensions=None,
deployment_id="test-deployment",
organization_id=None,
)
assert result == mock_embedding