Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
32a013eb2f Fix trailing whitespace to satisfy ruff lint
Co-Authored-By: João <joao@crewai.com>
2025-10-20 19:30:30 +00:00
Devin AI
36673f89e7 Fix google-generativeai embedder validation error (issue #3741)
This commit fixes the validation error that occurred when using the
google-generativeai embedder provider with a flat configuration format.

Changes:
1. Made the 'config' field optional in GenerativeAiProviderSpec by adding
   'total=False' and marking 'provider' as Required, consistent with other
   provider specs like VertexAIProviderSpec.

2. Added normalization in the Crew class to automatically convert flat
   embedder configs to nested format before validation. This allows users
   to use either format:
   - Flat: {'provider': 'google-generativeai', 'api_key': '...', 'model_name': '...'}
   - Nested: {'provider': 'google-generativeai', 'config': {'api_key': '...', 'model_name': '...'}}

3. Updated the embedder factory to support both flat and nested config
   formats by checking for the presence of 'config' key and extracting
   config fields accordingly.

4. Added comprehensive tests to verify both formats work correctly:
   - Test for flat config format (the issue reported in #3741)
   - Test for nested config format (recommended format)
   - Test for TypedDict validation

Fixes #3741

Co-Authored-By: João <joao@crewai.com>
2025-10-20 19:26:08 +00:00
6 changed files with 3570 additions and 3459 deletions

View File

@@ -283,6 +283,30 @@ class Crew(FlowTrackable, BaseModel):
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
)
@field_validator("embedder", mode="before")
@classmethod
def normalize_embedder_config(
cls, v: dict[str, Any] | None
) -> dict[str, Any] | None:
"""Normalize embedder config to support both flat and nested formats.
Args:
v: The embedder config to be normalized.
Returns:
The normalized embedder config with nested structure.
"""
if v is None or not isinstance(v, dict):
return v
if "provider" in v and "config" not in v:
provider = v["provider"]
config_fields = {k: val for k, val in v.items() if k != "provider"}
if config_fields:
return {"provider": provider, "config": config_fields}
return v
@field_validator("config", mode="before")
@classmethod
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:

View File

@@ -228,8 +228,11 @@ def build_embedder_from_dict(spec):
"""Build an embedding function instance from a dictionary specification.
Args:
spec: A dictionary with 'provider' and 'config' keys.
Example: {
spec: A dictionary with 'provider' and optionally 'config' keys.
Supports two formats:
Nested format (recommended):
{
"provider": "openai",
"config": {
"api_key": "sk-...",
@@ -237,6 +240,13 @@ def build_embedder_from_dict(spec):
}
}
Flat format (for backward compatibility):
{
"provider": "openai",
"api_key": "sk-...",
"model_name": "text-embedding-3-small"
}
Returns:
An instance of the appropriate embedding function.
@@ -266,7 +276,10 @@ def build_embedder_from_dict(spec):
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
provider_config = spec.get("config", {})
if "config" in spec:
provider_config = spec["config"]
else:
provider_config = {k: v for k, v in spec.items() if k != "provider"}
if provider_name == "custom" and "embedding_callable" not in provider_config:
raise ValueError("Custom provider requires 'embedding_callable' in config")

View File

@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
class GenerativeAiProviderSpec(TypedDict):
class GenerativeAiProviderSpec(TypedDict, total=False):
"""Google Generative AI provider specification."""
provider: Literal["google-generativeai"]
provider: Required[Literal["google-generativeai"]]
config: GenerativeAiProviderConfig

View File

@@ -242,3 +242,61 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
"""Test building Google Generative AI embedder with nested config format."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-gemini-key"
assert call_kwargs["model_name"] == "models/text-embedding-004"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-generativeai",
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-gemini-key"
assert call_kwargs["model_name"] == "models/text-embedding-004"

View File

@@ -0,0 +1,107 @@
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent, Crew, Task
class TestGoogleGenerativeAIEmbedder:
"""Test Google Generative AI embedder configuration formats."""
@patch("crewai.crew.Knowledge")
@patch("crewai.crew.ShortTermMemory")
@patch("crewai.crew.LongTermMemory")
@patch("crewai.crew.EntityMemory")
def test_crew_with_google_generativeai_flat_config(
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
):
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
embedder_config = {
"provider": "google-generativeai",
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
}
crew = Crew(
agents=[agent],
tasks=[task],
embedder=embedder_config,
)
expected_normalized_config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
assert crew.embedder == expected_normalized_config
@patch("crewai.crew.Knowledge")
@patch("crewai.crew.ShortTermMemory")
@patch("crewai.crew.LongTermMemory")
@patch("crewai.crew.EntityMemory")
def test_crew_with_google_generativeai_nested_config(
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
):
"""Test that Crew accepts google-generativeai embedder with nested config format."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
embedder_config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
crew = Crew(
agents=[agent],
tasks=[task],
embedder=embedder_config,
)
assert crew.embedder == embedder_config
def test_generativeai_provider_spec_validation(self):
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
flat_spec: GenerativeAiProviderSpec = {
"provider": "google-generativeai",
}
assert flat_spec["provider"] == "google-generativeai"
nested_spec: GenerativeAiProviderSpec = {
"provider": "google-generativeai",
"config": {
"api_key": "test-key",
"model_name": "models/text-embedding-004",
},
}
assert nested_spec["provider"] == "google-generativeai"
assert nested_spec["config"]["api_key"] == "test-key"

6817
uv.lock generated

File diff suppressed because it is too large Load Diff