mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
This commit fixes the validation error that occurred when using the
google-generativeai embedder provider with a flat configuration format.
Changes:
1. Made the 'config' field optional in GenerativeAiProviderSpec by adding
'total=False' and marking 'provider' as Required, consistent with other
provider specs like VertexAIProviderSpec.
2. Added normalization in the Crew class to automatically convert flat
embedder configs to nested format before validation. This allows users
to use either format:
- Flat: {'provider': 'google-generativeai', 'api_key': '...', 'model_name': '...'}
- Nested: {'provider': 'google-generativeai', 'config': {'api_key': '...', 'model_name': '...'}}
3. Updated the embedder factory to support both flat and nested config
formats by checking for the presence of 'config' key and extracting
config fields accordingly.
4. Added comprehensive tests to verify both formats work correctly:
- Test for flat config format (the issue reported in #3741)
- Test for nested config format (recommended format)
- Test for TypedDict validation
Fixes #3741
Co-Authored-By: João <joao@crewai.com>
303 lines
11 KiB
Python
303 lines
11 KiB
Python
"""Tests for embedding function factory."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.rag.embeddings.factory import build_embedder
|
|
|
|
|
|
class TestEmbeddingFactory:
|
|
"""Test embedding factory functions."""
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_openai(self, mock_import):
|
|
"""Test building OpenAI embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "openai",
|
|
"config": {
|
|
"api_key": "test-key",
|
|
"model_name": "text-embedding-3-small",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
|
|
)
|
|
mock_provider_class.assert_called_once()
|
|
|
|
call_kwargs = mock_provider_class.call_args.kwargs
|
|
assert call_kwargs["api_key"] == "test-key"
|
|
assert call_kwargs["model_name"] == "text-embedding-3-small"
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_azure(self, mock_import):
|
|
"""Test building Azure embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "azure",
|
|
"config": {
|
|
"api_key": "test-azure-key",
|
|
"api_base": "https://test.openai.azure.com/",
|
|
"api_type": "azure",
|
|
"api_version": "2023-05-15",
|
|
"model_name": "text-embedding-3-small",
|
|
"deployment_id": "test-deployment",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
|
)
|
|
|
|
call_kwargs = mock_provider_class.call_args.kwargs
|
|
assert call_kwargs["api_key"] == "test-azure-key"
|
|
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
|
assert call_kwargs["api_type"] == "azure"
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_ollama(self, mock_import):
|
|
"""Test building Ollama embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "ollama",
|
|
"config": {
|
|
"model_name": "nomic-embed-text",
|
|
"url": "http://localhost:11434",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
|
|
)
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_cohere(self, mock_import):
|
|
"""Test building Cohere embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "cohere",
|
|
"config": {
|
|
"api_key": "cohere-key",
|
|
"model_name": "embed-english-v3.0",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
|
|
)
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_voyageai(self, mock_import):
|
|
"""Test building VoyageAI embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "voyageai",
|
|
"config": {
|
|
"api_key": "voyage-key",
|
|
"model": "voyage-2",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
|
|
)
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_watsonx(self, mock_import):
|
|
"""Test building WatsonX embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "watsonx",
|
|
"config": {
|
|
"model_id": "ibm/slate-125m-english-rtrvr",
|
|
"api_key": "watsonx-key",
|
|
"url": "https://us-south.ml.cloud.ibm.com",
|
|
"project_id": "test-project",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
|
)
|
|
|
|
def test_build_embedder_unknown_provider(self):
|
|
"""Test error handling for unknown provider."""
|
|
config = {"provider": "unknown-provider", "config": {}}
|
|
|
|
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
|
|
build_embedder(config)
|
|
|
|
def test_build_embedder_missing_provider(self):
|
|
"""Test error handling for missing provider key."""
|
|
config = {"config": {"api_key": "test-key"}}
|
|
|
|
with pytest.raises(KeyError):
|
|
build_embedder(config)
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_import_error(self, mock_import):
|
|
"""Test error handling when provider import fails."""
|
|
mock_import.side_effect = ImportError("Module not found")
|
|
|
|
config = {"provider": "openai", "config": {"api_key": "test-key"}}
|
|
|
|
with pytest.raises(ImportError, match="Failed to import provider openai"):
|
|
build_embedder(config)
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_custom_provider(self, mock_import):
|
|
"""Test building custom embedder."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_callable = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable = mock_embedding_callable
|
|
|
|
config = {
|
|
"provider": "custom",
|
|
"config": {"embedding_callable": mock_embedding_callable},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
|
|
)
|
|
|
|
call_kwargs = mock_provider_class.call_args.kwargs
|
|
assert call_kwargs["embedding_callable"] == mock_embedding_callable
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
|
|
def test_build_embedder_with_provider_instance(
|
|
self, mock_build_from_provider, mock_import
|
|
):
|
|
"""Test building embedder from provider instance."""
|
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
|
|
|
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
|
|
mock_embedding_function = MagicMock()
|
|
mock_build_from_provider.return_value = mock_embedding_function
|
|
|
|
result = build_embedder(mock_provider)
|
|
|
|
mock_build_from_provider.assert_called_once_with(mock_provider)
|
|
assert result == mock_embedding_function
|
|
mock_import.assert_not_called()
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
|
|
"""Test building Google Generative AI embedder with nested config format."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "google-generativeai",
|
|
"config": {
|
|
"api_key": "test-gemini-key",
|
|
"model_name": "models/text-embedding-004",
|
|
},
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
|
)
|
|
mock_provider_class.assert_called_once()
|
|
|
|
call_kwargs = mock_provider_class.call_args.kwargs
|
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
|
|
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
|
|
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
|
|
mock_provider_class = MagicMock()
|
|
mock_provider_instance = MagicMock()
|
|
mock_embedding_function = MagicMock()
|
|
|
|
mock_import.return_value = mock_provider_class
|
|
mock_provider_class.return_value = mock_provider_instance
|
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
|
|
config = {
|
|
"provider": "google-generativeai",
|
|
"api_key": "test-gemini-key",
|
|
"model_name": "models/text-embedding-004",
|
|
}
|
|
|
|
build_embedder(config)
|
|
|
|
mock_import.assert_called_once_with(
|
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
|
)
|
|
mock_provider_class.assert_called_once()
|
|
|
|
call_kwargs = mock_provider_class.call_args.kwargs
|
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|