Compare commits

...

5 Commits

Author SHA1 Message Date
Greyson LaLonde
39e6618dab chore: add missing change logs 2026-01-26 19:52:00 -05:00
Lorenze Jay
6b926b90d0 chore: update version to 1.9.0 across all relevant files (#4284)
Some checks are pending
Build uv cache / build-cache (3.10) (push) Waiting to run
Build uv cache / build-cache (3.11) (push) Waiting to run
Build uv cache / build-cache (3.12) (push) Waiting to run
Build uv cache / build-cache (3.13) (push) Waiting to run
Check Documentation Broken Links / Check broken links (push) Waiting to run
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Notify Downstream / notify-downstream (push) Waiting to run
- Bumped the version number to 1.9.0 in pyproject.toml files and __init__.py files across the CrewAI library and its tools.
- Updated dependencies to use the new version of crewai-tools (1.9.0) for improved functionality and compatibility.
- Ensured consistency in versioning across the codebase to reflect the latest updates.
2026-01-26 16:36:35 -08:00
Lorenze Jay
fc84daadbb fix: enhance file store with fallback memory cache when aiocache is n… (#4283)
* fix: enhance file store with fallback memory cache when aiocache is not installed

- Added a simple in-memory cache implementation to serve as a fallback when the aiocache library is unavailable.
- Improved error handling for the aiocache import, ensuring that the application can still function without it.
- This change enhances the robustness of the file store utility by providing a reliable caching mechanism in various environments.

* drop fallback

* Potential fix for pull request finding 'Unused global variable'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
2026-01-26 15:12:34 -08:00
Lorenze Jay
58b866a83d Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support

- Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI.
- Updated factory methods to route to the new embedding function.
- Enhanced VertexAIProvider and related configurations to accommodate the new model options.
- Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods.

This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework.

* fix test count

* rm comment

* regen cassettes

* regen

* drop variable from .envtest

* dreict to relevant trest only
2026-01-26 14:55:03 -08:00
Greyson LaLonde
9797567342 feat: add structured outputs and response_format support across providers (#4280)
* feat: add response_format parameter to Azure and Gemini providers

* feat: add structured outputs support to Bedrock and Anthropic providers

* chore: bump anthropic dep

* fix: use beta structured output for new models
2026-01-26 11:03:33 -08:00
26 changed files with 14810 additions and 223 deletions

View File

@@ -4,6 +4,74 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="Jan 26, 2026">
## v1.9.0
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
## What's Changed
### Features
- Add structured outputs and response_format support across providers
- Add response ID to streaming responses
- Add event ordering with parent-child hierarchies
- Add Keycloak SSO authentication support
- Add multimodal file handling capabilities
- Add native OpenAI responses API support
- Add A2A task execution utilities
- Add A2A server configuration and agent card generation
- Enhance event system and expand transport options
- Improve tool calling mechanisms
### Bug Fixes
- Enhance file store with fallback memory cache when aiocache is not available
- Ensure document list is not empty
- Handle Bedrock stop sequences properly
- Add Google Vertex API key support
- Enhance Azure model stop word detection
- Improve error handling for HumanFeedbackPending in flow execution
- Fix execution span task unlinking
### Documentation
- Add native file handling documentation
- Add OpenAI responses API documentation
- Add agent card implementation guidance
- Refine A2A documentation
- Update changelog for v1.8.0
### Contributors
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
</Update>
<Update label="Jan 15, 2026">
## v1.8.1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
## What's Changed
### Features
- Add A2A task execution utilities
- Add A2A server configuration and agent card generation
- Add additional transport mechanisms
- Add Galileo integration support
### Bug Fixes
- Improve Azure model compatibility
- Expand frame inspection depth to detect parent_flow
- Resolve task execution span management issues
- Enhance error handling for human feedback scenarios during flow execution
### Documentation
- Add A2A agent card documentation
- Add PII redaction feature documentation
### Contributors
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
</Update>
<Update label="Jan 08, 2026">
## v1.8.0

View File

@@ -401,23 +401,58 @@ crew = Crew(
### Vertex AI Embeddings
For Google Cloud users with Vertex AI access.
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
<Note>
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
</Note>
```python
# Recommended: Using new models with google-genai SDK
crew = Crew(
memory=True,
embedder={
"provider": "vertexai",
"provider": "google-vertex",
"config": {
"project_id": "your-gcp-project-id",
"region": "us-central1", # or your preferred region
"api_key": "your-service-account-key",
"model_name": "textembedding-gecko"
"location": "us-central1",
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
"task_type": "RETRIEVAL_DOCUMENT", # Optional
"output_dimensionality": 768 # Optional
}
}
)
# Using API key authentication (Exp)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"api_key": "your-google-api-key",
"model_name": "gemini-embedding-001"
}
}
)
# Legacy models (backwards compatible, emits deprecation warning)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"project_id": "your-gcp-project-id",
"region": "us-central1", # or "location" (region is deprecated)
"model_name": "textembedding-gecko" # Legacy model
}
}
)
```
**Available models:**
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
### Ollama Embeddings (Local)
Run embeddings locally for privacy and cost savings.
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
"project_id": "my_project_id", # Optional
"api_key": "custom-api-key" # Optional - overrides env var
"run_id": "my_run_id", # Optional - for short-term memory
"includes": "include1", # Optional
"includes": "include1", # Optional
"excludes": "exclude1", # Optional
"infer": True # Optional defaults to True
"custom_categories": new_categories # Optional - custom categories for user memory
@@ -591,7 +626,7 @@ crew = Crew(
### Choosing the Right Embedding Provider
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
Below is a comparison to help you decide:
| Provider | Best For | Pros | Cons |
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
This improves performance and observability when writing many entities in one operation.
## 2. External Memory
## 2. External Memory
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
### Basic External Memory with Mem0
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
"project_id": "my_project_id", # Optional
"api_key": "custom-api-key" # Optional - overrides env var
"run_id": "my_run_id", # Optional - for short-term memory
"includes": "include1", # Optional
"includes": "include1", # Optional
"excludes": "exclude1", # Optional
"infer": True # Optional defaults to True
"custom_categories": new_categories # Optional - custom categories for user memory

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.8.1"
__version__ = "1.9.0"

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.8.1",
"crewai==1.9.0",
"lancedb~=0.5.4",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.8.1"
__version__ = "1.9.0"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.8.1",
"crewai-tools==1.9.0",
]
embeddings = [
"tiktoken~=0.8.0"
@@ -90,7 +90,7 @@ azure-ai-inference = [
"azure-ai-inference~=1.0.0b9",
]
anthropic = [
"anthropic~=0.71.0",
"anthropic~=0.73.0",
]
a2a = [
"a2a-sdk~=0.3.10",

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.8.1"
__version__ = "1.9.0"
_telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.8.1"
"crewai[tools]==1.9.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.8.1"
"crewai[tools]==1.9.0"
]
[project.scripts]

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
from anthropic.types import ThinkingBlock
from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from anthropic import Anthropic, AsyncAnthropic
from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
import httpx
except ImportError:
raise ImportError(
@@ -31,7 +31,62 @@ except ImportError:
) from None
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14"
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
tuple[
Literal["claude-sonnet-4-5"],
Literal["claude-sonnet-4.5"],
Literal["claude-opus-4-5"],
Literal["claude-opus-4.5"],
Literal["claude-opus-4-1"],
Literal["claude-opus-4.1"],
Literal["claude-haiku-4-5"],
Literal["claude-haiku-4.5"],
]
] = (
"claude-sonnet-4-5",
"claude-sonnet-4.5",
"claude-opus-4-5",
"claude-opus-4.5",
"claude-opus-4-1",
"claude-opus-4.1",
"claude-haiku-4-5",
"claude-haiku-4.5",
)
def _supports_native_structured_outputs(model: str) -> bool:
"""Check if the model supports native structured outputs.
Native structured outputs are only available for Claude 4.5 models
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
Other models require the tool-based fallback approach.
Args:
model: The model name/identifier.
Returns:
True if the model supports native structured outputs.
"""
model_lower = model.lower()
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
"""Check if an object is a Pydantic model class.
This distinguishes between Pydantic model classes that support structured
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
Args:
obj: The object to check.
Returns:
True if obj is a Pydantic model class.
"""
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
response_format: Pydantic model for structured output. When provided, responses
will be validated against this model schema.
**kwargs: Additional parameters
"""
super().__init__(
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
self.stop_sequences = stop_sequences or []
self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True
@@ -231,6 +290,8 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools
)
effective_response_model = response_model or self.response_format
# Handle streaming vs non-streaming
if self.stream:
return self._handle_streaming_completion(
@@ -238,7 +299,7 @@ class AnthropicCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return self._handle_completion(
@@ -246,7 +307,7 @@ class AnthropicCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -298,13 +359,15 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools
)
effective_response_model = response_model or self.response_format
if self.stream:
return await self._ahandle_streaming_completion(
completion_params,
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return await self._ahandle_completion(
@@ -312,7 +375,7 @@ class AnthropicCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -565,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
if uses_file_api:
params["betas"] = [ANTHROPIC_FILES_API_BETA]
response = self.client.beta.messages.create(**params)
if betas:
params["betas"] = betas
response = self.client.beta.messages.create(
**params, extra_body=extra_body
)
else:
response = self.client.messages.create(**params)
@@ -593,22 +674,34 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and response.content:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
# Check if Claude wants to use tools
if response.content:
@@ -678,17 +771,31 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
betas: list[str] = []
use_native_structured_output = False
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = ""
@@ -696,15 +803,22 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
stream_context = (
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
if betas
else self.client.messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None
for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"):
response_id = event.message.id
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
full_response += text_delta
@@ -712,7 +826,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if event.type == "content_block_start":
@@ -739,7 +853,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
@@ -763,10 +877,10 @@ class AnthropicCompletion(BaseLLM):
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
final_message: Message = stream.get_final_message()
final_message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = []
if final_message.content:
@@ -781,25 +895,30 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
if response_model and final_message.content:
tool_uses = [
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
self._emit_call_completed_event(
response=structured_json,
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return full_response
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content:
tool_uses = [
@@ -809,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions:
return list(tool_uses)
# Handle tool use conversation flow internally
return self._handle_tool_use_conversation(
final_message,
tool_uses,
@@ -823,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
from_agent,
)
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -884,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
def _handle_tool_use_conversation(
self,
initial_response: Message,
initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock],
params: dict[str, Any],
available_functions: dict[str, Any],
@@ -1002,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming async message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
if uses_file_api:
params["betas"] = [ANTHROPIC_FILES_API_BETA]
response = await self.async_client.beta.messages.create(**params)
if betas:
params["betas"] = betas
response = await self.async_client.beta.messages.create(
**params, extra_body=extra_body
)
else:
response = await self.async_client.messages.create(**params)
@@ -1030,23 +1163,34 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and response.content:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if response.content:
tool_uses = [
@@ -1102,25 +1246,49 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle async streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
betas: list[str] = []
use_native_structured_output = False
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = ""
stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream:
stream_context = (
self.async_client.beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self.async_client.messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None
async for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"):
@@ -1133,7 +1301,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if event.type == "content_block_start":
@@ -1160,7 +1328,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta":
@@ -1184,33 +1352,38 @@ class AnthropicCompletion(BaseLLM):
"index": block_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
final_message: Message = await stream.get_final_message()
final_message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
if response_model and final_message.content:
tool_uses = [
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
self._emit_call_completed_event(
response=structured_json,
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return full_response
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content:
tool_uses = [
@@ -1220,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions:
return list(tool_uses)
@@ -1247,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
async def _ahandle_tool_use_conversation(
self,
initial_response: Message,
initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock],
params: dict[str, Any],
available_functions: dict[str, Any],
@@ -1356,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
@staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
def _extract_anthropic_token_usage(
response: Message | BetaMessage,
) -> dict[str, Any]:
"""Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage

View File

@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
Only works with OpenAI models deployed on Azure.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.response_format = response_format
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
effective_response_model = response_model or self.response_format
# Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages)
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
formatted_messages, tools, effective_response_model
)
# Handle streaming vs non-streaming
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return self._handle_completion(
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
effective_response_model = response_model or self.response_format
formatted_messages = self._format_messages_for_azure(messages)
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
formatted_messages, tools, effective_response_model
)
if self.stream:
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return await self._ahandle_completion(
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
"""
if update.choices:
choice = update.choices[0]
response_id = update.id if hasattr(update,"id") else None
response_id = update.id if hasattr(update, "id") else None
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if choice.delta and choice.delta.tool_calls:
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
"index": idx,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
return full_response

View File

@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
) -> None:
"""Initialize AWS Bedrock completion client.
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
self.top_k = top_k
self.stream = stream
self.stop_sequences = stop_sequences
self.response_format = response_format
# Store advanced features (optional)
self.guardrail_config = guardrail_config
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call AWS Bedrock Converse API."""
effective_response_model = response_model or self.response_format
try:
# Emit call started event
self._emit_call_started_event(
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
effective_response_model,
)
return self._handle_converse(
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
effective_response_model,
)
except Exception as e:
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
NotImplementedError: If aiobotocore is not installed.
LLMContextLengthExceededError: If context window is exceeded.
"""
effective_response_model = response_model or self.response_format
if not AIOBOTOCORE_AVAILABLE:
raise NotImplementedError(
"Async support for AWS Bedrock requires aiobotocore. "
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
if self.stream:
return await self._ahandle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
)
return await self._ahandle_converse(
formatted_messages, body, available_functions, from_task, from_agent
formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
)
except Exception as e:
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming converse API call following AWS best practices."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try:
# Validate messages format before API call
if not messages:
raise ValueError("Messages cannot be empty")
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions:
self._emit_call_completed_event(
response=tool_uses,
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming converse API call with comprehensive event handling."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = ""
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try:
if not messages:
raise ValueError("Messages cannot be empty")
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions:
self._emit_call_completed_event(
response=tool_uses,
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle async streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = ""
current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")

View File

@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
When using Vertex AI with API key (Express mode), http_options with
api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
# Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent,
)
self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config(
system_instruction, tools, response_model
system_instruction, tools, effective_response_model
)
if self.stream:
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return self._handle_completion(
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except APIError as e:
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent,
)
self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
)
config = self._prepare_generation_config(
system_instruction, tools, response_model
system_instruction, tools, effective_response_model
)
if self.stream:
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return await self._ahandle_completion(
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except APIError as e:
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
types.Content(role="user", parts=[function_response_part])
)
elif role == "assistant" and message.get("tool_calls"):
parts: list[types.Part] = []
tool_parts: list[types.Part] = []
if text_content:
parts.append(types.Part.from_text(text=text_content))
tool_parts.append(types.Part.from_text(text=text_content))
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls:
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
else:
func_args = func_args_raw
parts.append(
tool_parts.append(
types.Part.from_function_call(name=func_name, args=func_args)
)
contents.append(types.Content(role="model", parts=parts))
contents.append(types.Content(role="model", parts=tool_parts))
else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user"
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
Returns:
Tuple of (updated full_response, updated function_calls, updated usage_data)
"""
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk)
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if chunk.candidates:
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
"index": call_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
return full_response, function_calls, usage_data
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}

View File

@@ -18,7 +18,6 @@ if TYPE_CHECKING:
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
@overload
def build_embedder_from_dict(
spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ...
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
def build_embedder(
spec: VertexAIProviderSpec,
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload

View File

@@ -1,5 +1,8 @@
"""Google embedding providers."""
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider,
)
@@ -18,6 +21,7 @@ __all__ = [
"GenerativeAiProvider",
"GenerativeAiProviderConfig",
"GenerativeAiProviderSpec",
"GoogleGenAIVertexEmbeddingFunction",
"VertexAIProvider",
"VertexAIProviderConfig",
"VertexAIProviderSpec",

View File

@@ -0,0 +1,237 @@
"""Google Vertex AI embedding function implementation.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The deprecated vertexai.language_models module will be removed after June 24, 2026.
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from typing import Any, ClassVar, cast
import warnings
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Google Vertex AI with dual SDK support.
This class supports both:
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Supports two authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access
Example:
# Using legacy model (will emit deprecation warning)
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
region="us-central1",
model_name="textembedding-gecko"
)
# Using new model with google-genai SDK
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# Using API key (new SDK only)
embedder = GoogleGenAIVertexEmbeddingFunction(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
# Models that use the legacy vertexai.language_models SDK
LEGACY_MODELS: ClassVar[set[str]] = {
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
}
# Models that use the new google-genai SDK
GENAI_MODELS: ClassVar[set[str]] = {
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
}
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
"""Initialize Google Vertex AI embedding function.
Args:
**kwargs: Configuration parameters including:
- model_name: Model to use for embeddings (default: "textembedding-gecko")
- api_key: Optional API key for authentication (new SDK only)
- project_id: GCP project ID (for Vertex AI backend)
- location: GCP region (default: "us-central1")
- region: Deprecated alias for location
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
- output_dimensionality: Optional output embedding dimension (new SDK only)
"""
# Handle deprecated 'region' parameter (only if it has a value)
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
if region_value is not None:
warnings.warn(
"The 'region' parameter is deprecated, use 'location' instead. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=2,
)
if "location" not in kwargs or kwargs.get("location") is None:
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
self._config = kwargs
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
self._use_legacy = self._is_legacy_model(self._model_name)
if self._use_legacy:
self._init_legacy_client(**kwargs)
else:
self._init_genai_client(**kwargs)
def _is_legacy_model(self, model_name: str) -> bool:
"""Check if the model uses the legacy SDK."""
return model_name in self.LEGACY_MODELS or model_name.startswith(
"textembedding-gecko"
)
def _init_legacy_client(self, **kwargs: Any) -> None:
"""Initialize using the deprecated vertexai.language_models SDK."""
warnings.warn(
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
"which will be removed after June 24, 2026. Consider migrating to newer models "
"like 'gemini-embedding-001'. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=3,
)
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
except ImportError as e:
raise ImportError(
"vertexai is required for legacy embedding models (textembedding-gecko*). "
"Install it with: pip install google-cloud-aiplatform"
) from e
project_id = kwargs.get("project_id")
location = str(kwargs.get("location", "us-central1"))
if not project_id:
raise ValueError(
"project_id is required for legacy models. "
"For API key authentication, use newer models like 'gemini-embedding-001'."
)
vertexai.init(project=str(project_id), location=location)
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
def _init_genai_client(self, **kwargs: Any) -> None:
"""Initialize using the new google-genai SDK."""
try:
from google import genai
from google.genai.types import EmbedContentConfig
except ImportError as e:
raise ImportError(
"google-genai is required for Google Gen AI embeddings. "
"Install it with: uv add 'crewai[google-genai]'"
) from e
self._genai = genai
self._EmbedContentConfig = EmbedContentConfig
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
self._output_dimensionality = kwargs.get("output_dimensionality")
# Initialize client based on authentication mode
api_key = kwargs.get("api_key")
project_id = kwargs.get("project_id")
location: str = str(kwargs.get("location", "us-central1"))
if api_key:
self._client = genai.Client(api_key=api_key)
elif project_id:
self._client = genai.Client(
vertexai=True,
project=str(project_id),
location=location,
)
else:
raise ValueError(
"Either 'api_key' (for API key authentication) or 'project_id' "
"(for Vertex AI backend with ADC) must be provided."
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "google-vertex"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
if self._use_legacy:
return self._call_legacy(input)
return self._call_genai(input)
def _call_legacy(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the legacy SDK."""
import numpy as np
embeddings_list = []
for text in input:
embedding_result = self._legacy_model.get_embeddings([text])
embeddings_list.append(
np.array(embedding_result[0].values, dtype=np.float32)
)
return cast(Embeddings, embeddings_list)
def _call_genai(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the new google-genai SDK."""
# Build config for embed_content
config_kwargs: dict[str, Any] = {
"task_type": self._task_type,
}
if self._output_dimensionality is not None:
config_kwargs["output_dimensionality"] = self._output_dimensionality
config = self._EmbedContentConfig(**config_kwargs)
# Call the embedding API
response = self._client.models.embed_content(
model=self._model_name,
contents=input, # type: ignore[arg-type]
config=config,
)
# Extract embeddings from response
if response.embeddings is None:
raise ValueError("No embeddings returned from the API")
embeddings = [emb.values for emb in response.embeddings]
return cast(Embeddings, embeddings)

View File

@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider."""
"""Configuration for Vertex AI provider with dual SDK support.
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models using google-genai SDK.
Attributes:
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
model_name: Embedding model name (default: "textembedding-gecko").
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
project_id: GCP project ID (required for Vertex AI backend and legacy models).
location: GCP region/location (default: "us-central1").
region: Deprecated alias for location (kept for backwards compatibility).
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
"""
api_key: str
model_name: Annotated[str, "textembedding-gecko"]
project_id: Annotated[str, "cloud-large-language-models"]
region: Annotated[str, "us-central1"]
model_name: Annotated[
Literal[
# Legacy models (deprecated vertexai.language_models SDK)
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
# New models (google-genai SDK)
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"textembedding-gecko",
]
project_id: str
location: Annotated[str, "us-central1"]
region: Annotated[str, "us-central1"] # Deprecated alias for location
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
output_dimensionality: int
class VertexAIProviderSpec(TypedDict, total=False):

View File

@@ -1,46 +1,126 @@
"""Google Vertex AI embeddings provider."""
"""Google Vertex AI embeddings provider.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The SDK is automatically selected based on the model name:
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
- New models (gemini-embedding-*, text-embedding-*) use google-genai
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from __future__ import annotations
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider."""
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider with dual SDK support.
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
default=GoogleVertexEmbeddingFunction,
description="Vertex AI embedding function class",
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
using the google-genai SDK.
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access (new SDK models only)
Example:
# Legacy model (backwards compatible, will emit deprecation warning)
provider = VertexAIProvider(
project_id="my-project",
region="us-central1", # or location="us-central1"
model_name="textembedding-gecko"
)
# New model with Vertex AI backend
provider = VertexAIProvider(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# New model with API key
provider = VertexAIProvider(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
default=GoogleGenAIVertexEmbeddingFunction,
description="Google Vertex AI embedding function class",
)
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
description=(
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
"use the google-genai SDK."
),
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME",
"model",
),
)
api_key: str = Field(
description="Google API key",
api_key: str | None = Field(
default=None,
description="Google API key (optional if using project_id with Application Default Credentials)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
"GOOGLE_CLOUD_API_KEY",
"GOOGLE_API_KEY",
),
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
project_id: str | None = Field(
default=None,
description="GCP project ID (required for Vertex AI backend and legacy models)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_PROJECT",
),
)
region: str = Field(
location: str = Field(
default="us-central1",
description="GCP region",
description="GCP region/location",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_CLOUD_REGION",
),
)
region: str | None = Field(
default=None,
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
"GOOGLE_VERTEX_REGION",
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
"GOOGLE_VERTEX_TASK_TYPE",
),
)
output_dimensionality: int | None = Field(
default=None,
description="Output embedding dimensionality (optional). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
),
)

View File

@@ -5,17 +5,29 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
import concurrent.futures
import logging
from typing import TYPE_CHECKING, TypeVar
from uuid import UUID
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
if TYPE_CHECKING:
from aiocache import Cache
from crewai_files import FileInput
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
logger = logging.getLogger(__name__)
_file_store: Cache | None = None
try:
from aiocache import Cache
from aiocache.serializers import PickleSerializer
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
except ImportError:
logger.debug(
"aiocache is not installed. File store features will be disabled. "
"Install with: uv add aiocache"
)
T = TypeVar("T")
@@ -59,6 +71,8 @@ async def astore_files(
files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds.
"""
if _file_store is None:
return
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
Returns:
Dictionary of files or None if not found.
"""
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get(
f"{_CREW_PREFIX}{execution_id}"
)
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
Args:
execution_id: Unique identifier for the crew execution.
"""
if _file_store is None:
return
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
@@ -98,6 +116,8 @@ async def astore_task_files(
files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds.
"""
if _file_store is None:
return
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
Returns:
Dictionary of files or None if not found.
"""
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get(
f"{_TASK_PREFIX}{task_id}"
)
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
Args:
task_id: Unique identifier for the task.
"""
if _file_store is None:
return
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")

View File

@@ -1,6 +1,8 @@
interactions:
- request:
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}'
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
the structured response","input_schema":{"type":"object","description":"Response
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
headers:
User-Agent:
- X-USER-AGENT-XXX
@@ -13,7 +15,7 @@ interactions:
connection:
- keep-alive
content-length:
- '539'
- '551'
content-type:
- application/json
host:
@@ -29,7 +31,7 @@ interactions:
x-stainless-os:
- X-STAINLESS-OS-XXX
x-stainless-package-version:
- 0.75.0
- 0.76.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
@@ -42,7 +44,7 @@ interactions:
uri: https://api.anthropic.com/v1/messages
response:
body:
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
headers:
CF-RAY:
- CF-RAY-XXX
@@ -51,7 +53,7 @@ interactions:
Content-Type:
- application/json
Date:
- Mon, 01 Dec 2025 11:19:38 GMT
- Mon, 26 Jan 2026 14:59:34 GMT
Server:
- cloudflare
Transfer-Encoding:
@@ -82,12 +84,10 @@ interactions:
- DYNAMIC
request-id:
- REQUEST-ID-XXX
retry-after:
- '24'
strict-transport-security:
- STS-XXX
x-envoy-upstream-service-time:
- '2101'
- '968'
status:
code: 200
message: OK

View File

@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
"""Test routing to Google Vertex provider with new genai model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"api_key": "test-google-api-key",
"model_name": "gemini-embedding-001",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-google-api-key"
assert call_kwargs["model_name"] == "gemini-embedding-001"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"region": "us-central1",
"model_name": "textembedding-gecko",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["region"] == "us-central1"
assert call_kwargs["model_name"] == "textembedding-gecko"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_location(self, mock_import):
"""Test routing to Google Vertex provider with location parameter."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"location": "europe-west1",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
"output_dimensionality": 768,
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["location"] == "europe-west1"
assert call_kwargs["model_name"] == "gemini-embedding-001"
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
assert call_kwargs["output_dimensionality"] == 768

View File

@@ -0,0 +1,176 @@
"""Integration tests for Google Vertex embeddings with Crew memory.
These tests make real API calls and use VCR to record/replay responses.
"""
import os
import threading
from collections import defaultdict
from unittest.mock import patch
import pytest
from crewai import Agent, Crew, Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
@pytest.fixture(autouse=True)
def setup_vertex_ai_env():
"""Set up environment for Vertex AI tests.
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
Also mocks GOOGLE_API_KEY if not already set.
"""
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
# Add a mock API key if none exists
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
env_updates["GOOGLE_API_KEY"] = "test-key"
with patch.dict(os.environ, env_updates):
yield
@pytest.fixture
def google_vertex_embedder_config():
"""Fixture providing Google Vertex embedder configuration."""
return {
"provider": "google-vertex",
"config": {
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
"model_name": "gemini-embedding-001",
},
}
@pytest.fixture
def simple_agent():
"""Fixture providing a simple test agent."""
return Agent(
role="Research Assistant",
goal="Help with research tasks",
backstory="You are a helpful research assistant.",
verbose=False,
)
@pytest.fixture
def simple_task(simple_agent):
"""Fixture providing a simple test task."""
return Task(
description="Summarize the key points about artificial intelligence in one sentence.",
expected_output="A one sentence summary about AI.",
agent=simple_agent,
)
@pytest.mark.vcr()
@pytest.mark.timeout(120) # Longer timeout for VCR recording
def test_crew_memory_with_google_vertex_embedder(
google_vertex_embedder_config, simple_agent, simple_task
) -> None:
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=google_vertex_embedder_config,
verbose=False,
)
result = crew.kickoff()
assert result is not None
assert result.raw is not None
assert len(result.raw) > 0
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
@pytest.mark.vcr()
@pytest.mark.timeout(120)
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
"""Test Crew memory with Google Vertex using project_id authentication."""
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
if not project_id:
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
embedder_config = {
"provider": "google-vertex",
"config": {
"project_id": project_id,
"location": "us-central1",
"model_name": "gemini-embedding-001",
},
}
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=embedder_config,
verbose=False,
)
result = crew.kickoff()
# Verify basic result
assert result is not None
assert result.raw is not None
# Wait for memory save events
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
# Verify memory was actually used
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.8.1"
__version__ = "1.9.0"

8
uv.lock generated
View File

@@ -310,7 +310,7 @@ wheels = [
[[package]]
name = "anthropic"
version = "0.71.1"
version = "0.73.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -322,9 +322,9 @@ dependencies = [
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" },
{ url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
]
[[package]]
@@ -1276,7 +1276,7 @@ requires-dist = [
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
{ name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
{ name = "appdirs", specifier = "~=1.4.4" },
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },