Lorenze/supporting vertex embeddings (#4282)

* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support

- Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI.
- Updated factory methods to route to the new embedding function.
- Enhanced VertexAIProvider and related configurations to accommodate the new model options.
- Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods.

This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework.

* fix test count

* rm comment

* regen cassettes

* regen

* drop variable from .envtest

* dreict to relevant trest only
This commit is contained in:
Lorenze Jay
2026-01-26 14:55:03 -08:00
committed by GitHub
parent 9797567342
commit 58b866a83d
10 changed files with 14215 additions and 36 deletions

View File

@@ -18,7 +18,6 @@ if TYPE_CHECKING:
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
@overload
def build_embedder_from_dict(
spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ...
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
def build_embedder(
spec: VertexAIProviderSpec,
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload

View File

@@ -1,5 +1,8 @@
"""Google embedding providers."""
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider,
)
@@ -18,6 +21,7 @@ __all__ = [
"GenerativeAiProvider",
"GenerativeAiProviderConfig",
"GenerativeAiProviderSpec",
"GoogleGenAIVertexEmbeddingFunction",
"VertexAIProvider",
"VertexAIProviderConfig",
"VertexAIProviderSpec",

View File

@@ -0,0 +1,237 @@
"""Google Vertex AI embedding function implementation.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The deprecated vertexai.language_models module will be removed after June 24, 2026.
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from typing import Any, ClassVar, cast
import warnings
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Google Vertex AI with dual SDK support.
This class supports both:
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Supports two authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access
Example:
# Using legacy model (will emit deprecation warning)
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
region="us-central1",
model_name="textembedding-gecko"
)
# Using new model with google-genai SDK
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# Using API key (new SDK only)
embedder = GoogleGenAIVertexEmbeddingFunction(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
# Models that use the legacy vertexai.language_models SDK
LEGACY_MODELS: ClassVar[set[str]] = {
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
}
# Models that use the new google-genai SDK
GENAI_MODELS: ClassVar[set[str]] = {
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
}
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
"""Initialize Google Vertex AI embedding function.
Args:
**kwargs: Configuration parameters including:
- model_name: Model to use for embeddings (default: "textembedding-gecko")
- api_key: Optional API key for authentication (new SDK only)
- project_id: GCP project ID (for Vertex AI backend)
- location: GCP region (default: "us-central1")
- region: Deprecated alias for location
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
- output_dimensionality: Optional output embedding dimension (new SDK only)
"""
# Handle deprecated 'region' parameter (only if it has a value)
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
if region_value is not None:
warnings.warn(
"The 'region' parameter is deprecated, use 'location' instead. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=2,
)
if "location" not in kwargs or kwargs.get("location") is None:
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
self._config = kwargs
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
self._use_legacy = self._is_legacy_model(self._model_name)
if self._use_legacy:
self._init_legacy_client(**kwargs)
else:
self._init_genai_client(**kwargs)
def _is_legacy_model(self, model_name: str) -> bool:
"""Check if the model uses the legacy SDK."""
return model_name in self.LEGACY_MODELS or model_name.startswith(
"textembedding-gecko"
)
def _init_legacy_client(self, **kwargs: Any) -> None:
"""Initialize using the deprecated vertexai.language_models SDK."""
warnings.warn(
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
"which will be removed after June 24, 2026. Consider migrating to newer models "
"like 'gemini-embedding-001'. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=3,
)
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
except ImportError as e:
raise ImportError(
"vertexai is required for legacy embedding models (textembedding-gecko*). "
"Install it with: pip install google-cloud-aiplatform"
) from e
project_id = kwargs.get("project_id")
location = str(kwargs.get("location", "us-central1"))
if not project_id:
raise ValueError(
"project_id is required for legacy models. "
"For API key authentication, use newer models like 'gemini-embedding-001'."
)
vertexai.init(project=str(project_id), location=location)
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
def _init_genai_client(self, **kwargs: Any) -> None:
"""Initialize using the new google-genai SDK."""
try:
from google import genai
from google.genai.types import EmbedContentConfig
except ImportError as e:
raise ImportError(
"google-genai is required for Google Gen AI embeddings. "
"Install it with: uv add 'crewai[google-genai]'"
) from e
self._genai = genai
self._EmbedContentConfig = EmbedContentConfig
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
self._output_dimensionality = kwargs.get("output_dimensionality")
# Initialize client based on authentication mode
api_key = kwargs.get("api_key")
project_id = kwargs.get("project_id")
location: str = str(kwargs.get("location", "us-central1"))
if api_key:
self._client = genai.Client(api_key=api_key)
elif project_id:
self._client = genai.Client(
vertexai=True,
project=str(project_id),
location=location,
)
else:
raise ValueError(
"Either 'api_key' (for API key authentication) or 'project_id' "
"(for Vertex AI backend with ADC) must be provided."
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "google-vertex"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
if self._use_legacy:
return self._call_legacy(input)
return self._call_genai(input)
def _call_legacy(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the legacy SDK."""
import numpy as np
embeddings_list = []
for text in input:
embedding_result = self._legacy_model.get_embeddings([text])
embeddings_list.append(
np.array(embedding_result[0].values, dtype=np.float32)
)
return cast(Embeddings, embeddings_list)
def _call_genai(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the new google-genai SDK."""
# Build config for embed_content
config_kwargs: dict[str, Any] = {
"task_type": self._task_type,
}
if self._output_dimensionality is not None:
config_kwargs["output_dimensionality"] = self._output_dimensionality
config = self._EmbedContentConfig(**config_kwargs)
# Call the embedding API
response = self._client.models.embed_content(
model=self._model_name,
contents=input, # type: ignore[arg-type]
config=config,
)
# Extract embeddings from response
if response.embeddings is None:
raise ValueError("No embeddings returned from the API")
embeddings = [emb.values for emb in response.embeddings]
return cast(Embeddings, embeddings)

View File

@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider."""
"""Configuration for Vertex AI provider with dual SDK support.
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models using google-genai SDK.
Attributes:
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
model_name: Embedding model name (default: "textembedding-gecko").
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
project_id: GCP project ID (required for Vertex AI backend and legacy models).
location: GCP region/location (default: "us-central1").
region: Deprecated alias for location (kept for backwards compatibility).
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
"""
api_key: str
model_name: Annotated[str, "textembedding-gecko"]
project_id: Annotated[str, "cloud-large-language-models"]
region: Annotated[str, "us-central1"]
model_name: Annotated[
Literal[
# Legacy models (deprecated vertexai.language_models SDK)
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
# New models (google-genai SDK)
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"textembedding-gecko",
]
project_id: str
location: Annotated[str, "us-central1"]
region: Annotated[str, "us-central1"] # Deprecated alias for location
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
output_dimensionality: int
class VertexAIProviderSpec(TypedDict, total=False):

View File

@@ -1,46 +1,126 @@
"""Google Vertex AI embeddings provider."""
"""Google Vertex AI embeddings provider.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The SDK is automatically selected based on the model name:
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
- New models (gemini-embedding-*, text-embedding-*) use google-genai
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from __future__ import annotations
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider."""
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider with dual SDK support.
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
default=GoogleVertexEmbeddingFunction,
description="Vertex AI embedding function class",
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
using the google-genai SDK.
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access (new SDK models only)
Example:
# Legacy model (backwards compatible, will emit deprecation warning)
provider = VertexAIProvider(
project_id="my-project",
region="us-central1", # or location="us-central1"
model_name="textembedding-gecko"
)
# New model with Vertex AI backend
provider = VertexAIProvider(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# New model with API key
provider = VertexAIProvider(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
default=GoogleGenAIVertexEmbeddingFunction,
description="Google Vertex AI embedding function class",
)
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
description=(
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
"use the google-genai SDK."
),
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME",
"model",
),
)
api_key: str = Field(
description="Google API key",
api_key: str | None = Field(
default=None,
description="Google API key (optional if using project_id with Application Default Credentials)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
"GOOGLE_CLOUD_API_KEY",
"GOOGLE_API_KEY",
),
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
project_id: str | None = Field(
default=None,
description="GCP project ID (required for Vertex AI backend and legacy models)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_PROJECT",
),
)
region: str = Field(
location: str = Field(
default="us-central1",
description="GCP region",
description="GCP region/location",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_CLOUD_REGION",
),
)
region: str | None = Field(
default=None,
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
"GOOGLE_VERTEX_REGION",
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
"GOOGLE_VERTEX_TASK_TYPE",
),
)
output_dimensionality: int | None = Field(
default=None,
description="Output embedding dimensionality (optional). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
),
)

View File

@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
"""Test routing to Google Vertex provider with new genai model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"api_key": "test-google-api-key",
"model_name": "gemini-embedding-001",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-google-api-key"
assert call_kwargs["model_name"] == "gemini-embedding-001"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"region": "us-central1",
"model_name": "textembedding-gecko",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["region"] == "us-central1"
assert call_kwargs["model_name"] == "textembedding-gecko"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_location(self, mock_import):
"""Test routing to Google Vertex provider with location parameter."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"location": "europe-west1",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
"output_dimensionality": 768,
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["location"] == "europe-west1"
assert call_kwargs["model_name"] == "gemini-embedding-001"
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
assert call_kwargs["output_dimensionality"] == 768

View File

@@ -0,0 +1,176 @@
"""Integration tests for Google Vertex embeddings with Crew memory.
These tests make real API calls and use VCR to record/replay responses.
"""
import os
import threading
from collections import defaultdict
from unittest.mock import patch
import pytest
from crewai import Agent, Crew, Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
@pytest.fixture(autouse=True)
def setup_vertex_ai_env():
"""Set up environment for Vertex AI tests.
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
Also mocks GOOGLE_API_KEY if not already set.
"""
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
# Add a mock API key if none exists
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
env_updates["GOOGLE_API_KEY"] = "test-key"
with patch.dict(os.environ, env_updates):
yield
@pytest.fixture
def google_vertex_embedder_config():
"""Fixture providing Google Vertex embedder configuration."""
return {
"provider": "google-vertex",
"config": {
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
"model_name": "gemini-embedding-001",
},
}
@pytest.fixture
def simple_agent():
"""Fixture providing a simple test agent."""
return Agent(
role="Research Assistant",
goal="Help with research tasks",
backstory="You are a helpful research assistant.",
verbose=False,
)
@pytest.fixture
def simple_task(simple_agent):
"""Fixture providing a simple test task."""
return Task(
description="Summarize the key points about artificial intelligence in one sentence.",
expected_output="A one sentence summary about AI.",
agent=simple_agent,
)
@pytest.mark.vcr()
@pytest.mark.timeout(120) # Longer timeout for VCR recording
def test_crew_memory_with_google_vertex_embedder(
google_vertex_embedder_config, simple_agent, simple_task
) -> None:
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=google_vertex_embedder_config,
verbose=False,
)
result = crew.kickoff()
assert result is not None
assert result.raw is not None
assert len(result.raw) > 0
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
@pytest.mark.vcr()
@pytest.mark.timeout(120)
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
"""Test Crew memory with Google Vertex using project_id authentication."""
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
if not project_id:
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
embedder_config = {
"provider": "google-vertex",
"config": {
"project_id": project_id,
"location": "us-central1",
"model_name": "gemini-embedding-001",
},
}
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=embedder_config,
verbose=False,
)
result = crew.kickoff()
# Verify basic result
assert result is not None
assert result.raw is not None
# Wait for memory save events
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
# Verify memory was actually used
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"