diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index f1ae919bc..4f6526c59 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient -from crewai.rag.embeddings.factory import get_embedding_function +from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function +from crewai.rag.embeddings.types import EmbeddingOptions from crewai.rag.factory import create_client from crewai.rag.storage.base_rag_storage import BaseRAGStorage from crewai.rag.types import BaseRecord @@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage): self, type: str, allow_reset: bool = True, - embedder_config: dict[str, Any] | None = None, + embedder_config: EmbeddingOptions | EmbedderConfig | None = None, crew: Any = None, path: str | None = None, ) -> None: @@ -50,6 +51,21 @@ class RAGStorage(BaseRAGStorage): if self.embedder_config: embedding_function = get_embedding_function(self.embedder_config) + + try: + _ = embedding_function(["test"]) + except Exception as e: + provider = ( + self.embedder_config.provider + if isinstance(self.embedder_config, EmbeddingOptions) + else self.embedder_config.get("provider", "unknown") + ) + raise ValueError( + f"Failed to initialize embedder. Please check your configuration or connection.\n" + f"Provider: {provider}\n" + f"Error: {e}" + ) from e + config = ChromaDBConfig( embedding_function=cast( ChromaEmbeddingFunctionWrapper, embedding_function diff --git a/src/crewai/rag/embeddings/factory.py b/src/crewai/rag/embeddings/factory.py index 0b76ef36a..3ced72655 100644 --- a/src/crewai/rag/embeddings/factory.py +++ b/src/crewai/rag/embeddings/factory.py @@ -1,6 +1,8 @@ """Minimal embedding function factory for CrewAI.""" import os +from collections.abc import Callable, MutableMapping +from typing import Any, Final, Literal, TypedDict from chromadb import EmbeddingFunction from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( @@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function from chromadb.utils.embedding_functions.text2vec_embedding_function import ( Text2VecEmbeddingFunction, ) +from typing_extensions import NotRequired from crewai.rag.embeddings.types import EmbeddingOptions +AllowedEmbeddingProviders = Literal[ + "openai", + "cohere", + "ollama", + "huggingface", + "sentence-transformer", + "instructor", + "google-palm", + "google-generativeai", + "google-vertex", + "amazon-bedrock", + "jina", + "roboflow", + "openclip", + "text2vec", + "onnx", +] + + +class EmbedderConfig(TypedDict): + """Configuration for embedding functions with nested format.""" + + provider: AllowedEmbeddingProviders + config: NotRequired[dict[str, Any]] + + +EMBEDDING_PROVIDERS: Final[ + dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]] +] = { + "openai": OpenAIEmbeddingFunction, + "cohere": CohereEmbeddingFunction, + "ollama": OllamaEmbeddingFunction, + "huggingface": HuggingFaceEmbeddingFunction, + "sentence-transformer": SentenceTransformerEmbeddingFunction, + "instructor": InstructorEmbeddingFunction, + "google-palm": GooglePalmEmbeddingFunction, + "google-generativeai": GoogleGenerativeAiEmbeddingFunction, + "google-vertex": GoogleVertexEmbeddingFunction, + "amazon-bedrock": AmazonBedrockEmbeddingFunction, + "jina": JinaEmbeddingFunction, + "roboflow": RoboflowEmbeddingFunction, + "openclip": OpenCLIPEmbeddingFunction, + "text2vec": Text2VecEmbeddingFunction, + "onnx": ONNXMiniLM_L6_V2, +} + +PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = { + "openai": ("OPENAI_API_KEY", "api_key"), + "cohere": ("COHERE_API_KEY", "api_key"), + "huggingface": ("HUGGINGFACE_API_KEY", "api_key"), + "google-palm": ("GOOGLE_API_KEY", "api_key"), + "google-generativeai": ("GOOGLE_API_KEY", "api_key"), + "google-vertex": ("GOOGLE_API_KEY", "api_key"), + "jina": ("JINA_API_KEY", "api_key"), + "roboflow": ("ROBOFLOW_API_KEY", "api_key"), +} + + +def _inject_api_key_from_env( + provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any] +) -> None: + """Inject API key or other required configuration from environment if not explicitly provided. + + Args: + provider: The embedding provider name + config_dict: The configuration dictionary to modify in-place + + Raises: + ImportError: If required libraries for certain providers are not installed + ValueError: If AWS session creation fails for amazon-bedrock + """ + if provider in PROVIDER_ENV_MAPPING: + env_var_name, config_key = PROVIDER_ENV_MAPPING[provider] + if config_key not in config_dict: + env_value = os.getenv(env_var_name) + if env_value: + config_dict[config_key] = env_value + + if provider == "amazon-bedrock": + if "session" not in config_dict: + try: + import boto3 # type: ignore[import] + + config_dict["session"] = boto3.Session() + except ImportError as e: + raise ImportError( + "boto3 is required for amazon-bedrock embeddings. " + "Install it with: uv add boto3" + ) from e + except Exception as e: + raise ValueError( + f"Failed to create AWS session for amazon-bedrock. " + f"Ensure AWS credentials are configured. Error: {e}" + ) from e + def get_embedding_function( - config: EmbeddingOptions | dict | None = None, + config: EmbeddingOptions | EmbedderConfig | None = None, ) -> EmbeddingFunction: """Get embedding function - delegates to ChromaDB. Args: - config: Optional configuration - either an EmbeddingOptions object or a dict with: - - provider: The embedding provider to use (default: "openai") - - Any other provider-specific parameters + config: Optional configuration - either: + - EmbeddingOptions: Pydantic model with flat configuration + - EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict} + - None: Uses default OpenAI configuration Returns: EmbeddingFunction instance ready for use with ChromaDB @@ -81,31 +180,33 @@ def get_embedding_function( >>> embedder = get_embedding_function() # Use Cohere with dict - >>> embedder = get_embedding_function({ + >>> embedder = get_embedding_function(EmbedderConfig(**{ ... "provider": "cohere", - ... "api_key": "your-key", - ... "model_name": "embed-english-v3.0" - ... }) + ... "config": { + ... "api_key": "your-key", + ... "model_name": "embed-english-v3.0" + ... } + ... })) # Use with EmbeddingOptions >>> embedder = get_embedding_function( ... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2") ... ) - # Use local sentence transformers (no API key needed) - >>> embedder = get_embedding_function({ - ... "provider": "sentence-transformer", - ... "model_name": "all-MiniLM-L6-v2" + # Use Azure OpenAI + >>> embedder = get_embedding_function(EmbedderConfig(**{ + ... "provider": "openai", + ... "config": { + ... "api_key": "your-azure-key", + ... "api_base": "https://your-resource.openai.azure.com/", + ... "api_type": "azure", + ... "api_version": "2023-05-15", + ... "model": "text-embedding-3-small", + ... "deployment_id": "your-deployment-name" + ... } ... }) - # Use Ollama for local embeddings - >>> embedder = get_embedding_function({ - ... "provider": "ollama", - ... "model_name": "nomic-embed-text" - ... }) - - # Use ONNX (no API key needed) - >>> embedder = get_embedding_function({ + >>> embedder = get_embedding_function(EmbedderConfig(**{ ... "provider": "onnx" ... }) """ @@ -114,35 +215,33 @@ def get_embedding_function( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" ) - # Handle EmbeddingOptions object + provider: AllowedEmbeddingProviders + config_dict: dict[str, Any] + if isinstance(config, EmbeddingOptions): config_dict = config.model_dump(exclude_none=True) + provider = config_dict["provider"] else: - config_dict = config.copy() + provider = config["provider"] + nested: dict[str, Any] = config.get("config", {}) - provider = config_dict.pop("provider", "openai") + if not nested and len(config) > 1: + raise ValueError( + "Invalid embedder configuration format. " + "Configuration must be nested under a 'config' key. " + "Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}" + ) - embedding_functions = { - "openai": OpenAIEmbeddingFunction, - "cohere": CohereEmbeddingFunction, - "ollama": OllamaEmbeddingFunction, - "huggingface": HuggingFaceEmbeddingFunction, - "sentence-transformer": SentenceTransformerEmbeddingFunction, - "instructor": InstructorEmbeddingFunction, - "google-palm": GooglePalmEmbeddingFunction, - "google-generativeai": GoogleGenerativeAiEmbeddingFunction, - "google-vertex": GoogleVertexEmbeddingFunction, - "amazon-bedrock": AmazonBedrockEmbeddingFunction, - "jina": JinaEmbeddingFunction, - "roboflow": RoboflowEmbeddingFunction, - "openclip": OpenCLIPEmbeddingFunction, - "text2vec": Text2VecEmbeddingFunction, - "onnx": ONNXMiniLM_L6_V2, - } + config_dict = dict(nested) + if "model" in config_dict and "model_name" not in config_dict: + config_dict["model_name"] = config_dict.pop("model") - if provider not in embedding_functions: + if provider not in EMBEDDING_PROVIDERS: raise ValueError( f"Unsupported provider: {provider}. " - f"Available providers: {list(embedding_functions.keys())}" + f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}" ) - return embedding_functions[provider](**config_dict) + + _inject_api_key_from_env(provider, config_dict) + + return EMBEDDING_PROVIDERS[provider](**config_dict) diff --git a/src/crewai/rag/storage/base_rag_storage.py b/src/crewai/rag/storage/base_rag_storage.py index 772ed4266..59189820c 100644 --- a/src/crewai/rag/storage/base_rag_storage.py +++ b/src/crewai/rag/storage/base_rag_storage.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import Any +from crewai.rag.embeddings.factory import EmbedderConfig +from crewai.rag.embeddings.types import EmbeddingOptions + class BaseRAGStorage(ABC): """ @@ -13,7 +16,7 @@ class BaseRAGStorage(ABC): self, type: str, allow_reset: bool = True, - embedder_config: dict[str, Any] | None = None, + embedder_config: EmbeddingOptions | EmbedderConfig | None = None, crew: Any = None, ): self.type = type diff --git a/tests/rag/embeddings/test_embedding_factory.py b/tests/rag/embeddings/test_embedding_factory.py new file mode 100644 index 000000000..937e5c1e2 --- /dev/null +++ b/tests/rag/embeddings/test_embedding_factory.py @@ -0,0 +1,598 @@ +"""Enhanced tests for embedding function factory.""" + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import SecretStr + +from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped] + get_embedding_function, +) +from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped] + + +def test_get_embedding_function_default() -> None: + """Test default embedding function when no config provided.""" + with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: + mock_instance = MagicMock() + mock_openai.return_value = mock_instance + + with patch( + "crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key" + ): + result = get_embedding_function() + + mock_openai.assert_called_once_with( + api_key="test-api-key", model_name="text-embedding-3-small" + ) + assert result == mock_instance + + +def test_get_embedding_function_with_embedding_options() -> None: + """Test embedding function creation with EmbeddingOptions object.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + options = EmbeddingOptions( + provider="openai", + api_key=SecretStr("test-key"), + model_name="text-embedding-3-large", + ) + + result = get_embedding_function(options) + + call_kwargs = mock_openai.call_args.kwargs + assert "api_key" in call_kwargs + assert call_kwargs["api_key"].get_secret_value() == "test-key" + assert "model_name" in call_kwargs + assert call_kwargs["model_name"] == "text-embedding-3-large" + assert result == mock_instance + + +def test_get_embedding_function_sentence_transformer() -> None: + """Test sentence transformer embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_st = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_st + mock_providers.__contains__.return_value = True + + config = { + "provider": "sentence-transformer", + "config": {"model_name": "all-MiniLM-L6-v2"}, + } + + result = get_embedding_function(config) + + mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2") + assert result == mock_instance + + +def test_get_embedding_function_ollama() -> None: + """Test Ollama embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_ollama = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_ollama + mock_providers.__contains__.return_value = True + + config = { + "provider": "ollama", + "config": { + "model_name": "nomic-embed-text", + "url": "http://localhost:11434", + }, + } + + result = get_embedding_function(config) + + mock_ollama.assert_called_once_with( + model_name="nomic-embed-text", url="http://localhost:11434" + ) + assert result == mock_instance + + +def test_get_embedding_function_cohere() -> None: + """Test Cohere embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_cohere = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_cohere + mock_providers.__contains__.return_value = True + + config = { + "provider": "cohere", + "config": {"api_key": "cohere-key", "model_name": "embed-english-v3.0"}, + } + + result = get_embedding_function(config) + + mock_cohere.assert_called_once_with( + api_key="cohere-key", model_name="embed-english-v3.0" + ) + assert result == mock_instance + + +def test_get_embedding_function_huggingface() -> None: + """Test HuggingFace embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_hf = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_hf + mock_providers.__contains__.return_value = True + + config = { + "provider": "huggingface", + "config": { + "api_key": "hf-token", + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + }, + } + + result = get_embedding_function(config) + + mock_hf.assert_called_once_with( + api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2" + ) + assert result == mock_instance + + +def test_get_embedding_function_onnx() -> None: + """Test ONNX embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_onnx = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_onnx + mock_providers.__contains__.return_value = True + + config = {"provider": "onnx"} + + result = get_embedding_function(config) + + mock_onnx.assert_called_once() + assert result == mock_instance + + +def test_get_embedding_function_google_palm() -> None: + """Test Google PaLM embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_palm = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_palm + mock_providers.__contains__.return_value = True + + config = {"provider": "google-palm", "config": {"api_key": "palm-key"}} + + result = get_embedding_function(config) + + mock_palm.assert_called_once_with(api_key="palm-key") + assert result == mock_instance + + +def test_get_embedding_function_amazon_bedrock() -> None: + """Test Amazon Bedrock embedding function with explicit session.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_bedrock = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_bedrock + mock_providers.__contains__.return_value = True + + # Provide an explicit session to avoid boto3 import + mock_session = MagicMock() + config = { + "provider": "amazon-bedrock", + "config": { + "session": mock_session, + "region_name": "us-west-2", + "model_name": "amazon.titan-embed-text-v1", + }, + } + + result = get_embedding_function(config) + + mock_bedrock.assert_called_once_with( + session=mock_session, + region_name="us-west-2", + model_name="amazon.titan-embed-text-v1", + ) + assert result == mock_instance + + +def test_get_embedding_function_jina() -> None: + """Test Jina embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_jina = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_jina + mock_providers.__contains__.return_value = True + + config = { + "provider": "jina", + "config": { + "api_key": "jina-key", + "model_name": "jina-embeddings-v2-base-en", + }, + } + + result = get_embedding_function(config) + + mock_jina.assert_called_once_with( + api_key="jina-key", model_name="jina-embeddings-v2-base-en" + ) + assert result == mock_instance + + +def test_get_embedding_function_unsupported_provider() -> None: + """Test handling of unsupported provider.""" + config = {"provider": "unsupported-provider"} + + with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"): + get_embedding_function(config) + + +def test_get_embedding_function_config_modification() -> None: + """Test that original config dict is not modified.""" + original_config = { + "provider": "openai", + "config": {"api_key": "test-key", "model": "text-embedding-3-small"}, + } + config_copy = original_config.copy() + + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + get_embedding_function(config_copy) + + assert config_copy == original_config + + +def test_get_embedding_function_exclude_none_values() -> None: + """Test that None values are excluded from embedding function calls.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + options = EmbeddingOptions( + provider="openai", api_key=SecretStr("test-key"), model_name=None + ) + + result = get_embedding_function(options) + + call_kwargs = mock_openai.call_args.kwargs + assert "api_key" in call_kwargs + assert call_kwargs["api_key"].get_secret_value() == "test-key" + assert "model_name" not in call_kwargs + assert result == mock_instance + + +def test_get_embedding_function_instructor() -> None: + """Test Instructor embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_instructor = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_instructor + mock_providers.__contains__.return_value = True + + config = { + "provider": "instructor", + "config": {"model_name": "hkunlp/instructor-large"}, + } + + result = get_embedding_function(config) + + mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large") + assert result == mock_instance + + +def test_get_embedding_function_google_generativeai() -> None: + """Test Google Generative AI embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_google = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_google + mock_providers.__contains__.return_value = True + + config = { + "provider": "google-generativeai", + "config": {"api_key": "google-key", "model_name": "models/embedding-001"}, + } + + result = get_embedding_function(config) + + mock_google.assert_called_once_with( + api_key="google-key", model_name="models/embedding-001" + ) + assert result == mock_instance + + +def test_get_embedding_function_google_vertex() -> None: + """Test Google Vertex AI embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_vertex = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_vertex + mock_providers.__contains__.return_value = True + + config = { + "provider": "google-vertex", + "config": { + "api_key": "vertex-key", + "project_id": "my-project", + "region": "us-central1", + }, + } + + result = get_embedding_function(config) + + mock_vertex.assert_called_once_with( + api_key="vertex-key", project_id="my-project", region="us-central1" + ) + assert result == mock_instance + + +def test_get_embedding_function_roboflow() -> None: + """Test Roboflow embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_roboflow = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_roboflow + mock_providers.__contains__.return_value = True + + config = { + "provider": "roboflow", + "config": { + "api_key": "roboflow-key", + "api_url": "https://infer.roboflow.com", + }, + } + + result = get_embedding_function(config) + + mock_roboflow.assert_called_once_with( + api_key="roboflow-key", api_url="https://infer.roboflow.com" + ) + assert result == mock_instance + + +def test_get_embedding_function_openclip() -> None: + """Test OpenCLIP embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openclip = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openclip + mock_providers.__contains__.return_value = True + + config = { + "provider": "openclip", + "config": {"model_name": "ViT-B-32", "checkpoint": "laion2b_s34b_b79k"}, + } + + result = get_embedding_function(config) + + mock_openclip.assert_called_once_with( + model_name="ViT-B-32", checkpoint="laion2b_s34b_b79k" + ) + assert result == mock_instance + + +def test_get_embedding_function_text2vec() -> None: + """Test Text2Vec embedding function.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_text2vec = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_text2vec + mock_providers.__contains__.return_value = True + + config = { + "provider": "text2vec", + "config": {"model_name": "shibing624/text2vec-base-chinese"}, + } + + result = get_embedding_function(config) + + mock_text2vec.assert_called_once_with( + model_name="shibing624/text2vec-base-chinese" + ) + assert result == mock_instance + + +def test_model_to_model_name_conversion() -> None: + """Test that 'model' field is converted to 'model_name' for nested config.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + config = { + "provider": "openai", + "config": {"api_key": "test-key", "model": "text-embedding-3-small"}, + } + + result = get_embedding_function(config) + + mock_openai.assert_called_once_with( + api_key="test-key", model_name="text-embedding-3-small" + ) + assert result == mock_instance + + +def test_api_key_injection_from_env_openai() -> None: + """Test that OpenAI API key is injected from environment when not provided.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv: + mock_getenv.return_value = "env-openai-key" + + config = { + "provider": "openai", + "config": {"model": "text-embedding-3-small"}, + } + + result = get_embedding_function(config) + + mock_getenv.assert_called_with("OPENAI_API_KEY") + mock_openai.assert_called_once_with( + api_key="env-openai-key", model_name="text-embedding-3-small" + ) + assert result == mock_instance + + +def test_api_key_injection_from_env_cohere() -> None: + """Test that Cohere API key is injected from environment when not provided.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_cohere = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_cohere + mock_providers.__contains__.return_value = True + + with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv: + mock_getenv.return_value = "env-cohere-key" + + config = { + "provider": "cohere", + "config": {"model_name": "embed-english-v3.0"}, + } + + result = get_embedding_function(config) + + mock_getenv.assert_called_with("COHERE_API_KEY") + mock_cohere.assert_called_once_with( + api_key="env-cohere-key", model_name="embed-english-v3.0" + ) + assert result == mock_instance + + +def test_api_key_not_injected_when_provided() -> None: + """Test that API key from config takes precedence over environment.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_openai = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_openai + mock_providers.__contains__.return_value = True + + with patch("crewai.rag.embeddings.factory.os.getenv") as mock_getenv: + mock_getenv.return_value = "env-key" + + config = { + "provider": "openai", + "config": {"api_key": "config-key", "model": "text-embedding-3-small"}, + } + + result = get_embedding_function(config) + + mock_openai.assert_called_once_with( + api_key="config-key", model_name="text-embedding-3-small" + ) + assert result == mock_instance + + +def test_amazon_bedrock_session_injection() -> None: + """Test that boto3 session is automatically created for amazon-bedrock.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_bedrock = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_bedrock + mock_providers.__contains__.return_value = True + + mock_boto3 = MagicMock() + with patch.dict("sys.modules", {"boto3": mock_boto3}): + mock_session = MagicMock() + mock_boto3.Session.return_value = mock_session + + config = { + "provider": "amazon-bedrock", + "config": {"model_name": "amazon.titan-embed-text-v1"}, + } + + result = get_embedding_function(config) + + mock_boto3.Session.assert_called_once() + mock_bedrock.assert_called_once_with( + session=mock_session, model_name="amazon.titan-embed-text-v1" + ) + assert result == mock_instance + + +def test_amazon_bedrock_session_not_injected_when_provided() -> None: + """Test that provided session is used for amazon-bedrock.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_instance = MagicMock() + mock_bedrock = MagicMock(return_value=mock_instance) + mock_providers.__getitem__.return_value = mock_bedrock + mock_providers.__contains__.return_value = True + + existing_session = MagicMock() + config = { + "provider": "amazon-bedrock", + "config": { + "session": existing_session, + "model_name": "amazon.titan-embed-text-v1", + }, + } + + result = get_embedding_function(config) + + mock_bedrock.assert_called_once_with( + session=existing_session, model_name="amazon.titan-embed-text-v1" + ) + assert result == mock_instance + + +def test_amazon_bedrock_boto3_import_error() -> None: + """Test error handling when boto3 is not installed.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_providers.__contains__.return_value = True + + with patch.dict("sys.modules", {"boto3": None}): + config = { + "provider": "amazon-bedrock", + "config": {"model_name": "amazon.titan-embed-text-v1"}, + } + + with pytest.raises( + ImportError, match="boto3 is required for amazon-bedrock" + ): + get_embedding_function(config) + + +def test_amazon_bedrock_session_creation_error() -> None: + """Test error handling when AWS session creation fails.""" + with patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") as mock_providers: + mock_providers.__contains__.return_value = True + + mock_boto3 = MagicMock() + with patch.dict("sys.modules", {"boto3": mock_boto3}): + mock_boto3.Session.side_effect = Exception("AWS credentials not configured") + + config = { + "provider": "amazon-bedrock", + "config": {"model_name": "amazon.titan-embed-text-v1"}, + } + + with pytest.raises(ValueError, match="Failed to create AWS session"): + get_embedding_function(config) + + +def test_invalid_config_format() -> None: + """Test error handling for invalid config format.""" + config = { + "provider": "openai", + "api_key": "test-key", + "model": "text-embedding-3-small", + } + + with pytest.raises(ValueError, match="Invalid embedder configuration format"): + get_embedding_function(config) diff --git a/tests/rag/embeddings/test_factory_azure.py b/tests/rag/embeddings/test_factory_azure.py new file mode 100644 index 000000000..e17d2bbef --- /dev/null +++ b/tests/rag/embeddings/test_factory_azure.py @@ -0,0 +1,79 @@ +"""Test Azure embedder configuration with factory.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function + + +class TestAzureEmbedderFactory: + """Test Azure embedder configuration with factory function.""" + + @patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") + def test_azure_with_nested_config(self, mock_providers): + """Test Azure configuration with nested config key.""" + + mock_embedding = MagicMock() + mock_openai_func = MagicMock(return_value=mock_embedding) + mock_providers.__getitem__.return_value = mock_openai_func + mock_providers.__contains__.return_value = True + + embedder_config = EmbedderConfig( + provider="openai", + config={ + "api_key": "test-azure-key", + "api_base": "https://test.openai.azure.com/", + "api_type": "azure", + "api_version": "2023-05-15", + "model": "text-embedding-3-small", + "deployment_id": "test-deployment", + }, + ) + + result = get_embedding_function(embedder_config) + + mock_openai_func.assert_called_once_with( + api_key="test-azure-key", + api_base="https://test.openai.azure.com/", + api_type="azure", + api_version="2023-05-15", + model_name="text-embedding-3-small", + deployment_id="test-deployment", + ) + assert result == mock_embedding + + @patch("crewai.rag.embeddings.factory.EMBEDDING_PROVIDERS") + def test_regular_openai_with_nested_config(self, mock_providers): + """Test regular OpenAI configuration with nested config.""" + + mock_embedding = MagicMock() + mock_openai_func = MagicMock(return_value=mock_embedding) + mock_providers.__getitem__.return_value = mock_openai_func + mock_providers.__contains__.return_value = True + + embedder_config = EmbedderConfig( + provider="openai", + config={"api_key": "test-openai-key", "model": "text-embedding-3-large"}, + ) + + result = get_embedding_function(embedder_config) + + mock_openai_func.assert_called_once_with( + api_key="test-openai-key", model_name="text-embedding-3-large" + ) + assert result == mock_embedding + + def test_flat_format_raises_error(self): + """Test that flat format raises an error.""" + embedder_config = { + "provider": "openai", + "api_key": "test-key", + "model_name": "text-embedding-3-small", + } + + with pytest.raises(ValueError) as exc_info: + get_embedding_function(embedder_config) + + assert "Invalid embedder configuration format" in str(exc_info.value) + assert "nested under a 'config' key" in str(exc_info.value) diff --git a/tests/rag/embeddings/test_factory_enhanced.py b/tests/rag/embeddings/test_factory_enhanced.py deleted file mode 100644 index 489064826..000000000 --- a/tests/rag/embeddings/test_factory_enhanced.py +++ /dev/null @@ -1,250 +0,0 @@ -"""Enhanced tests for embedding function factory.""" - -from unittest.mock import MagicMock, patch - -import pytest - -from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped] - get_embedding_function, -) -from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped] - - -def test_get_embedding_function_default() -> None: - """Test default embedding function when no config provided.""" - with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: - mock_instance = MagicMock() - mock_openai.return_value = mock_instance - - with patch( - "crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key" - ): - result = get_embedding_function() - - mock_openai.assert_called_once_with( - api_key="test-api-key", model_name="text-embedding-3-small" - ) - assert result == mock_instance - - -def test_get_embedding_function_with_embedding_options() -> None: - """Test embedding function creation with EmbeddingOptions object.""" - with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: - mock_instance = MagicMock() - mock_openai.return_value = mock_instance - - options = EmbeddingOptions( - provider="openai", api_key="test-key", model="text-embedding-3-large" - ) - - result = get_embedding_function(options) - - call_kwargs = mock_openai.call_args.kwargs - assert "api_key" in call_kwargs - assert call_kwargs["api_key"].get_secret_value() == "test-key" - # OpenAI uses model_name parameter, not model - assert result == mock_instance - - -def test_get_embedding_function_sentence_transformer() -> None: - """Test sentence transformer embedding function.""" - with patch( - "crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction" - ) as mock_st: - mock_instance = MagicMock() - mock_st.return_value = mock_instance - - config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"} - - result = get_embedding_function(config) - - mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2") - assert result == mock_instance - - -def test_get_embedding_function_ollama() -> None: - """Test Ollama embedding function.""" - with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama: - mock_instance = MagicMock() - mock_ollama.return_value = mock_instance - - config = { - "provider": "ollama", - "model_name": "nomic-embed-text", - "url": "http://localhost:11434", - } - - result = get_embedding_function(config) - - mock_ollama.assert_called_once_with( - model_name="nomic-embed-text", url="http://localhost:11434" - ) - assert result == mock_instance - - -def test_get_embedding_function_cohere() -> None: - """Test Cohere embedding function.""" - with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere: - mock_instance = MagicMock() - mock_cohere.return_value = mock_instance - - config = { - "provider": "cohere", - "api_key": "cohere-key", - "model_name": "embed-english-v3.0", - } - - result = get_embedding_function(config) - - mock_cohere.assert_called_once_with( - api_key="cohere-key", model_name="embed-english-v3.0" - ) - assert result == mock_instance - - -def test_get_embedding_function_huggingface() -> None: - """Test HuggingFace embedding function.""" - with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf: - mock_instance = MagicMock() - mock_hf.return_value = mock_instance - - config = { - "provider": "huggingface", - "api_key": "hf-token", - "model_name": "sentence-transformers/all-MiniLM-L6-v2", - } - - result = get_embedding_function(config) - - mock_hf.assert_called_once_with( - api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2" - ) - assert result == mock_instance - - -def test_get_embedding_function_onnx() -> None: - """Test ONNX embedding function.""" - with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx: - mock_instance = MagicMock() - mock_onnx.return_value = mock_instance - - config = {"provider": "onnx"} - - result = get_embedding_function(config) - - mock_onnx.assert_called_once() - assert result == mock_instance - - -def test_get_embedding_function_google_palm() -> None: - """Test Google PaLM embedding function.""" - with patch( - "crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction" - ) as mock_palm: - mock_instance = MagicMock() - mock_palm.return_value = mock_instance - - config = {"provider": "google-palm", "api_key": "palm-key"} - - result = get_embedding_function(config) - - mock_palm.assert_called_once_with(api_key="palm-key") - assert result == mock_instance - - -def test_get_embedding_function_amazon_bedrock() -> None: - """Test Amazon Bedrock embedding function.""" - with patch( - "crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction" - ) as mock_bedrock: - mock_instance = MagicMock() - mock_bedrock.return_value = mock_instance - - config = { - "provider": "amazon-bedrock", - "region_name": "us-west-2", - "model_name": "amazon.titan-embed-text-v1", - } - - result = get_embedding_function(config) - - mock_bedrock.assert_called_once_with( - region_name="us-west-2", model_name="amazon.titan-embed-text-v1" - ) - assert result == mock_instance - - -def test_get_embedding_function_jina() -> None: - """Test Jina embedding function.""" - with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina: - mock_instance = MagicMock() - mock_jina.return_value = mock_instance - - config = { - "provider": "jina", - "api_key": "jina-key", - "model_name": "jina-embeddings-v2-base-en", - } - - result = get_embedding_function(config) - - mock_jina.assert_called_once_with( - api_key="jina-key", model_name="jina-embeddings-v2-base-en" - ) - assert result == mock_instance - - -def test_get_embedding_function_unsupported_provider() -> None: - """Test handling of unsupported provider.""" - config = {"provider": "unsupported-provider"} - - with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"): - get_embedding_function(config) - - -def test_get_embedding_function_config_modification() -> None: - """Test that original config dict is not modified.""" - original_config = { - "provider": "openai", - "api_key": "test-key", - "model": "text-embedding-3-small", - } - config_copy = original_config.copy() - - with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"): - get_embedding_function(config_copy) - - assert config_copy == original_config - - -def test_get_embedding_function_exclude_none_values() -> None: - """Test that None values are excluded from embedding function calls.""" - with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai: - mock_instance = MagicMock() - mock_openai.return_value = mock_instance - - options = EmbeddingOptions(provider="openai", api_key="test-key", model=None) - - result = get_embedding_function(options) - - call_kwargs = mock_openai.call_args.kwargs - assert "api_key" in call_kwargs - assert call_kwargs["api_key"].get_secret_value() == "test-key" - assert "model" not in call_kwargs - assert result == mock_instance - - -def test_get_embedding_function_instructor() -> None: - """Test Instructor embedding function.""" - with patch( - "crewai.rag.embeddings.factory.InstructorEmbeddingFunction" - ) as mock_instructor: - mock_instance = MagicMock() - mock_instructor.return_value = mock_instance - - config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"} - - result = get_embedding_function(config) - - mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large") - assert result == mock_instance diff --git a/tests/utilities/test_azure_embedder_config.py b/tests/utilities/test_azure_embedder_config.py new file mode 100644 index 000000000..873c68958 --- /dev/null +++ b/tests/utilities/test_azure_embedder_config.py @@ -0,0 +1,82 @@ +"""Test Azure embedder configuration with nested format only.""" + +from unittest.mock import MagicMock, patch + +from crewai.rag.embeddings.configurator import EmbeddingConfigurator + + +class TestAzureEmbedderConfiguration: + """Test Azure embedder configuration with nested format.""" + + @patch( + "chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction" + ) + def test_azure_openai_with_nested_config(self, mock_openai_func): + """Test Azure configuration using OpenAI provider with nested config key.""" + mock_embedding = MagicMock() + mock_openai_func.return_value = mock_embedding + + configurator = EmbeddingConfigurator() + + embedder_config = { + "provider": "openai", + "config": { + "api_key": "test-azure-key", + "api_base": "https://test.openai.azure.com/", + "api_type": "azure", + "api_version": "2023-05-15", + "model": "text-embedding-3-small", + "deployment_id": "test-deployment", + }, + } + + result = configurator.configure_embedder(embedder_config) + + mock_openai_func.assert_called_once_with( + api_key="test-azure-key", + model_name="text-embedding-3-small", + api_base="https://test.openai.azure.com/", + api_type="azure", + api_version="2023-05-15", + default_headers=None, + dimensions=None, + deployment_id="test-deployment", + organization_id=None, + ) + assert result == mock_embedding + + @patch( + "chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction" + ) + def test_azure_provider_with_nested_config(self, mock_openai_func): + """Test using 'azure' as provider with nested config.""" + mock_embedding = MagicMock() + mock_openai_func.return_value = mock_embedding + + configurator = EmbeddingConfigurator() + + embedder_config = { + "provider": "azure", + "config": { + "api_key": "test-azure-key", + "api_base": "https://test.openai.azure.com/", + "api_version": "2023-05-15", + "model": "text-embedding-3-small", + "deployment_id": "test-deployment", + }, + } + + result = configurator.configure_embedder(embedder_config) + + mock_openai_func.assert_called_once_with( + api_key="test-azure-key", + api_base="https://test.openai.azure.com/", + api_type="azure", + api_version="2023-05-15", + model_name="text-embedding-3-small", + default_headers=None, + dimensions=None, + deployment_id="test-deployment", + organization_id=None, + ) + assert result == mock_embedding