mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-30 11:18:31 +00:00
Fix google-generativeai embedder validation error (issue #3741)
This commit fixes the validation error that occurred when using the
google-generativeai embedder provider with a flat configuration format.
Changes:
1. Made the 'config' field optional in GenerativeAiProviderSpec by adding
'total=False' and marking 'provider' as Required, consistent with other
provider specs like VertexAIProviderSpec.
2. Added normalization in the Crew class to automatically convert flat
embedder configs to nested format before validation. This allows users
to use either format:
- Flat: {'provider': 'google-generativeai', 'api_key': '...', 'model_name': '...'}
- Nested: {'provider': 'google-generativeai', 'config': {'api_key': '...', 'model_name': '...'}}
3. Updated the embedder factory to support both flat and nested config
formats by checking for the presence of 'config' key and extracting
config fields accordingly.
4. Added comprehensive tests to verify both formats work correctly:
- Test for flat config format (the issue reported in #3741)
- Test for nested config format (recommended format)
- Test for TypedDict validation
Fixes #3741
Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -283,6 +283,30 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("embedder", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_embedder_config(
|
||||||
|
cls, v: dict[str, Any] | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Normalize embedder config to support both flat and nested formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: The embedder config to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized embedder config with nested structure.
|
||||||
|
"""
|
||||||
|
if v is None or not isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
|
||||||
|
if "provider" in v and "config" not in v:
|
||||||
|
provider = v["provider"]
|
||||||
|
config_fields = {k: val for k, val in v.items() if k != "provider"}
|
||||||
|
if config_fields:
|
||||||
|
return {"provider": provider, "config": config_fields}
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
@field_validator("config", mode="before")
|
@field_validator("config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||||
|
|||||||
@@ -228,14 +228,24 @@ def build_embedder_from_dict(spec):
|
|||||||
"""Build an embedding function instance from a dictionary specification.
|
"""Build an embedding function instance from a dictionary specification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec: A dictionary with 'provider' and 'config' keys.
|
spec: A dictionary with 'provider' and optionally 'config' keys.
|
||||||
Example: {
|
Supports two formats:
|
||||||
|
|
||||||
|
Nested format (recommended):
|
||||||
|
{
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": "sk-...",
|
"api_key": "sk-...",
|
||||||
"model_name": "text-embedding-3-small"
|
"model_name": "text-embedding-3-small"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Flat format (for backward compatibility):
|
||||||
|
{
|
||||||
|
"provider": "openai",
|
||||||
|
"api_key": "sk-...",
|
||||||
|
"model_name": "text-embedding-3-small"
|
||||||
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the appropriate embedding function.
|
An instance of the appropriate embedding function.
|
||||||
@@ -266,7 +276,10 @@ def build_embedder_from_dict(spec):
|
|||||||
except (ImportError, AttributeError, ValueError) as e:
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||||
|
|
||||||
provider_config = spec.get("config", {})
|
if "config" in spec:
|
||||||
|
provider_config = spec["config"]
|
||||||
|
else:
|
||||||
|
provider_config = {k: v for k, v in spec.items() if k != "provider"}
|
||||||
|
|
||||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
|
|||||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||||
|
|
||||||
|
|
||||||
class GenerativeAiProviderSpec(TypedDict):
|
class GenerativeAiProviderSpec(TypedDict, total=False):
|
||||||
"""Google Generative AI provider specification."""
|
"""Google Generative AI provider specification."""
|
||||||
|
|
||||||
provider: Literal["google-generativeai"]
|
provider: Required[Literal["google-generativeai"]]
|
||||||
config: GenerativeAiProviderConfig
|
config: GenerativeAiProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -242,3 +242,61 @@ class TestEmbeddingFactory:
|
|||||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||||
assert result == mock_embedding_function
|
assert result == mock_embedding_function
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
|
||||||
|
"""Test building Google Generative AI embedder with nested config format."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||||
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
|
||||||
|
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||||
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||||
|
|||||||
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleGenerativeAIEmbedder:
|
||||||
|
"""Test Google Generative AI embedder configuration formats."""
|
||||||
|
|
||||||
|
@patch("crewai.crew.Knowledge")
|
||||||
|
@patch("crewai.crew.ShortTermMemory")
|
||||||
|
@patch("crewai.crew.LongTermMemory")
|
||||||
|
@patch("crewai.crew.EntityMemory")
|
||||||
|
def test_crew_with_google_generativeai_flat_config(
|
||||||
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||||
|
):
|
||||||
|
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
}
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
embedder=embedder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_normalized_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert crew.embedder == expected_normalized_config
|
||||||
|
|
||||||
|
@patch("crewai.crew.Knowledge")
|
||||||
|
@patch("crewai.crew.ShortTermMemory")
|
||||||
|
@patch("crewai.crew.LongTermMemory")
|
||||||
|
@patch("crewai.crew.EntityMemory")
|
||||||
|
def test_crew_with_google_generativeai_nested_config(
|
||||||
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||||
|
):
|
||||||
|
"""Test that Crew accepts google-generativeai embedder with nested config format."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
embedder=embedder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert crew.embedder == embedder_config
|
||||||
|
|
||||||
|
def test_generativeai_provider_spec_validation(self):
|
||||||
|
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
|
||||||
|
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
|
||||||
|
|
||||||
|
flat_spec: GenerativeAiProviderSpec = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
}
|
||||||
|
assert flat_spec["provider"] == "google-generativeai"
|
||||||
|
|
||||||
|
nested_spec: GenerativeAiProviderSpec = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert nested_spec["provider"] == "google-generativeai"
|
||||||
|
assert nested_spec["config"]["api_key"] == "test-key"
|
||||||
Reference in New Issue
Block a user