mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-29 02:38:29 +00:00
Compare commits
2 Commits
1.2.0
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32a013eb2f | ||
|
|
36673f89e7 |
@@ -283,6 +283,30 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("embedder", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_embedder_config(
|
||||||
|
cls, v: dict[str, Any] | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Normalize embedder config to support both flat and nested formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: The embedder config to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized embedder config with nested structure.
|
||||||
|
"""
|
||||||
|
if v is None or not isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
|
||||||
|
if "provider" in v and "config" not in v:
|
||||||
|
provider = v["provider"]
|
||||||
|
config_fields = {k: val for k, val in v.items() if k != "provider"}
|
||||||
|
if config_fields:
|
||||||
|
return {"provider": provider, "config": config_fields}
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
@field_validator("config", mode="before")
|
@field_validator("config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||||
|
|||||||
@@ -228,8 +228,11 @@ def build_embedder_from_dict(spec):
|
|||||||
"""Build an embedding function instance from a dictionary specification.
|
"""Build an embedding function instance from a dictionary specification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec: A dictionary with 'provider' and 'config' keys.
|
spec: A dictionary with 'provider' and optionally 'config' keys.
|
||||||
Example: {
|
Supports two formats:
|
||||||
|
|
||||||
|
Nested format (recommended):
|
||||||
|
{
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": "sk-...",
|
"api_key": "sk-...",
|
||||||
@@ -237,6 +240,13 @@ def build_embedder_from_dict(spec):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Flat format (for backward compatibility):
|
||||||
|
{
|
||||||
|
"provider": "openai",
|
||||||
|
"api_key": "sk-...",
|
||||||
|
"model_name": "text-embedding-3-small"
|
||||||
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the appropriate embedding function.
|
An instance of the appropriate embedding function.
|
||||||
|
|
||||||
@@ -266,7 +276,10 @@ def build_embedder_from_dict(spec):
|
|||||||
except (ImportError, AttributeError, ValueError) as e:
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||||
|
|
||||||
provider_config = spec.get("config", {})
|
if "config" in spec:
|
||||||
|
provider_config = spec["config"]
|
||||||
|
else:
|
||||||
|
provider_config = {k: v for k, v in spec.items() if k != "provider"}
|
||||||
|
|
||||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
|
|||||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||||
|
|
||||||
|
|
||||||
class GenerativeAiProviderSpec(TypedDict):
|
class GenerativeAiProviderSpec(TypedDict, total=False):
|
||||||
"""Google Generative AI provider specification."""
|
"""Google Generative AI provider specification."""
|
||||||
|
|
||||||
provider: Literal["google-generativeai"]
|
provider: Required[Literal["google-generativeai"]]
|
||||||
config: GenerativeAiProviderConfig
|
config: GenerativeAiProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -242,3 +242,61 @@ class TestEmbeddingFactory:
|
|||||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||||
assert result == mock_embedding_function
|
assert result == mock_embedding_function
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
|
||||||
|
"""Test building Google Generative AI embedder with nested config format."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||||
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
|
||||||
|
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||||
|
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||||
|
|||||||
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
107
tests/rag/embeddings/test_google_generativeai_embedder.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleGenerativeAIEmbedder:
|
||||||
|
"""Test Google Generative AI embedder configuration formats."""
|
||||||
|
|
||||||
|
@patch("crewai.crew.Knowledge")
|
||||||
|
@patch("crewai.crew.ShortTermMemory")
|
||||||
|
@patch("crewai.crew.LongTermMemory")
|
||||||
|
@patch("crewai.crew.EntityMemory")
|
||||||
|
def test_crew_with_google_generativeai_flat_config(
|
||||||
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||||
|
):
|
||||||
|
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
}
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
embedder=embedder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_normalized_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert crew.embedder == expected_normalized_config
|
||||||
|
|
||||||
|
@patch("crewai.crew.Knowledge")
|
||||||
|
@patch("crewai.crew.ShortTermMemory")
|
||||||
|
@patch("crewai.crew.LongTermMemory")
|
||||||
|
@patch("crewai.crew.EntityMemory")
|
||||||
|
def test_crew_with_google_generativeai_nested_config(
|
||||||
|
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||||
|
):
|
||||||
|
"""Test that Crew accepts google-generativeai embedder with nested config format."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-gemini-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
embedder=embedder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert crew.embedder == embedder_config
|
||||||
|
|
||||||
|
def test_generativeai_provider_spec_validation(self):
|
||||||
|
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
|
||||||
|
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
|
||||||
|
|
||||||
|
flat_spec: GenerativeAiProviderSpec = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
}
|
||||||
|
assert flat_spec["provider"] == "google-generativeai"
|
||||||
|
|
||||||
|
nested_spec: GenerativeAiProviderSpec = {
|
||||||
|
"provider": "google-generativeai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model_name": "models/text-embedding-004",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert nested_spec["provider"] == "google-generativeai"
|
||||||
|
assert nested_spec["config"]["api_key"] == "test-key"
|
||||||
Reference in New Issue
Block a user