mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support - Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI. - Updated factory methods to route to the new embedding function. - Enhanced VertexAIProvider and related configurations to accommodate the new model options. - Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods. This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework. * fix test count * rm comment * regen cassettes * regen * drop variable from .envtest * dreict to relevant trest only
This commit is contained in:
@@ -401,23 +401,58 @@ crew = Crew(
|
||||
|
||||
### Vertex AI Embeddings
|
||||
|
||||
For Google Cloud users with Vertex AI access.
|
||||
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
|
||||
|
||||
<Note>
|
||||
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
|
||||
</Note>
|
||||
|
||||
```python
|
||||
# Recommended: Using new models with google-genai SDK
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "vertexai",
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or your preferred region
|
||||
"api_key": "your-service-account-key",
|
||||
"model_name": "textembedding-gecko"
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||
"task_type": "RETRIEVAL_DOCUMENT", # Optional
|
||||
"output_dimensionality": 768 # Optional
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Using API key authentication (Exp)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "your-google-api-key",
|
||||
"model_name": "gemini-embedding-001"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Legacy models (backwards compatible, emits deprecation warning)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or "location" (region is deprecated)
|
||||
"model_name": "textembedding-gecko" # Legacy model
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Available models:**
|
||||
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
|
||||
|
||||
### Ollama Embeddings (Local)
|
||||
|
||||
Run embeddings locally for privacy and cost savings.
|
||||
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
@@ -591,7 +626,7 @@ crew = Crew(
|
||||
|
||||
### Choosing the Right Embedding Provider
|
||||
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
Below is a comparison to help you decide:
|
||||
|
||||
| Provider | Best For | Pros | Cons |
|
||||
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
|
||||
|
||||
This improves performance and observability when writing many entities in one operation.
|
||||
|
||||
## 2. External Memory
|
||||
## 2. External Memory
|
||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||
|
||||
### Basic External Memory with Mem0
|
||||
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
|
||||
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
def build_embedder(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
@@ -18,6 +21,7 @@ __all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"GoogleGenAIVertexEmbeddingFunction",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Google Vertex AI embedding function implementation.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar, cast
|
||||
import warnings
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
||||
|
||||
|
||||
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for Google Vertex AI with dual SDK support.
|
||||
|
||||
This class supports both:
|
||||
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
||||
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access
|
||||
|
||||
Example:
|
||||
# Using legacy model (will emit deprecation warning)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
region="us-central1",
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# Using new model with google-genai SDK
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# Using API key (new SDK only)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
# Models that use the legacy vertexai.language_models SDK
|
||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
}
|
||||
|
||||
# Models that use the new google-genai SDK
|
||||
GENAI_MODELS: ClassVar[set[str]] = {
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
||||
"""Initialize Google Vertex AI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters including:
|
||||
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
||||
- api_key: Optional API key for authentication (new SDK only)
|
||||
- project_id: GCP project ID (for Vertex AI backend)
|
||||
- location: GCP region (default: "us-central1")
|
||||
- region: Deprecated alias for location
|
||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
self._use_legacy = self._is_legacy_model(self._model_name)
|
||||
|
||||
if self._use_legacy:
|
||||
self._init_legacy_client(**kwargs)
|
||||
else:
|
||||
self._init_genai_client(**kwargs)
|
||||
|
||||
def _is_legacy_model(self, model_name: str) -> bool:
|
||||
"""Check if the model uses the legacy SDK."""
|
||||
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
||||
"textembedding-gecko"
|
||||
)
|
||||
|
||||
def _init_legacy_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the deprecated vertexai.language_models SDK."""
|
||||
warnings.warn(
|
||||
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
||||
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
||||
"like 'gemini-embedding-001'. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
"Install it with: pip install google-cloud-aiplatform"
|
||||
) from e
|
||||
|
||||
project_id = kwargs.get("project_id")
|
||||
location = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"project_id is required for legacy models. "
|
||||
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
||||
)
|
||||
|
||||
vertexai.init(project=str(project_id), location=location)
|
||||
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
||||
|
||||
def _init_genai_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the new google-genai SDK."""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import EmbedContentConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"google-genai is required for Google Gen AI embeddings. "
|
||||
"Install it with: uv add 'crewai[google-genai]'"
|
||||
) from e
|
||||
|
||||
self._genai = genai
|
||||
self._EmbedContentConfig = EmbedContentConfig
|
||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||
|
||||
# Initialize client based on authentication mode
|
||||
api_key = kwargs.get("api_key")
|
||||
project_id = kwargs.get("project_id")
|
||||
location: str = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if api_key:
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
elif project_id:
|
||||
self._client = genai.Client(
|
||||
vertexai=True,
|
||||
project=str(project_id),
|
||||
location=location,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'api_key' (for API key authentication) or 'project_id' "
|
||||
"(for Vertex AI backend with ADC) must be provided."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "google-vertex"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if self._use_legacy:
|
||||
return self._call_legacy(input)
|
||||
return self._call_genai(input)
|
||||
|
||||
def _call_legacy(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the legacy SDK."""
|
||||
import numpy as np
|
||||
|
||||
embeddings_list = []
|
||||
for text in input:
|
||||
embedding_result = self._legacy_model.get_embeddings([text])
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result[0].values, dtype=np.float32)
|
||||
)
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the new google-genai SDK."""
|
||||
# Build config for embed_content
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"task_type": self._task_type,
|
||||
}
|
||||
if self._output_dimensionality is not None:
|
||||
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
||||
|
||||
config = self._EmbedContentConfig(**config_kwargs)
|
||||
|
||||
# Call the embedding API
|
||||
response = self._client.models.embed_content(
|
||||
model=self._model_name,
|
||||
contents=input, # type: ignore[arg-type]
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
if response.embeddings is None:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
embeddings = [emb.values for emb in response.embeddings]
|
||||
return cast(Embeddings, embeddings)
|
||||
@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
"""Configuration for Vertex AI provider with dual SDK support.
|
||||
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models using google-genai SDK.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
||||
model_name: Embedding model name (default: "textembedding-gecko").
|
||||
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
||||
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
||||
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
||||
location: GCP region/location (default: "us-central1").
|
||||
region: Deprecated alias for location (kept for backwards compatibility).
|
||||
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
||||
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
# Legacy models (deprecated vertexai.language_models SDK)
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
# New models (google-genai SDK)
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"textembedding-gecko",
|
||||
]
|
||||
project_id: str
|
||||
location: Annotated[str, "us-central1"]
|
||||
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
output_dimensionality: int
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -1,46 +1,126 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
"""Google Vertex AI embeddings provider.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The SDK is automatically selected based on the model name:
|
||||
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
||||
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
||||
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider with dual SDK support.
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
||||
using the google-genai SDK.
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access (new SDK models only)
|
||||
|
||||
Example:
|
||||
# Legacy model (backwards compatible, will emit deprecation warning)
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
region="us-central1", # or location="us-central1"
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# New model with Vertex AI backend
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# New model with API key
|
||||
provider = VertexAIProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
||||
default=GoogleGenAIVertexEmbeddingFunction,
|
||||
description="Google Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
description=(
|
||||
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
||||
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
||||
"use the google-genai SDK."
|
||||
),
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Google API key (optional if using project_id with Application Default Credentials)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
"GOOGLE_CLOUD_PROJECT",
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
description="GCP region/location",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
"GOOGLE_CLOUD_LOCATION",
|
||||
"GOOGLE_CLOUD_REGION",
|
||||
),
|
||||
)
|
||||
region: str | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
||||
"GOOGLE_VERTEX_REGION",
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
||||
"GOOGLE_VERTEX_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
output_dimensionality: int | None = Field(
|
||||
default=None,
|
||||
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
),
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with new genai model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "test-google-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-google-api-key"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"region": "us-central1",
|
||||
"model_name": "textembedding-gecko",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["region"] == "us-central1"
|
||||
assert call_kwargs["model_name"] == "textembedding-gecko"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with location parameter."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"location": "europe-west1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"output_dimensionality": 768,
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["location"] == "europe-west1"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
||||
assert call_kwargs["output_dimensionality"] == 768
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Integration tests for Google Vertex embeddings with Crew memory.
|
||||
|
||||
These tests make real API calls and use VCR to record/replay responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vertex_ai_env():
|
||||
"""Set up environment for Vertex AI tests.
|
||||
|
||||
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
||||
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
||||
Also mocks GOOGLE_API_KEY if not already set.
|
||||
"""
|
||||
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
||||
|
||||
# Add a mock API key if none exists
|
||||
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
||||
env_updates["GOOGLE_API_KEY"] = "test-key"
|
||||
|
||||
with patch.dict(os.environ, env_updates):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_vertex_embedder_config():
|
||||
"""Fixture providing Google Vertex embedder configuration."""
|
||||
return {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent():
|
||||
"""Fixture providing a simple test agent."""
|
||||
return Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research tasks",
|
||||
backstory="You are a helpful research assistant.",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(simple_agent):
|
||||
"""Fixture providing a simple test task."""
|
||||
return Task(
|
||||
description="Summarize the key points about artificial intelligence in one sentence.",
|
||||
expected_output="A one sentence summary about AI.",
|
||||
agent=simple_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
||||
def test_crew_memory_with_google_vertex_embedder(
|
||||
google_vertex_embedder_config, simple_agent, simple_task
|
||||
) -> None:
|
||||
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=google_vertex_embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert len(result.raw) > 0
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120)
|
||||
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
||||
"""Test Crew memory with Google Vertex using project_id authentication."""
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if not project_id:
|
||||
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
||||
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": project_id,
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify basic result
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
# Wait for memory save events
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Verify memory was actually used
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
||||
Reference in New Issue
Block a user