mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-13 14:32:47 +00:00
fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
- support nested config format with embedderconfig typeddict - fix parsing for model/model_name compatibility - add validation, typing_extensions, and improved type hints - enhance embedding factory with env var injection and provider support - add tests for openai, azure, and all embedding providers - misc fixes: test file rename, updated mocking patterns
This commit is contained in:
@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config.provider
|
||||
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||
else self.embedder_config.get("provider", "unknown")
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any, Final, Literal, TypedDict
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
|
||||
|
||||
class EmbedderConfig(TypedDict):
|
||||
"""Configuration for embedding functions with nested format."""
|
||||
|
||||
provider: AllowedEmbeddingProviders
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
EMBEDDING_PROVIDERS: Final[
|
||||
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
|
||||
] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
|
||||
"openai": ("OPENAI_API_KEY", "api_key"),
|
||||
"cohere": ("COHERE_API_KEY", "api_key"),
|
||||
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
|
||||
"google-palm": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
|
||||
"jina": ("JINA_API_KEY", "api_key"),
|
||||
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
|
||||
}
|
||||
|
||||
|
||||
def _inject_api_key_from_env(
|
||||
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
|
||||
) -> None:
|
||||
"""Inject API key or other required configuration from environment if not explicitly provided.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider name
|
||||
config_dict: The configuration dictionary to modify in-place
|
||||
|
||||
Raises:
|
||||
ImportError: If required libraries for certain providers are not installed
|
||||
ValueError: If AWS session creation fails for amazon-bedrock
|
||||
"""
|
||||
if provider in PROVIDER_ENV_MAPPING:
|
||||
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
|
||||
if config_key not in config_dict:
|
||||
env_value = os.getenv(env_var_name)
|
||||
if env_value:
|
||||
config_dict[config_key] = env_value
|
||||
|
||||
if provider == "amazon-bedrock":
|
||||
if "session" not in config_dict:
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
config_dict["session"] = boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
config: Optional configuration - either:
|
||||
- EmbeddingOptions: Pydantic model with flat configuration
|
||||
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
|
||||
- None: Uses default OpenAI configuration
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
@@ -81,31 +180,33 @@ def get_embedding_function(
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
... "config": {
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... }
|
||||
... }))
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
# Use Azure OpenAI
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "openai",
|
||||
... "config": {
|
||||
... "api_key": "your-azure-key",
|
||||
... "api_base": "https://your-resource.openai.azure.com/",
|
||||
... "api_type": "azure",
|
||||
... "api_version": "2023-05-15",
|
||||
... "model": "text-embedding-3-small",
|
||||
... "deployment_id": "your-deployment-name"
|
||||
... }
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
"""
|
||||
@@ -114,35 +215,33 @@ def get_embedding_function(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
provider: AllowedEmbeddingProviders
|
||||
config_dict: dict[str, Any]
|
||||
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
provider = config_dict["provider"]
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
provider = config["provider"]
|
||||
nested: dict[str, Any] = config.get("config", {})
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
if not nested and len(config) > 1:
|
||||
raise ValueError(
|
||||
"Invalid embedder configuration format. "
|
||||
"Configuration must be nested under a 'config' key. "
|
||||
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
|
||||
)
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
config_dict = dict(nested)
|
||||
if "model" in config_dict and "model_name" not in config_dict:
|
||||
config_dict["model_name"] = config_dict.pop("model")
|
||||
|
||||
if provider not in embedding_functions:
|
||||
if provider not in EMBEDDING_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
|
||||
_inject_api_key_from_env(provider, config_dict)
|
||||
|
||||
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
598
tests/rag/embeddings/test_embedding_factory.py
Normal file
598
tests/rag/embeddings/test_embedding_factory.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai",
|
||||
api_key=SecretStr("test-key"),
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model_name" in call_kwargs
|
||||
assert call_kwargs["model_name"] == "text-embedding-3-large"
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_st = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_st
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "sentence-transformer",
|
||||
"config": {"model_name": "all-MiniLM-L6-v2"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_ollama
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_cohere
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"config": {"api_key": "cohere-key", "model_name": "embed-english-v3.0"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_hf
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_onnx
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_palm
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {"provider": "google-palm", "config": {"api_key": "palm-key"}}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function with explicit session."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_bedrock
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
# Provide an explicit session to avoid boto3 import
|
||||
mock_session = MagicMock()
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {
|
||||
"session": mock_session,
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
session=mock_session,
|
||||
region_name="us-west-2",
|
||||
model_name="amazon.titan-embed-text-v1",
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_jina
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"config": {
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key=SecretStr("test-key"), model_name=None
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model_name" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_instructor
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "instructor",
|
||||
"config": {"model_name": "hkunlp/instructor-large"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_generativeai() -> None:
|
||||
"""Test Google Generative AI embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_google = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_google
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {"api_key": "google-key", "model_name": "models/embedding-001"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_google.assert_called_once_with(
|
||||
api_key="google-key", model_name="models/embedding-001"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_vertex() -> None:
|
||||
"""Test Google Vertex AI embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_vertex = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_vertex
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "vertex-key",
|
||||
"project_id": "my-project",
|
||||
"region": "us-central1",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_vertex.assert_called_once_with(
|
||||
api_key="vertex-key", project_id="my-project", region="us-central1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_roboflow() -> None:
|
||||
"""Test Roboflow embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_roboflow = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_roboflow
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "roboflow",
|
||||
"config": {
|
||||
"api_key": "roboflow-key",
|
||||
"api_url": "https://infer.roboflow.com",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_roboflow.assert_called_once_with(
|
||||
api_key="roboflow-key", api_url="https://infer.roboflow.com"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_openclip() -> None:
|
||||
"""Test OpenCLIP embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openclip = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openclip
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "openclip",
|
||||
"config": {"model_name": "ViT-B-32", "checkpoint": "laion2b_s34b_b79k"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_openclip.assert_called_once_with(
|
||||
model_name="ViT-B-32", checkpoint="laion2b_s34b_b79k"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_text2vec() -> None:
|
||||
"""Test Text2Vec embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_text2vec = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_text2vec
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "text2vec",
|
||||
"config": {"model_name": "shibing624/text2vec-base-chinese"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_text2vec.assert_called_once_with(
|
||||
model_name="shibing624/text2vec-base-chinese"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_model_to_model_name_conversion() -> None:
|
||||
"""Test that 'model' field is converted to 'model_name' for nested config."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_api_key_injection_from_env_openai() -> None:
|
||||
"""Test that OpenAI API key is injected from environment when not provided."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||
mock_getenv.return_value = "env-openai-key"
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_getenv.assert_called_with("OPENAI_API_KEY")
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="env-openai-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_api_key_injection_from_env_cohere() -> None:
|
||||
"""Test that Cohere API key is injected from environment when not provided."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_cohere
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||
mock_getenv.return_value = "env-cohere-key"
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"config": {"model_name": "embed-english-v3.0"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_getenv.assert_called_with("COHERE_API_KEY")
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="env-cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_api_key_not_injected_when_provided() -> None:
|
||||
"""Test that API key from config takes precedence over environment."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_openai
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||
mock_getenv.return_value = "env-key"
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "config-key", "model": "text-embedding-3-small"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="config-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_amazon_bedrock_session_injection() -> None:
|
||||
"""Test that boto3 session is automatically created for amazon-bedrock."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_bedrock
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
mock_boto3 = MagicMock()
|
||||
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||
mock_session = MagicMock()
|
||||
mock_boto3.Session.return_value = mock_session
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_boto3.Session.assert_called_once()
|
||||
mock_bedrock.assert_called_once_with(
|
||||
session=mock_session, model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_amazon_bedrock_session_not_injected_when_provided() -> None:
|
||||
"""Test that provided session is used for amazon-bedrock."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||
mock_providers.__getitem__.return_value = mock_bedrock
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
existing_session = MagicMock()
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {
|
||||
"session": existing_session,
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
},
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
session=existing_session, model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_amazon_bedrock_boto3_import_error() -> None:
|
||||
"""Test error handling when boto3 is not installed."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {"boto3": None}):
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||
}
|
||||
|
||||
with pytest.raises(
|
||||
ImportError, match="boto3 is required for amazon-bedrock"
|
||||
):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_amazon_bedrock_session_creation_error() -> None:
|
||||
"""Test error handling when AWS session creation fails."""
|
||||
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
mock_boto3 = MagicMock()
|
||||
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||
mock_boto3.Session.side_effect = Exception("AWS credentials not configured")
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create AWS session"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_invalid_config_format() -> None:
|
||||
"""Test error handling for invalid config format."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid embedder configuration format"):
|
||||
get_embedding_function(config)
|
||||
79
tests/rag/embeddings/test_factory_azure.py
Normal file
79
tests/rag/embeddings/test_factory_azure.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Test Azure embedder configuration with factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
|
||||
|
||||
class TestAzureEmbedderFactory:
|
||||
"""Test Azure embedder configuration with factory function."""
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
|
||||
def test_azure_with_nested_config(self, mock_providers):
|
||||
"""Test Azure configuration with nested config key."""
|
||||
|
||||
mock_embedding = MagicMock()
|
||||
mock_openai_func = MagicMock(return_value=mock_embedding)
|
||||
mock_providers.__getitem__.return_value = mock_openai_func
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
embedder_config = EmbedderConfig(
|
||||
provider="openai",
|
||||
config={
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
)
|
||||
|
||||
result = get_embedding_function(embedder_config)
|
||||
|
||||
mock_openai_func.assert_called_once_with(
|
||||
api_key="test-azure-key",
|
||||
api_base="https://test.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
model_name="text-embedding-3-small",
|
||||
deployment_id="test-deployment",
|
||||
)
|
||||
assert result == mock_embedding
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
|
||||
def test_regular_openai_with_nested_config(self, mock_providers):
|
||||
"""Test regular OpenAI configuration with nested config."""
|
||||
|
||||
mock_embedding = MagicMock()
|
||||
mock_openai_func = MagicMock(return_value=mock_embedding)
|
||||
mock_providers.__getitem__.return_value = mock_openai_func
|
||||
mock_providers.__contains__.return_value = True
|
||||
|
||||
embedder_config = EmbedderConfig(
|
||||
provider="openai",
|
||||
config={"api_key": "test-openai-key", "model": "text-embedding-3-large"},
|
||||
)
|
||||
|
||||
result = get_embedding_function(embedder_config)
|
||||
|
||||
mock_openai_func.assert_called_once_with(
|
||||
api_key="test-openai-key", model_name="text-embedding-3-large"
|
||||
)
|
||||
assert result == mock_embedding
|
||||
|
||||
def test_flat_format_raises_error(self):
|
||||
"""Test that flat format raises an error."""
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model_name": "text-embedding-3-small",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_embedding_function(embedder_config)
|
||||
|
||||
assert "Invalid embedder configuration format" in str(exc_info.value)
|
||||
assert "nested under a 'config' key" in str(exc_info.value)
|
||||
@@ -1,250 +0,0 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
82
tests/utilities/test_azure_embedder_config.py
Normal file
82
tests/utilities/test_azure_embedder_config.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Test Azure embedder configuration with nested format only."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
|
||||
|
||||
class TestAzureEmbedderConfiguration:
|
||||
"""Test Azure embedder configuration with nested format."""
|
||||
|
||||
@patch(
|
||||
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
|
||||
)
|
||||
def test_azure_openai_with_nested_config(self, mock_openai_func):
|
||||
"""Test Azure configuration using OpenAI provider with nested config key."""
|
||||
mock_embedding = MagicMock()
|
||||
mock_openai_func.return_value = mock_embedding
|
||||
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
result = configurator.configure_embedder(embedder_config)
|
||||
|
||||
mock_openai_func.assert_called_once_with(
|
||||
api_key="test-azure-key",
|
||||
model_name="text-embedding-3-small",
|
||||
api_base="https://test.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
default_headers=None,
|
||||
dimensions=None,
|
||||
deployment_id="test-deployment",
|
||||
organization_id=None,
|
||||
)
|
||||
assert result == mock_embedding
|
||||
|
||||
@patch(
|
||||
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
|
||||
)
|
||||
def test_azure_provider_with_nested_config(self, mock_openai_func):
|
||||
"""Test using 'azure' as provider with nested config."""
|
||||
mock_embedding = MagicMock()
|
||||
mock_openai_func.return_value = mock_embedding
|
||||
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2023-05-15",
|
||||
"model": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
result = configurator.configure_embedder(embedder_config)
|
||||
|
||||
mock_openai_func.assert_called_once_with(
|
||||
api_key="test-azure-key",
|
||||
api_base="https://test.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
model_name="text-embedding-3-small",
|
||||
default_headers=None,
|
||||
dimensions=None,
|
||||
deployment_id="test-deployment",
|
||||
organization_id=None,
|
||||
)
|
||||
assert result == mock_embedding
|
||||
Reference in New Issue
Block a user