Files
crewAI/lib/crewai/tests/rag/embeddings/test_embedding_factory.py
Lorenze Jay 58b866a83d Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support

- Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI.
- Updated factory methods to route to the new embedding function.
- Enhanced VertexAIProvider and related configurations to accommodate the new model options.
- Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods.

This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework.

* fix test count

* rm comment

* regen cassettes

* regen

* drop variable from .envtest

* dreict to relevant trest only
2026-01-26 14:55:03 -08:00

372 lines
14 KiB
Python

"""Tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import build_embedder
class TestEmbeddingFactory:
"""Test embedding factory functions."""
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_openai(self, mock_import):
"""Test building OpenAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "openai",
"config": {
"api_key": "test-key",
"model_name": "text-embedding-3-small",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["model_name"] == "text-embedding-3-small"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_azure(self, mock_import):
"""Test building Azure embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "azure",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model_name": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-azure-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
assert call_kwargs["api_type"] == "azure"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_ollama(self, mock_import):
"""Test building Ollama embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_huggingface(self, mock_import):
"""Test building HuggingFace embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "huggingface",
"config": {
"api_key": "hf-test-key",
"model": "sentence-transformers/all-MiniLM-L6-v2",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "hf-test-key"
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_cohere(self, mock_import):
"""Test building Cohere embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "cohere",
"config": {
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_voyageai(self, mock_import):
"""Test building VoyageAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "voyageai",
"config": {
"api_key": "voyage-key",
"model": "voyage-2",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_watsonx(self, mock_import):
"""Test building WatsonX embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "watsonx",
"config": {
"model_id": "ibm/slate-125m-english-rtrvr",
"api_key": "watsonx-key",
"url": "https://us-south.ml.cloud.ibm.com",
"project_id": "test-project",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
)
def test_build_embedder_unknown_provider(self):
"""Test error handling for unknown provider."""
config = {"provider": "unknown-provider", "config": {}}
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
build_embedder(config)
def test_build_embedder_missing_provider(self):
"""Test error handling for missing provider key."""
config = {"config": {"api_key": "test-key"}}
with pytest.raises(KeyError):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_import_error(self, mock_import):
"""Test error handling when provider import fails."""
mock_import.side_effect = ImportError("Module not found")
config = {"provider": "openai", "config": {"api_key": "test-key"}}
with pytest.raises(ImportError, match="Failed to import provider openai"):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_custom_provider(self, mock_import):
"""Test building custom embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_callable = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable = mock_embedding_callable
config = {
"provider": "custom",
"config": {"embedding_callable": mock_embedding_callable},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["embedding_callable"] == mock_embedding_callable
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
def test_build_embedder_with_provider_instance(
self, mock_build_from_provider, mock_import
):
"""Test building embedder from provider instance."""
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
mock_embedding_function = MagicMock()
mock_build_from_provider.return_value = mock_embedding_function
result = build_embedder(mock_provider)
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
"""Test routing to Google Vertex provider with new genai model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"api_key": "test-google-api-key",
"model_name": "gemini-embedding-001",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-google-api-key"
assert call_kwargs["model_name"] == "gemini-embedding-001"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"region": "us-central1",
"model_name": "textembedding-gecko",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["region"] == "us-central1"
assert call_kwargs["model_name"] == "textembedding-gecko"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_location(self, mock_import):
"""Test routing to Google Vertex provider with location parameter."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"location": "europe-west1",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
"output_dimensionality": 768,
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["location"] == "europe-west1"
assert call_kwargs["model_name"] == "gemini-embedding-001"
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
assert call_kwargs["output_dimensionality"] == 768