mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
This commit fixes the validation error that occurred when using the
google-generativeai embedder provider with a flat configuration format.
Changes:
1. Made the 'config' field optional in GenerativeAiProviderSpec by adding
'total=False' and marking 'provider' as Required, consistent with other
provider specs like VertexAIProviderSpec.
2. Added normalization in the Crew class to automatically convert flat
embedder configs to nested format before validation. This allows users
to use either format:
- Flat: {'provider': 'google-generativeai', 'api_key': '...', 'model_name': '...'}
- Nested: {'provider': 'google-generativeai', 'config': {'api_key': '...', 'model_name': '...'}}
3. Updated the embedder factory to support both flat and nested config
formats by checking for the presence of 'config' key and extracting
config fields accordingly.
4. Added comprehensive tests to verify both formats work correctly:
- Test for flat config format (the issue reported in #3741)
- Test for nested config format (recommended format)
- Test for TypedDict validation
Fixes #3741
Co-Authored-By: João <joao@crewai.com>
108 lines
3.3 KiB
Python
108 lines
3.3 KiB
Python
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai import Agent, Crew, Task
|
|
|
|
|
|
class TestGoogleGenerativeAIEmbedder:
|
|
"""Test Google Generative AI embedder configuration formats."""
|
|
|
|
@patch("crewai.crew.Knowledge")
|
|
@patch("crewai.crew.ShortTermMemory")
|
|
@patch("crewai.crew.LongTermMemory")
|
|
@patch("crewai.crew.EntityMemory")
|
|
def test_crew_with_google_generativeai_flat_config(
|
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
|
):
|
|
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent,
|
|
)
|
|
|
|
embedder_config = {
|
|
"provider": "google-generativeai",
|
|
"api_key": "test-gemini-key",
|
|
"model_name": "models/text-embedding-004",
|
|
}
|
|
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
embedder=embedder_config,
|
|
)
|
|
|
|
expected_normalized_config = {
|
|
"provider": "google-generativeai",
|
|
"config": {
|
|
"api_key": "test-gemini-key",
|
|
"model_name": "models/text-embedding-004",
|
|
},
|
|
}
|
|
assert crew.embedder == expected_normalized_config
|
|
|
|
@patch("crewai.crew.Knowledge")
|
|
@patch("crewai.crew.ShortTermMemory")
|
|
@patch("crewai.crew.LongTermMemory")
|
|
@patch("crewai.crew.EntityMemory")
|
|
def test_crew_with_google_generativeai_nested_config(
|
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
|
):
|
|
"""Test that Crew accepts google-generativeai embedder with nested config format."""
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent,
|
|
)
|
|
|
|
embedder_config = {
|
|
"provider": "google-generativeai",
|
|
"config": {
|
|
"api_key": "test-gemini-key",
|
|
"model_name": "models/text-embedding-004",
|
|
},
|
|
}
|
|
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
embedder=embedder_config,
|
|
)
|
|
|
|
assert crew.embedder == embedder_config
|
|
|
|
def test_generativeai_provider_spec_validation(self):
|
|
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
|
|
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
|
|
|
|
flat_spec: GenerativeAiProviderSpec = {
|
|
"provider": "google-generativeai",
|
|
}
|
|
assert flat_spec["provider"] == "google-generativeai"
|
|
|
|
nested_spec: GenerativeAiProviderSpec = {
|
|
"provider": "google-generativeai",
|
|
"config": {
|
|
"api_key": "test-key",
|
|
"model_name": "models/text-embedding-004",
|
|
},
|
|
}
|
|
assert nested_spec["provider"] == "google-generativeai"
|
|
assert nested_spec["config"]["api_key"] == "test-key"
|