Files
crewAI/tests/rag/embeddings/test_embedding_factory.py
Greyson LaLonde 12fa7e2ff1 fix: rename watson to watsonx embedding provider and prefix env vars
- prefix provider env vars with embeddings_  
- rename watson → watsonx in providers  
- add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
2025-09-26 10:57:18 -04:00

245 lines
8.9 KiB
Python

"""Tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import build_embedder
class TestEmbeddingFactory:
"""Test embedding factory functions."""
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_openai(self, mock_import):
"""Test building OpenAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "openai",
"config": {
"api_key": "test-key",
"model_name": "text-embedding-3-small",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["model_name"] == "text-embedding-3-small"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_azure(self, mock_import):
"""Test building Azure embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "azure",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model_name": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-azure-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
assert call_kwargs["api_type"] == "azure"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_ollama(self, mock_import):
"""Test building Ollama embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_cohere(self, mock_import):
"""Test building Cohere embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "cohere",
"config": {
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_voyageai(self, mock_import):
"""Test building VoyageAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "voyageai",
"config": {
"api_key": "voyage-key",
"model": "voyage-2",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_watsonx(self, mock_import):
"""Test building WatsonX embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "watsonx",
"config": {
"model_id": "ibm/slate-125m-english-rtrvr",
"api_key": "watsonx-key",
"url": "https://us-south.ml.cloud.ibm.com",
"project_id": "test-project",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
)
def test_build_embedder_unknown_provider(self):
"""Test error handling for unknown provider."""
config = {"provider": "unknown-provider", "config": {}}
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
build_embedder(config)
def test_build_embedder_missing_provider(self):
"""Test error handling for missing provider key."""
config = {"config": {"api_key": "test-key"}}
with pytest.raises(KeyError):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_import_error(self, mock_import):
"""Test error handling when provider import fails."""
mock_import.side_effect = ImportError("Module not found")
config = {"provider": "openai", "config": {"api_key": "test-key"}}
with pytest.raises(ImportError, match="Failed to import provider openai"):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_custom_provider(self, mock_import):
"""Test building custom embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_callable = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable = mock_embedding_callable
config = {
"provider": "custom",
"config": {"embedding_callable": mock_embedding_callable},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["embedding_callable"] == mock_embedding_callable
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
def test_build_embedder_with_provider_instance(
self, mock_build_from_provider, mock_import
):
"""Test building embedder from provider instance."""
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
mock_embedding_function = MagicMock()
mock_build_from_provider.return_value = mock_embedding_function
result = build_embedder(mock_provider)
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()