mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
2 Commits
1.2.0
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32a013eb2f | ||
|
|
36673f89e7 |
@@ -283,6 +283,30 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
||||
)
|
||||
|
||||
@field_validator("embedder", mode="before")
|
||||
@classmethod
|
||||
def normalize_embedder_config(
|
||||
cls, v: dict[str, Any] | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Normalize embedder config to support both flat and nested formats.
|
||||
|
||||
Args:
|
||||
v: The embedder config to be normalized.
|
||||
|
||||
Returns:
|
||||
The normalized embedder config with nested structure.
|
||||
"""
|
||||
if v is None or not isinstance(v, dict):
|
||||
return v
|
||||
|
||||
if "provider" in v and "config" not in v:
|
||||
provider = v["provider"]
|
||||
config_fields = {k: val for k, val in v.items() if k != "provider"}
|
||||
if config_fields:
|
||||
return {"provider": provider, "config": config_fields}
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||
|
||||
@@ -228,8 +228,11 @@ def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
spec: A dictionary with 'provider' and optionally 'config' keys.
|
||||
Supports two formats:
|
||||
|
||||
Nested format (recommended):
|
||||
{
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
@@ -237,6 +240,13 @@ def build_embedder_from_dict(spec):
|
||||
}
|
||||
}
|
||||
|
||||
Flat format (for backward compatibility):
|
||||
{
|
||||
"provider": "openai",
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
@@ -266,7 +276,10 @@ def build_embedder_from_dict(spec):
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
if "config" in spec:
|
||||
provider_config = spec["config"]
|
||||
else:
|
||||
provider_config = {k: v for k, v in spec.items() if k != "provider"}
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
class GenerativeAiProviderSpec(TypedDict, total=False):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
provider: Required[Literal["google-generativeai"]]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
|
||||
@@ -242,3 +242,61 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
|
||||
"""Test building Google Generative AI embedder with nested config format."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
|
||||
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-generativeai",
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||
|
||||
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
|
||||
class TestGoogleGenerativeAIEmbedder:
|
||||
"""Test Google Generative AI embedder configuration formats."""
|
||||
|
||||
@patch("crewai.crew.Knowledge")
|
||||
@patch("crewai.crew.ShortTermMemory")
|
||||
@patch("crewai.crew.LongTermMemory")
|
||||
@patch("crewai.crew.EntityMemory")
|
||||
def test_crew_with_google_generativeai_flat_config(
|
||||
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||
):
|
||||
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-generativeai",
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
embedder=embedder_config,
|
||||
)
|
||||
|
||||
expected_normalized_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
assert crew.embedder == expected_normalized_config
|
||||
|
||||
@patch("crewai.crew.Knowledge")
|
||||
@patch("crewai.crew.ShortTermMemory")
|
||||
@patch("crewai.crew.LongTermMemory")
|
||||
@patch("crewai.crew.EntityMemory")
|
||||
def test_crew_with_google_generativeai_nested_config(
|
||||
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||
):
|
||||
"""Test that Crew accepts google-generativeai embedder with nested config format."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
embedder=embedder_config,
|
||||
)
|
||||
|
||||
assert crew.embedder == embedder_config
|
||||
|
||||
def test_generativeai_provider_spec_validation(self):
|
||||
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
|
||||
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
|
||||
|
||||
flat_spec: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
}
|
||||
assert flat_spec["provider"] == "google-generativeai"
|
||||
|
||||
nested_spec: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
assert nested_spec["provider"] == "google-generativeai"
|
||||
assert nested_spec["config"]["api_key"] == "test-key"
|
||||
Reference in New Issue
Block a user