mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support - Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI. - Updated factory methods to route to the new embedding function. - Enhanced VertexAIProvider and related configurations to accommodate the new model options. - Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods. This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework. * fix test count * rm comment * regen cassettes * regen * drop variable from .envtest * dreict to relevant trest only
This commit is contained in:
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
def build_embedder(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
@@ -18,6 +21,7 @@ __all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"GoogleGenAIVertexEmbeddingFunction",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Google Vertex AI embedding function implementation.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar, cast
|
||||
import warnings
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
||||
|
||||
|
||||
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for Google Vertex AI with dual SDK support.
|
||||
|
||||
This class supports both:
|
||||
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
||||
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access
|
||||
|
||||
Example:
|
||||
# Using legacy model (will emit deprecation warning)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
region="us-central1",
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# Using new model with google-genai SDK
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# Using API key (new SDK only)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
# Models that use the legacy vertexai.language_models SDK
|
||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
}
|
||||
|
||||
# Models that use the new google-genai SDK
|
||||
GENAI_MODELS: ClassVar[set[str]] = {
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
||||
"""Initialize Google Vertex AI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters including:
|
||||
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
||||
- api_key: Optional API key for authentication (new SDK only)
|
||||
- project_id: GCP project ID (for Vertex AI backend)
|
||||
- location: GCP region (default: "us-central1")
|
||||
- region: Deprecated alias for location
|
||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
self._use_legacy = self._is_legacy_model(self._model_name)
|
||||
|
||||
if self._use_legacy:
|
||||
self._init_legacy_client(**kwargs)
|
||||
else:
|
||||
self._init_genai_client(**kwargs)
|
||||
|
||||
def _is_legacy_model(self, model_name: str) -> bool:
|
||||
"""Check if the model uses the legacy SDK."""
|
||||
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
||||
"textembedding-gecko"
|
||||
)
|
||||
|
||||
def _init_legacy_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the deprecated vertexai.language_models SDK."""
|
||||
warnings.warn(
|
||||
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
||||
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
||||
"like 'gemini-embedding-001'. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
"Install it with: pip install google-cloud-aiplatform"
|
||||
) from e
|
||||
|
||||
project_id = kwargs.get("project_id")
|
||||
location = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"project_id is required for legacy models. "
|
||||
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
||||
)
|
||||
|
||||
vertexai.init(project=str(project_id), location=location)
|
||||
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
||||
|
||||
def _init_genai_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the new google-genai SDK."""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import EmbedContentConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"google-genai is required for Google Gen AI embeddings. "
|
||||
"Install it with: uv add 'crewai[google-genai]'"
|
||||
) from e
|
||||
|
||||
self._genai = genai
|
||||
self._EmbedContentConfig = EmbedContentConfig
|
||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||
|
||||
# Initialize client based on authentication mode
|
||||
api_key = kwargs.get("api_key")
|
||||
project_id = kwargs.get("project_id")
|
||||
location: str = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if api_key:
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
elif project_id:
|
||||
self._client = genai.Client(
|
||||
vertexai=True,
|
||||
project=str(project_id),
|
||||
location=location,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'api_key' (for API key authentication) or 'project_id' "
|
||||
"(for Vertex AI backend with ADC) must be provided."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "google-vertex"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if self._use_legacy:
|
||||
return self._call_legacy(input)
|
||||
return self._call_genai(input)
|
||||
|
||||
def _call_legacy(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the legacy SDK."""
|
||||
import numpy as np
|
||||
|
||||
embeddings_list = []
|
||||
for text in input:
|
||||
embedding_result = self._legacy_model.get_embeddings([text])
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result[0].values, dtype=np.float32)
|
||||
)
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the new google-genai SDK."""
|
||||
# Build config for embed_content
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"task_type": self._task_type,
|
||||
}
|
||||
if self._output_dimensionality is not None:
|
||||
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
||||
|
||||
config = self._EmbedContentConfig(**config_kwargs)
|
||||
|
||||
# Call the embedding API
|
||||
response = self._client.models.embed_content(
|
||||
model=self._model_name,
|
||||
contents=input, # type: ignore[arg-type]
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
if response.embeddings is None:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
embeddings = [emb.values for emb in response.embeddings]
|
||||
return cast(Embeddings, embeddings)
|
||||
@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
"""Configuration for Vertex AI provider with dual SDK support.
|
||||
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models using google-genai SDK.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
||||
model_name: Embedding model name (default: "textembedding-gecko").
|
||||
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
||||
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
||||
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
||||
location: GCP region/location (default: "us-central1").
|
||||
region: Deprecated alias for location (kept for backwards compatibility).
|
||||
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
||||
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
# Legacy models (deprecated vertexai.language_models SDK)
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
# New models (google-genai SDK)
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"textembedding-gecko",
|
||||
]
|
||||
project_id: str
|
||||
location: Annotated[str, "us-central1"]
|
||||
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
output_dimensionality: int
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -1,46 +1,126 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
"""Google Vertex AI embeddings provider.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The SDK is automatically selected based on the model name:
|
||||
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
||||
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
||||
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider with dual SDK support.
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
||||
using the google-genai SDK.
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access (new SDK models only)
|
||||
|
||||
Example:
|
||||
# Legacy model (backwards compatible, will emit deprecation warning)
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
region="us-central1", # or location="us-central1"
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# New model with Vertex AI backend
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# New model with API key
|
||||
provider = VertexAIProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
||||
default=GoogleGenAIVertexEmbeddingFunction,
|
||||
description="Google Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
description=(
|
||||
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
||||
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
||||
"use the google-genai SDK."
|
||||
),
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Google API key (optional if using project_id with Application Default Credentials)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
"GOOGLE_CLOUD_PROJECT",
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
description="GCP region/location",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
"GOOGLE_CLOUD_LOCATION",
|
||||
"GOOGLE_CLOUD_REGION",
|
||||
),
|
||||
)
|
||||
region: str | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
||||
"GOOGLE_VERTEX_REGION",
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
||||
"GOOGLE_VERTEX_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
output_dimensionality: int | None = Field(
|
||||
default=None,
|
||||
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
),
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with new genai model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "test-google-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-google-api-key"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"region": "us-central1",
|
||||
"model_name": "textembedding-gecko",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["region"] == "us-central1"
|
||||
assert call_kwargs["model_name"] == "textembedding-gecko"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with location parameter."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"location": "europe-west1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"output_dimensionality": 768,
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["location"] == "europe-west1"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
||||
assert call_kwargs["output_dimensionality"] == 768
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Integration tests for Google Vertex embeddings with Crew memory.
|
||||
|
||||
These tests make real API calls and use VCR to record/replay responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vertex_ai_env():
|
||||
"""Set up environment for Vertex AI tests.
|
||||
|
||||
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
||||
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
||||
Also mocks GOOGLE_API_KEY if not already set.
|
||||
"""
|
||||
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
||||
|
||||
# Add a mock API key if none exists
|
||||
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
||||
env_updates["GOOGLE_API_KEY"] = "test-key"
|
||||
|
||||
with patch.dict(os.environ, env_updates):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_vertex_embedder_config():
|
||||
"""Fixture providing Google Vertex embedder configuration."""
|
||||
return {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent():
|
||||
"""Fixture providing a simple test agent."""
|
||||
return Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research tasks",
|
||||
backstory="You are a helpful research assistant.",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(simple_agent):
|
||||
"""Fixture providing a simple test task."""
|
||||
return Task(
|
||||
description="Summarize the key points about artificial intelligence in one sentence.",
|
||||
expected_output="A one sentence summary about AI.",
|
||||
agent=simple_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
||||
def test_crew_memory_with_google_vertex_embedder(
|
||||
google_vertex_embedder_config, simple_agent, simple_task
|
||||
) -> None:
|
||||
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=google_vertex_embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert len(result.raw) > 0
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120)
|
||||
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
||||
"""Test Crew memory with Google Vertex using project_id authentication."""
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if not project_id:
|
||||
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
||||
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": project_id,
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify basic result
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
# Wait for memory save events
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Verify memory was actually used
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
||||
Reference in New Issue
Block a user