mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
- support nested config format with embedderconfig typeddict - fix parsing for model/model_name compatibility - add validation, typing_extensions, and improved type hints - enhance embedding factory with env var injection and provider support - add tests for openai, azure, and all embedding providers - misc fixes: test file rename, updated mocking patterns
This commit is contained in:
@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
|||||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
from crewai.rag.config.utils import get_rag_client
|
from crewai.rag.config.utils import get_rag_client
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.embeddings.factory import get_embedding_function
|
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||||
|
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.rag.types import BaseRecord
|
from crewai.rag.types import BaseRecord
|
||||||
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: dict[str, Any] | None = None,
|
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
path: str | None = None,
|
path: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
if self.embedder_config:
|
if self.embedder_config:
|
||||||
embedding_function = get_embedding_function(self.embedder_config)
|
embedding_function = get_embedding_function(self.embedder_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = embedding_function(["test"])
|
||||||
|
except Exception as e:
|
||||||
|
provider = (
|
||||||
|
self.embedder_config.provider
|
||||||
|
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||||
|
else self.embedder_config.get("provider", "unknown")
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||||
|
f"Provider: {provider}\n"
|
||||||
|
f"Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
config = ChromaDBConfig(
|
config = ChromaDBConfig(
|
||||||
embedding_function=cast(
|
embedding_function=cast(
|
||||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Minimal embedding function factory for CrewAI."""
|
"""Minimal embedding function factory for CrewAI."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable, MutableMapping
|
||||||
|
from typing import Any, Final, Literal, TypedDict
|
||||||
|
|
||||||
from chromadb import EmbeddingFunction
|
from chromadb import EmbeddingFunction
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
|
|||||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||||
Text2VecEmbeddingFunction,
|
Text2VecEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||||
|
|
||||||
|
AllowedEmbeddingProviders = Literal[
|
||||||
|
"openai",
|
||||||
|
"cohere",
|
||||||
|
"ollama",
|
||||||
|
"huggingface",
|
||||||
|
"sentence-transformer",
|
||||||
|
"instructor",
|
||||||
|
"google-palm",
|
||||||
|
"google-generativeai",
|
||||||
|
"google-vertex",
|
||||||
|
"amazon-bedrock",
|
||||||
|
"jina",
|
||||||
|
"roboflow",
|
||||||
|
"openclip",
|
||||||
|
"text2vec",
|
||||||
|
"onnx",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderConfig(TypedDict):
|
||||||
|
"""Configuration for embedding functions with nested format."""
|
||||||
|
|
||||||
|
provider: AllowedEmbeddingProviders
|
||||||
|
config: NotRequired[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
EMBEDDING_PROVIDERS: Final[
|
||||||
|
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
|
||||||
|
] = {
|
||||||
|
"openai": OpenAIEmbeddingFunction,
|
||||||
|
"cohere": CohereEmbeddingFunction,
|
||||||
|
"ollama": OllamaEmbeddingFunction,
|
||||||
|
"huggingface": HuggingFaceEmbeddingFunction,
|
||||||
|
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||||
|
"instructor": InstructorEmbeddingFunction,
|
||||||
|
"google-palm": GooglePalmEmbeddingFunction,
|
||||||
|
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||||
|
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||||
|
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||||
|
"jina": JinaEmbeddingFunction,
|
||||||
|
"roboflow": RoboflowEmbeddingFunction,
|
||||||
|
"openclip": OpenCLIPEmbeddingFunction,
|
||||||
|
"text2vec": Text2VecEmbeddingFunction,
|
||||||
|
"onnx": ONNXMiniLM_L6_V2,
|
||||||
|
}
|
||||||
|
|
||||||
|
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
|
||||||
|
"openai": ("OPENAI_API_KEY", "api_key"),
|
||||||
|
"cohere": ("COHERE_API_KEY", "api_key"),
|
||||||
|
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
|
||||||
|
"google-palm": ("GOOGLE_API_KEY", "api_key"),
|
||||||
|
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
|
||||||
|
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
|
||||||
|
"jina": ("JINA_API_KEY", "api_key"),
|
||||||
|
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_api_key_from_env(
|
||||||
|
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Inject API key or other required configuration from environment if not explicitly provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The embedding provider name
|
||||||
|
config_dict: The configuration dictionary to modify in-place
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If required libraries for certain providers are not installed
|
||||||
|
ValueError: If AWS session creation fails for amazon-bedrock
|
||||||
|
"""
|
||||||
|
if provider in PROVIDER_ENV_MAPPING:
|
||||||
|
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
|
||||||
|
if config_key not in config_dict:
|
||||||
|
env_value = os.getenv(env_var_name)
|
||||||
|
if env_value:
|
||||||
|
config_dict[config_key] = env_value
|
||||||
|
|
||||||
|
if provider == "amazon-bedrock":
|
||||||
|
if "session" not in config_dict:
|
||||||
|
try:
|
||||||
|
import boto3 # type: ignore[import]
|
||||||
|
|
||||||
|
config_dict["session"] = boto3.Session()
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"boto3 is required for amazon-bedrock embeddings. "
|
||||||
|
"Install it with: uv add boto3"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to create AWS session for amazon-bedrock. "
|
||||||
|
f"Ensure AWS credentials are configured. Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_function(
|
def get_embedding_function(
|
||||||
config: EmbeddingOptions | dict | None = None,
|
config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Get embedding function - delegates to ChromaDB.
|
"""Get embedding function - delegates to ChromaDB.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
config: Optional configuration - either:
|
||||||
- provider: The embedding provider to use (default: "openai")
|
- EmbeddingOptions: Pydantic model with flat configuration
|
||||||
- Any other provider-specific parameters
|
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
|
||||||
|
- None: Uses default OpenAI configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmbeddingFunction instance ready for use with ChromaDB
|
EmbeddingFunction instance ready for use with ChromaDB
|
||||||
@@ -81,31 +180,33 @@ def get_embedding_function(
|
|||||||
>>> embedder = get_embedding_function()
|
>>> embedder = get_embedding_function()
|
||||||
|
|
||||||
# Use Cohere with dict
|
# Use Cohere with dict
|
||||||
>>> embedder = get_embedding_function({
|
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||||
... "provider": "cohere",
|
... "provider": "cohere",
|
||||||
... "api_key": "your-key",
|
... "config": {
|
||||||
... "model_name": "embed-english-v3.0"
|
... "api_key": "your-key",
|
||||||
... })
|
... "model_name": "embed-english-v3.0"
|
||||||
|
... }
|
||||||
|
... }))
|
||||||
|
|
||||||
# Use with EmbeddingOptions
|
# Use with EmbeddingOptions
|
||||||
>>> embedder = get_embedding_function(
|
>>> embedder = get_embedding_function(
|
||||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||||
... )
|
... )
|
||||||
|
|
||||||
# Use local sentence transformers (no API key needed)
|
# Use Azure OpenAI
|
||||||
>>> embedder = get_embedding_function({
|
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||||
... "provider": "sentence-transformer",
|
... "provider": "openai",
|
||||||
... "model_name": "all-MiniLM-L6-v2"
|
... "config": {
|
||||||
|
... "api_key": "your-azure-key",
|
||||||
|
... "api_base": "https://your-resource.openai.azure.com/",
|
||||||
|
... "api_type": "azure",
|
||||||
|
... "api_version": "2023-05-15",
|
||||||
|
... "model": "text-embedding-3-small",
|
||||||
|
... "deployment_id": "your-deployment-name"
|
||||||
|
... }
|
||||||
... })
|
... })
|
||||||
|
|
||||||
# Use Ollama for local embeddings
|
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||||
>>> embedder = get_embedding_function({
|
|
||||||
... "provider": "ollama",
|
|
||||||
... "model_name": "nomic-embed-text"
|
|
||||||
... })
|
|
||||||
|
|
||||||
# Use ONNX (no API key needed)
|
|
||||||
>>> embedder = get_embedding_function({
|
|
||||||
... "provider": "onnx"
|
... "provider": "onnx"
|
||||||
... })
|
... })
|
||||||
"""
|
"""
|
||||||
@@ -114,35 +215,33 @@ def get_embedding_function(
|
|||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle EmbeddingOptions object
|
provider: AllowedEmbeddingProviders
|
||||||
|
config_dict: dict[str, Any]
|
||||||
|
|
||||||
if isinstance(config, EmbeddingOptions):
|
if isinstance(config, EmbeddingOptions):
|
||||||
config_dict = config.model_dump(exclude_none=True)
|
config_dict = config.model_dump(exclude_none=True)
|
||||||
|
provider = config_dict["provider"]
|
||||||
else:
|
else:
|
||||||
config_dict = config.copy()
|
provider = config["provider"]
|
||||||
|
nested: dict[str, Any] = config.get("config", {})
|
||||||
|
|
||||||
provider = config_dict.pop("provider", "openai")
|
if not nested and len(config) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid embedder configuration format. "
|
||||||
|
"Configuration must be nested under a 'config' key. "
|
||||||
|
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
|
||||||
|
)
|
||||||
|
|
||||||
embedding_functions = {
|
config_dict = dict(nested)
|
||||||
"openai": OpenAIEmbeddingFunction,
|
if "model" in config_dict and "model_name" not in config_dict:
|
||||||
"cohere": CohereEmbeddingFunction,
|
config_dict["model_name"] = config_dict.pop("model")
|
||||||
"ollama": OllamaEmbeddingFunction,
|
|
||||||
"huggingface": HuggingFaceEmbeddingFunction,
|
|
||||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
|
||||||
"instructor": InstructorEmbeddingFunction,
|
|
||||||
"google-palm": GooglePalmEmbeddingFunction,
|
|
||||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
|
||||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
|
||||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
|
||||||
"jina": JinaEmbeddingFunction,
|
|
||||||
"roboflow": RoboflowEmbeddingFunction,
|
|
||||||
"openclip": OpenCLIPEmbeddingFunction,
|
|
||||||
"text2vec": Text2VecEmbeddingFunction,
|
|
||||||
"onnx": ONNXMiniLM_L6_V2,
|
|
||||||
}
|
|
||||||
|
|
||||||
if provider not in embedding_functions:
|
if provider not in EMBEDDING_PROVIDERS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported provider: {provider}. "
|
f"Unsupported provider: {provider}. "
|
||||||
f"Available providers: {list(embedding_functions.keys())}"
|
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
|
||||||
)
|
)
|
||||||
return embedding_functions[provider](**config_dict)
|
|
||||||
|
_inject_api_key_from_env(provider, config_dict)
|
||||||
|
|
||||||
|
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.factory import EmbedderConfig
|
||||||
|
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||||
|
|
||||||
|
|
||||||
class BaseRAGStorage(ABC):
|
class BaseRAGStorage(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: dict[str, Any] | None = None,
|
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
):
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|||||||
598
tests/rag/embeddings/test_embedding_factory.py
Normal file
598
tests/rag/embeddings/test_embedding_factory.py
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
"""Enhanced tests for embedding function factory."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||||
|
get_embedding_function,
|
||||||
|
)
|
||||||
|
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_default() -> None:
|
||||||
|
"""Test default embedding function when no config provided."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai.return_value = mock_instance
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||||
|
):
|
||||||
|
result = get_embedding_function()
|
||||||
|
|
||||||
|
mock_openai.assert_called_once_with(
|
||||||
|
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_with_embedding_options() -> None:
|
||||||
|
"""Test embedding function creation with EmbeddingOptions object."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
options = EmbeddingOptions(
|
||||||
|
provider="openai",
|
||||||
|
api_key=SecretStr("test-key"),
|
||||||
|
model_name="text-embedding-3-large",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_embedding_function(options)
|
||||||
|
|
||||||
|
call_kwargs = mock_openai.call_args.kwargs
|
||||||
|
assert "api_key" in call_kwargs
|
||||||
|
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||||
|
assert "model_name" in call_kwargs
|
||||||
|
assert call_kwargs["model_name"] == "text-embedding-3-large"
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_sentence_transformer() -> None:
|
||||||
|
"""Test sentence transformer embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_st = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_st
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "sentence-transformer",
|
||||||
|
"config": {"model_name": "all-MiniLM-L6-v2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_ollama() -> None:
|
||||||
|
"""Test Ollama embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_ollama = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_ollama
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model_name": "nomic-embed-text",
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_ollama.assert_called_once_with(
|
||||||
|
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_cohere() -> None:
|
||||||
|
"""Test Cohere embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_cohere = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_cohere
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "cohere",
|
||||||
|
"config": {"api_key": "cohere-key", "model_name": "embed-english-v3.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_cohere.assert_called_once_with(
|
||||||
|
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_huggingface() -> None:
|
||||||
|
"""Test HuggingFace embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_hf = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_hf
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "hf-token",
|
||||||
|
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_hf.assert_called_once_with(
|
||||||
|
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_onnx() -> None:
|
||||||
|
"""Test ONNX embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_onnx = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_onnx
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {"provider": "onnx"}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_onnx.assert_called_once()
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_google_palm() -> None:
|
||||||
|
"""Test Google PaLM embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_palm = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_palm
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {"provider": "google-palm", "config": {"api_key": "palm-key"}}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||||
|
"""Test Amazon Bedrock embedding function with explicit session."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_bedrock
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
# Provide an explicit session to avoid boto3 import
|
||||||
|
mock_session = MagicMock()
|
||||||
|
config = {
|
||||||
|
"provider": "amazon-bedrock",
|
||||||
|
"config": {
|
||||||
|
"session": mock_session,
|
||||||
|
"region_name": "us-west-2",
|
||||||
|
"model_name": "amazon.titan-embed-text-v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_bedrock.assert_called_once_with(
|
||||||
|
session=mock_session,
|
||||||
|
region_name="us-west-2",
|
||||||
|
model_name="amazon.titan-embed-text-v1",
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_jina() -> None:
|
||||||
|
"""Test Jina embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_jina = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_jina
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "jina",
|
||||||
|
"config": {
|
||||||
|
"api_key": "jina-key",
|
||||||
|
"model_name": "jina-embeddings-v2-base-en",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_jina.assert_called_once_with(
|
||||||
|
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_unsupported_provider() -> None:
|
||||||
|
"""Test handling of unsupported provider."""
|
||||||
|
config = {"provider": "unsupported-provider"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||||
|
get_embedding_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_config_modification() -> None:
|
||||||
|
"""Test that original config dict is not modified."""
|
||||||
|
original_config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
|
||||||
|
}
|
||||||
|
config_copy = original_config.copy()
|
||||||
|
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
get_embedding_function(config_copy)
|
||||||
|
|
||||||
|
assert config_copy == original_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_exclude_none_values() -> None:
|
||||||
|
"""Test that None values are excluded from embedding function calls."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
options = EmbeddingOptions(
|
||||||
|
provider="openai", api_key=SecretStr("test-key"), model_name=None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_embedding_function(options)
|
||||||
|
|
||||||
|
call_kwargs = mock_openai.call_args.kwargs
|
||||||
|
assert "api_key" in call_kwargs
|
||||||
|
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||||
|
assert "model_name" not in call_kwargs
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_instructor() -> None:
|
||||||
|
"""Test Instructor embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instructor = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_instructor
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "instructor",
|
||||||
|
"config": {"model_name": "hkunlp/instructor-large"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_google_generativeai() -> None:
|
||||||
|
"""Test Google Generative AI embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_google = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_google
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {"api_key": "google-key", "model_name": "models/embedding-001"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_google.assert_called_once_with(
|
||||||
|
api_key="google-key", model_name="models/embedding-001"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_google_vertex() -> None:
|
||||||
|
"""Test Google Vertex AI embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_vertex = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_vertex
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"api_key": "vertex-key",
|
||||||
|
"project_id": "my-project",
|
||||||
|
"region": "us-central1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_vertex.assert_called_once_with(
|
||||||
|
api_key="vertex-key", project_id="my-project", region="us-central1"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_roboflow() -> None:
|
||||||
|
"""Test Roboflow embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_roboflow = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_roboflow
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "roboflow",
|
||||||
|
"config": {
|
||||||
|
"api_key": "roboflow-key",
|
||||||
|
"api_url": "https://infer.roboflow.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_roboflow.assert_called_once_with(
|
||||||
|
api_key="roboflow-key", api_url="https://infer.roboflow.com"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_openclip() -> None:
|
||||||
|
"""Test OpenCLIP embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openclip = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openclip
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "openclip",
|
||||||
|
"config": {"model_name": "ViT-B-32", "checkpoint": "laion2b_s34b_b79k"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_openclip.assert_called_once_with(
|
||||||
|
model_name="ViT-B-32", checkpoint="laion2b_s34b_b79k"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_text2vec() -> None:
|
||||||
|
"""Test Text2Vec embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_text2vec = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_text2vec
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "text2vec",
|
||||||
|
"config": {"model_name": "shibing624/text2vec-base-chinese"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_text2vec.assert_called_once_with(
|
||||||
|
model_name="shibing624/text2vec-base-chinese"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_to_model_name_conversion() -> None:
|
||||||
|
"""Test that 'model' field is converted to 'model_name' for nested config."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {"api_key": "test-key", "model": "text-embedding-3-small"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_openai.assert_called_once_with(
|
||||||
|
api_key="test-key", model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_injection_from_env_openai() -> None:
|
||||||
|
"""Test that OpenAI API key is injected from environment when not provided."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||||
|
mock_getenv.return_value = "env-openai-key"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {"model": "text-embedding-3-small"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_getenv.assert_called_with("OPENAI_API_KEY")
|
||||||
|
mock_openai.assert_called_once_with(
|
||||||
|
api_key="env-openai-key", model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_injection_from_env_cohere() -> None:
|
||||||
|
"""Test that Cohere API key is injected from environment when not provided."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_cohere = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_cohere
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||||
|
mock_getenv.return_value = "env-cohere-key"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "cohere",
|
||||||
|
"config": {"model_name": "embed-english-v3.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_getenv.assert_called_with("COHERE_API_KEY")
|
||||||
|
mock_cohere.assert_called_once_with(
|
||||||
|
api_key="env-cohere-key", model_name="embed-english-v3.0"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_not_injected_when_provided() -> None:
|
||||||
|
"""Test that API key from config takes precedence over environment."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_openai = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv:
|
||||||
|
mock_getenv.return_value = "env-key"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {"api_key": "config-key", "model": "text-embedding-3-small"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_openai.assert_called_once_with(
|
||||||
|
api_key="config-key", model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_session_injection() -> None:
|
||||||
|
"""Test that boto3 session is automatically created for amazon-bedrock."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_bedrock
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
mock_boto3 = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_boto3.Session.return_value = mock_session
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "amazon-bedrock",
|
||||||
|
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_boto3.Session.assert_called_once()
|
||||||
|
mock_bedrock.assert_called_once_with(
|
||||||
|
session=mock_session, model_name="amazon.titan-embed-text-v1"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_session_not_injected_when_provided() -> None:
|
||||||
|
"""Test that provided session is used for amazon-bedrock."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_bedrock = MagicMock(return_value=mock_instance)
|
||||||
|
mock_providers.__getitem__.return_value = mock_bedrock
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
existing_session = MagicMock()
|
||||||
|
config = {
|
||||||
|
"provider": "amazon-bedrock",
|
||||||
|
"config": {
|
||||||
|
"session": existing_session,
|
||||||
|
"model_name": "amazon.titan-embed-text-v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_bedrock.assert_called_once_with(
|
||||||
|
session=existing_session, model_name="amazon.titan-embed-text-v1"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_boto3_import_error() -> None:
|
||||||
|
"""Test error handling when boto3 is not installed."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"boto3": None}):
|
||||||
|
config = {
|
||||||
|
"provider": "amazon-bedrock",
|
||||||
|
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match="boto3 is required for amazon-bedrock"
|
||||||
|
):
|
||||||
|
get_embedding_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_session_creation_error() -> None:
|
||||||
|
"""Test error handling when AWS session creation fails."""
|
||||||
|
with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers:
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
mock_boto3 = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||||
|
mock_boto3.Session.side_effect = Exception("AWS credentials not configured")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "amazon-bedrock",
|
||||||
|
"config": {"model_name": "amazon.titan-embed-text-v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Failed to create AWS session"):
|
||||||
|
get_embedding_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_config_format() -> None:
|
||||||
|
"""Test error handling for invalid config format."""
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid embedder configuration format"):
|
||||||
|
get_embedding_function(config)
|
||||||
79
tests/rag/embeddings/test_factory_azure.py
Normal file
79
tests/rag/embeddings/test_factory_azure.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Test Azure embedder configuration with factory."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureEmbedderFactory:
|
||||||
|
"""Test Azure embedder configuration with factory function."""
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
|
||||||
|
def test_azure_with_nested_config(self, mock_providers):
|
||||||
|
"""Test Azure configuration with nested config key."""
|
||||||
|
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_openai_func = MagicMock(return_value=mock_embedding)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai_func
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
embedder_config = EmbedderConfig(
|
||||||
|
provider="openai",
|
||||||
|
config={
|
||||||
|
"api_key": "test-azure-key",
|
||||||
|
"api_base": "https://test.openai.azure.com/",
|
||||||
|
"api_type": "azure",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"deployment_id": "test-deployment",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_embedding_function(embedder_config)
|
||||||
|
|
||||||
|
mock_openai_func.assert_called_once_with(
|
||||||
|
api_key="test-azure-key",
|
||||||
|
api_base="https://test.openai.azure.com/",
|
||||||
|
api_type="azure",
|
||||||
|
api_version="2023-05-15",
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
deployment_id="test-deployment",
|
||||||
|
)
|
||||||
|
assert result == mock_embedding
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS")
|
||||||
|
def test_regular_openai_with_nested_config(self, mock_providers):
|
||||||
|
"""Test regular OpenAI configuration with nested config."""
|
||||||
|
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_openai_func = MagicMock(return_value=mock_embedding)
|
||||||
|
mock_providers.__getitem__.return_value = mock_openai_func
|
||||||
|
mock_providers.__contains__.return_value = True
|
||||||
|
|
||||||
|
embedder_config = EmbedderConfig(
|
||||||
|
provider="openai",
|
||||||
|
config={"api_key": "test-openai-key", "model": "text-embedding-3-large"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_embedding_function(embedder_config)
|
||||||
|
|
||||||
|
mock_openai_func.assert_called_once_with(
|
||||||
|
api_key="test-openai-key", model_name="text-embedding-3-large"
|
||||||
|
)
|
||||||
|
assert result == mock_embedding
|
||||||
|
|
||||||
|
def test_flat_format_raises_error(self):
|
||||||
|
"""Test that flat format raises an error."""
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model_name": "text-embedding-3-small",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
get_embedding_function(embedder_config)
|
||||||
|
|
||||||
|
assert "Invalid embedder configuration format" in str(exc_info.value)
|
||||||
|
assert "nested under a 'config' key" in str(exc_info.value)
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
"""Enhanced tests for embedding function factory."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
|
||||||
get_embedding_function,
|
|
||||||
)
|
|
||||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_default() -> None:
|
|
||||||
"""Test default embedding function when no config provided."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_openai.return_value = mock_instance
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
|
||||||
):
|
|
||||||
result = get_embedding_function()
|
|
||||||
|
|
||||||
mock_openai.assert_called_once_with(
|
|
||||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_with_embedding_options() -> None:
|
|
||||||
"""Test embedding function creation with EmbeddingOptions object."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_openai.return_value = mock_instance
|
|
||||||
|
|
||||||
options = EmbeddingOptions(
|
|
||||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = get_embedding_function(options)
|
|
||||||
|
|
||||||
call_kwargs = mock_openai.call_args.kwargs
|
|
||||||
assert "api_key" in call_kwargs
|
|
||||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
|
||||||
# OpenAI uses model_name parameter, not model
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_sentence_transformer() -> None:
|
|
||||||
"""Test sentence transformer embedding function."""
|
|
||||||
with patch(
|
|
||||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
|
||||||
) as mock_st:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_st.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_ollama() -> None:
|
|
||||||
"""Test Ollama embedding function."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_ollama.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "ollama",
|
|
||||||
"model_name": "nomic-embed-text",
|
|
||||||
"url": "http://localhost:11434",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_ollama.assert_called_once_with(
|
|
||||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_cohere() -> None:
|
|
||||||
"""Test Cohere embedding function."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_cohere.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "cohere",
|
|
||||||
"api_key": "cohere-key",
|
|
||||||
"model_name": "embed-english-v3.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_cohere.assert_called_once_with(
|
|
||||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_huggingface() -> None:
|
|
||||||
"""Test HuggingFace embedding function."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_hf.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "huggingface",
|
|
||||||
"api_key": "hf-token",
|
|
||||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_hf.assert_called_once_with(
|
|
||||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_onnx() -> None:
|
|
||||||
"""Test ONNX embedding function."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_onnx.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {"provider": "onnx"}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_onnx.assert_called_once()
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_google_palm() -> None:
|
|
||||||
"""Test Google PaLM embedding function."""
|
|
||||||
with patch(
|
|
||||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
|
||||||
) as mock_palm:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_palm.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
|
||||||
"""Test Amazon Bedrock embedding function."""
|
|
||||||
with patch(
|
|
||||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
|
||||||
) as mock_bedrock:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_bedrock.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "amazon-bedrock",
|
|
||||||
"region_name": "us-west-2",
|
|
||||||
"model_name": "amazon.titan-embed-text-v1",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_bedrock.assert_called_once_with(
|
|
||||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_jina() -> None:
|
|
||||||
"""Test Jina embedding function."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_jina.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "jina",
|
|
||||||
"api_key": "jina-key",
|
|
||||||
"model_name": "jina-embeddings-v2-base-en",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_jina.assert_called_once_with(
|
|
||||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
|
||||||
)
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_unsupported_provider() -> None:
|
|
||||||
"""Test handling of unsupported provider."""
|
|
||||||
config = {"provider": "unsupported-provider"}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
|
||||||
get_embedding_function(config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_config_modification() -> None:
|
|
||||||
"""Test that original config dict is not modified."""
|
|
||||||
original_config = {
|
|
||||||
"provider": "openai",
|
|
||||||
"api_key": "test-key",
|
|
||||||
"model": "text-embedding-3-small",
|
|
||||||
}
|
|
||||||
config_copy = original_config.copy()
|
|
||||||
|
|
||||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
|
||||||
get_embedding_function(config_copy)
|
|
||||||
|
|
||||||
assert config_copy == original_config
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_exclude_none_values() -> None:
|
|
||||||
"""Test that None values are excluded from embedding function calls."""
|
|
||||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_openai.return_value = mock_instance
|
|
||||||
|
|
||||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
|
||||||
|
|
||||||
result = get_embedding_function(options)
|
|
||||||
|
|
||||||
call_kwargs = mock_openai.call_args.kwargs
|
|
||||||
assert "api_key" in call_kwargs
|
|
||||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
|
||||||
assert "model" not in call_kwargs
|
|
||||||
assert result == mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_function_instructor() -> None:
|
|
||||||
"""Test Instructor embedding function."""
|
|
||||||
with patch(
|
|
||||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
|
||||||
) as mock_instructor:
|
|
||||||
mock_instance = MagicMock()
|
|
||||||
mock_instructor.return_value = mock_instance
|
|
||||||
|
|
||||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
|
||||||
|
|
||||||
result = get_embedding_function(config)
|
|
||||||
|
|
||||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
|
||||||
assert result == mock_instance
|
|
||||||
82
tests/utilities/test_azure_embedder_config.py
Normal file
82
tests/utilities/test_azure_embedder_config.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Test Azure embedder configuration with nested format only."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureEmbedderConfiguration:
|
||||||
|
"""Test Azure embedder configuration with nested format."""
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
|
||||||
|
)
|
||||||
|
def test_azure_openai_with_nested_config(self, mock_openai_func):
|
||||||
|
"""Test Azure configuration using OpenAI provider with nested config key."""
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_openai_func.return_value = mock_embedding
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-azure-key",
|
||||||
|
"api_base": "https://test.openai.azure.com/",
|
||||||
|
"api_type": "azure",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"deployment_id": "test-deployment",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
mock_openai_func.assert_called_once_with(
|
||||||
|
api_key="test-azure-key",
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
api_base="https://test.openai.azure.com/",
|
||||||
|
api_type="azure",
|
||||||
|
api_version="2023-05-15",
|
||||||
|
default_headers=None,
|
||||||
|
dimensions=None,
|
||||||
|
deployment_id="test-deployment",
|
||||||
|
organization_id=None,
|
||||||
|
)
|
||||||
|
assert result == mock_embedding
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction"
|
||||||
|
)
|
||||||
|
def test_azure_provider_with_nested_config(self, mock_openai_func):
|
||||||
|
"""Test using 'azure' as provider with nested config."""
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_openai_func.return_value = mock_embedding
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "azure",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-azure-key",
|
||||||
|
"api_base": "https://test.openai.azure.com/",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"deployment_id": "test-deployment",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
mock_openai_func.assert_called_once_with(
|
||||||
|
api_key="test-azure-key",
|
||||||
|
api_base="https://test.openai.azure.com/",
|
||||||
|
api_type="azure",
|
||||||
|
api_version="2023-05-15",
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
default_headers=None,
|
||||||
|
dimensions=None,
|
||||||
|
deployment_id="test-deployment",
|
||||||
|
organization_id=None,
|
||||||
|
)
|
||||||
|
assert result == mock_embedding
|
||||||
Reference in New Issue
Block a user