Compare commits

...

4 Commits

Author SHA1 Message Date
Lorenze Jay
6b926b90d0 chore: update version to 1.9.0 across all relevant files (#4284)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- Bumped the version number to 1.9.0 in pyproject.toml files and __init__.py files across the CrewAI library and its tools.
- Updated dependencies to use the new version of crewai-tools (1.9.0) for improved functionality and compatibility.
- Ensured consistency in versioning across the codebase to reflect the latest updates.
2026-01-26 16:36:35 -08:00
Lorenze Jay
fc84daadbb fix: enhance file store with fallback memory cache when aiocache is n… (#4283)
* fix: enhance file store with fallback memory cache when aiocache is not installed

- Added a simple in-memory cache implementation to serve as a fallback when the aiocache library is unavailable.
- Improved error handling for the aiocache import, ensuring that the application can still function without it.
- This change enhances the robustness of the file store utility by providing a reliable caching mechanism in various environments.

* drop fallback

* Potential fix for pull request finding 'Unused global variable'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
2026-01-26 15:12:34 -08:00
Lorenze Jay
58b866a83d Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support

- Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI.
- Updated factory methods to route to the new embedding function.
- Enhanced VertexAIProvider and related configurations to accommodate the new model options.
- Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods.

This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework.

* fix test count

* rm comment

* regen cassettes

* regen

* drop variable from .envtest

* dreict to relevant trest only
2026-01-26 14:55:03 -08:00
Greyson LaLonde
9797567342 feat: add structured outputs and response_format support across providers (#4280)
* feat: add response_format parameter to Azure and Gemini providers

* feat: add structured outputs support to Bedrock and Anthropic providers

* chore: bump anthropic dep

* fix: use beta structured output for new models
2026-01-26 11:03:33 -08:00
25 changed files with 14742 additions and 223 deletions

View File

@@ -401,23 +401,58 @@ crew = Crew(
### Vertex AI Embeddings ### Vertex AI Embeddings
For Google Cloud users with Vertex AI access. For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
<Note>
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
</Note>
```python ```python
# Recommended: Using new models with google-genai SDK
crew = Crew( crew = Crew(
memory=True, memory=True,
embedder={ embedder={
"provider": "vertexai", "provider": "google-vertex",
"config": { "config": {
"project_id": "your-gcp-project-id", "project_id": "your-gcp-project-id",
"region": "us-central1", # or your preferred region "location": "us-central1",
"api_key": "your-service-account-key", "model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
"model_name": "textembedding-gecko" "task_type": "RETRIEVAL_DOCUMENT", # Optional
"output_dimensionality": 768 # Optional
}
}
)
# Using API key authentication (Exp)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"api_key": "your-google-api-key",
"model_name": "gemini-embedding-001"
}
}
)
# Legacy models (backwards compatible, emits deprecation warning)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"project_id": "your-gcp-project-id",
"region": "us-central1", # or "location" (region is deprecated)
"model_name": "textembedding-gecko" # Legacy model
} }
} }
) )
``` ```
**Available models:**
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
### Ollama Embeddings (Local) ### Ollama Embeddings (Local)
Run embeddings locally for privacy and cost savings. Run embeddings locally for privacy and cost savings.

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source", "wrap_file_source",
] ]
__version__ = "1.8.1" __version__ = "1.9.0"

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube~=15.0.0", "pytube~=15.0.0",
"requests~=2.32.5", "requests~=2.32.5",
"docker~=7.1.0", "docker~=7.1.0",
"crewai==1.8.1", "crewai==1.9.0",
"lancedb~=0.5.4", "lancedb~=0.5.4",
"tiktoken~=0.8.0", "tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4", "beautifulsoup4~=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools", "ZapierActionTools",
] ]
__version__ = "1.8.1" __version__ = "1.9.0"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies] [project.optional-dependencies]
tools = [ tools = [
"crewai-tools==1.8.1", "crewai-tools==1.9.0",
] ]
embeddings = [ embeddings = [
"tiktoken~=0.8.0" "tiktoken~=0.8.0"
@@ -90,7 +90,7 @@ azure-ai-inference = [
"azure-ai-inference~=1.0.0b9", "azure-ai-inference~=1.0.0b9",
] ]
anthropic = [ anthropic = [
"anthropic~=0.71.0", "anthropic~=0.73.0",
] ]
a2a = [ a2a = [
"a2a-sdk~=0.3.10", "a2a-sdk~=0.3.10",

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings() _suppress_pydantic_deprecation_warnings()
__version__ = "1.8.1" __version__ = "1.9.0"
_telemetry_submitted = False _telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }] authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14" requires-python = ">=3.10,<3.14"
dependencies = [ dependencies = [
"crewai[tools]==1.8.1" "crewai[tools]==1.9.0"
] ]
[project.scripts] [project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }] authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14" requires-python = ">=3.10,<3.14"
dependencies = [ dependencies = [
"crewai[tools]==1.8.1" "crewai[tools]==1.9.0"
] ]
[project.scripts] [project.scripts]

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
from anthropic.types import ThinkingBlock
from pydantic import BaseModel from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor from crewai.llms.hooks.base import BaseInterceptor
try: try:
from anthropic import Anthropic, AsyncAnthropic from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
import httpx import httpx
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -31,7 +31,62 @@ except ImportError:
) from None ) from None
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14" ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
tuple[
Literal["claude-sonnet-4-5"],
Literal["claude-sonnet-4.5"],
Literal["claude-opus-4-5"],
Literal["claude-opus-4.5"],
Literal["claude-opus-4-1"],
Literal["claude-opus-4.1"],
Literal["claude-haiku-4-5"],
Literal["claude-haiku-4.5"],
]
] = (
"claude-sonnet-4-5",
"claude-sonnet-4.5",
"claude-opus-4-5",
"claude-opus-4.5",
"claude-opus-4-1",
"claude-opus-4.1",
"claude-haiku-4-5",
"claude-haiku-4.5",
)
def _supports_native_structured_outputs(model: str) -> bool:
"""Check if the model supports native structured outputs.
Native structured outputs are only available for Claude 4.5 models
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
Other models require the tool-based fallback approach.
Args:
model: The model name/identifier.
Returns:
True if the model supports native structured outputs.
"""
model_lower = model.lower()
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
"""Check if an object is a Pydantic model class.
This distinguishes between Pydantic model classes that support structured
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
Args:
obj: The object to check.
Returns:
True if obj is a Pydantic model class.
"""
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool: def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None, thinking: AnthropicThinkingConfig | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Anthropic chat completion client. """Initialize Anthropic chat completion client.
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
stream: Enable streaming responses stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level. interceptor: HTTP interceptor for modifying requests/responses at transport level.
response_format: Pydantic model for structured output. When provided, responses
will be validated against this model schema.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
super().__init__( super().__init__(
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
self.thinking = thinking self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = [] self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Model-specific settings # Model-specific settings
self.is_claude_3 = "claude-3" in model.lower() self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True self.supports_tools = True
@@ -231,6 +290,8 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
# Handle streaming vs non-streaming # Handle streaming vs non-streaming
if self.stream: if self.stream:
return self._handle_streaming_completion( return self._handle_streaming_completion(
@@ -238,7 +299,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -246,7 +307,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -298,13 +359,15 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
if self.stream: if self.stream:
return await self._ahandle_streaming_completion( return await self._ahandle_streaming_completion(
completion_params, completion_params,
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -312,7 +375,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -565,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming message completion.""" """Handle non-streaming message completion."""
if response_model: uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = { structured_tool = {
"name": "structured_output", "name": "structured_output",
"description": "Returns structured data according to the schema", "description": "Output the structured response",
"input_schema": response_model.model_json_schema(), "input_schema": schema,
} }
params["tools"] = [structured_tool] params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"} params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", []))
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = self.client.beta.messages.create(**params) response = self.client.beta.messages.create(
**params, extra_body=extra_body
)
else: else:
response = self.client.messages.create(**params) response = self.client.messages.create(**params)
@@ -593,13 +674,26 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response) usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if _is_pydantic_model_class(response_model) and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=structured_json,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -607,7 +701,6 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return structured_json return structured_json
# Check if Claude wants to use tools # Check if Claude wants to use tools
@@ -678,15 +771,29 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming message completion.""" """Handle streaming message completion."""
if response_model: betas: list[str] = []
use_native_structured_output = False
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = { structured_tool = {
"name": "structured_output", "name": "structured_output",
"description": "Returns structured data according to the schema", "description": "Output the structured response",
"input_schema": response_model.model_json_schema(), "input_schema": schema,
} }
params["tools"] = [structured_tool] params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"} params["tool_choice"] = {"type": "tool", "name": "structured_output"}
@@ -696,10 +803,17 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally) # (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call stream_context = (
with self.client.messages.stream(**stream_params) as stream: self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
if betas
else self.client.messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None response_id = None
for event in stream: for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
@@ -712,7 +826,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -739,7 +853,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -763,10 +877,10 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = stream.get_final_message() final_message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = [] thinking_blocks: list[ThinkingBlock] = []
if final_message.content: if final_message.content:
@@ -781,16 +895,22 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if _is_pydantic_model_class(response_model):
tool_uses = [ if use_native_structured_output:
block self._emit_call_completed_event(
for block in final_message.content response=full_response,
if isinstance(block, ToolUseBlock) call_type=LLMCallType.LLM_CALL,
] from_task=from_task,
if tool_uses and tool_uses[0].name == "structured_output": from_agent=from_agent,
structured_data = tool_uses[0].input messages=params["messages"],
structured_json = json.dumps(structured_data) )
return full_response
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=structured_json,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -798,7 +918,6 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return structured_json return structured_json
if final_message.content: if final_message.content:
@@ -809,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
# Handle tool use conversation flow internally
return self._handle_tool_use_conversation( return self._handle_tool_use_conversation(
final_message, final_message,
tool_uses, tool_uses,
@@ -823,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
from_agent, from_agent,
) )
# Apply stop words to full response
full_response = self._apply_stop_words(full_response) full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event( self._emit_call_completed_event(
response=full_response, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -884,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
def _handle_tool_use_conversation( def _handle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1002,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming async message completion.""" """Handle non-streaming async message completion."""
if response_model: uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = { structured_tool = {
"name": "structured_output", "name": "structured_output",
"description": "Returns structured data according to the schema", "description": "Output the structured response",
"input_schema": response_model.model_json_schema(), "input_schema": schema,
} }
params["tools"] = [structured_tool] params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"} params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", []))
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = await self.async_client.beta.messages.create(**params) response = await self.async_client.beta.messages.create(
**params, extra_body=extra_body
)
else: else:
response = await self.async_client.messages.create(**params) response = await self.async_client.messages.create(**params)
@@ -1030,14 +1163,26 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response) usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if _is_pydantic_model_class(response_model) and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=structured_json,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -1045,7 +1190,6 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return structured_json return structured_json
if response.content: if response.content:
@@ -1102,15 +1246,29 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle async streaming message completion.""" """Handle async streaming message completion."""
if response_model: betas: list[str] = []
use_native_structured_output = False
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = { structured_tool = {
"name": "structured_output", "name": "structured_output",
"description": "Returns structured data according to the schema", "description": "Output the structured response",
"input_schema": response_model.model_json_schema(), "input_schema": schema,
} }
params["tools"] = [structured_tool] params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"} params["tool_choice"] = {"type": "tool", "name": "structured_output"}
@@ -1118,9 +1276,19 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream: stream_context = (
self.async_client.beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self.async_client.messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None response_id = None
async for event in stream: async for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
@@ -1133,7 +1301,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -1160,7 +1328,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -1184,24 +1352,30 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = await stream.get_final_message() final_message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if _is_pydantic_model_class(response_model):
tool_uses = [ if use_native_structured_output:
block self._emit_call_completed_event(
for block in final_message.content response=full_response,
if isinstance(block, ToolUseBlock) call_type=LLMCallType.LLM_CALL,
] from_task=from_task,
if tool_uses and tool_uses[0].name == "structured_output": from_agent=from_agent,
structured_data = tool_uses[0].input messages=params["messages"],
structured_json = json.dumps(structured_data) )
return full_response
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=structured_json,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -1209,7 +1383,6 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return structured_json return structured_json
if final_message.content: if final_message.content:
@@ -1220,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
@@ -1247,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
async def _ahandle_tool_use_conversation( async def _ahandle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1356,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO) return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
@staticmethod @staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]: def _extract_anthropic_token_usage(
response: Message | BetaMessage,
) -> dict[str, Any]:
"""Extract token usage from Anthropic response.""" """Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage = response.usage usage = response.usage

View File

@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
stop: list[str] | None = None, stop: list[str] | None = None,
stream: bool = False, stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Azure AI Inference chat completion client. """Initialize Azure AI Inference chat completion client.
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
stop: Stop sequences stop: Stop sequences
stream: Enable streaming responses stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure). interceptor: HTTP interceptor (not yet supported for Azure).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
Only works with OpenAI models deployed on Azure.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.stream = stream self.stream = stream
self.response_format = response_format
self.is_openai_model = any( self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
) )
effective_response_model = response_model or self.response_format
# Format messages for Azure # Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages) formatted_messages = self._format_messages_for_azure(messages)
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
# Prepare completion parameters # Prepare completion parameters
completion_params = self._prepare_completion_params( completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model formatted_messages, tools, effective_response_model
) )
# Handle streaming vs non-streaming # Handle streaming vs non-streaming
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
) )
effective_response_model = response_model or self.response_format
formatted_messages = self._format_messages_for_azure(messages) formatted_messages = self._format_messages_for_azure(messages)
completion_params = self._prepare_completion_params( completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model formatted_messages, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
chunk=content_delta, chunk=content_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if choice.delta and choice.delta.tool_calls: if choice.delta and choice.delta.tool_calls:
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
"index": idx, "index": idx,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
return full_response return full_response

View File

@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: dict[str, Any] | None = None, additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None, additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize AWS Bedrock completion client. """Initialize AWS Bedrock completion client.
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: Model-specific request parameters additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock). interceptor: HTTP interceptor (not yet supported for Bedrock).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
self.top_k = top_k self.top_k = top_k
self.stream = stream self.stream = stream
self.stop_sequences = stop_sequences self.stop_sequences = stop_sequences
self.response_format = response_format
# Store advanced features (optional) # Store advanced features (optional)
self.guardrail_config = guardrail_config self.guardrail_config = guardrail_config
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Call AWS Bedrock Converse API.""" """Call AWS Bedrock Converse API."""
effective_response_model = response_model or self.response_format
try: try:
# Emit call started event # Emit call started event
self._emit_call_started_event( self._emit_call_started_event(
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
return self._handle_converse( return self._handle_converse(
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
NotImplementedError: If aiobotocore is not installed. NotImplementedError: If aiobotocore is not installed.
LLMContextLengthExceededError: If context window is exceeded. LLMContextLengthExceededError: If context window is exceeded.
""" """
effective_response_model = response_model or self.response_format
if not AIOBOTOCORE_AVAILABLE: if not AIOBOTOCORE_AVAILABLE:
raise NotImplementedError( raise NotImplementedError(
"Async support for AWS Bedrock requires aiobotocore. " "Async support for AWS Bedrock requires aiobotocore. "
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
if self.stream: if self.stream:
return await self._ahandle_streaming_converse( return await self._ahandle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
return await self._ahandle_converse( return await self._ahandle_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming converse API call following AWS best practices.""" """Handle non-streaming converse API call following AWS best practices."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
# Validate messages format before API call
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming converse API call with comprehensive event handling.""" """Handle streaming converse API call with comprehensive event handling."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif "contentBlockStop" in event: elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream") logging.debug("Content block stopped in stream")
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming converse API call.""" """Handle async non-streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle async streaming converse API call.""" """Handle async streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk, chunk=text_chunk,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
elif "toolUse" in delta and current_tool_use: elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")

View File

@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None, use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Google Gemini chat completion client. """Initialize Google Gemini chat completion client.
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var - None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
When using Vertex AI with API key (Express mode), http_options with When using Vertex AI with API key (Express mode), http_options with
api_version="v1" is automatically configured. api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
self.safety_settings = safety_settings or {} self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
# Model-specific settings # Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower()) version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
self.tools = tools self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini( formatted_content, system_instruction = self._format_messages_for_gemini(
messages messages
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook") raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config( config = self._prepare_generation_config(
system_instruction, tools, response_model system_instruction, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except APIError as e: except APIError as e:
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
self.tools = tools self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini( formatted_content, system_instruction = self._format_messages_for_gemini(
messages messages
) )
config = self._prepare_generation_config( config = self._prepare_generation_config(
system_instruction, tools, response_model system_instruction, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except APIError as e: except APIError as e:
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
types.Content(role="user", parts=[function_response_part]) types.Content(role="user", parts=[function_response_part])
) )
elif role == "assistant" and message.get("tool_calls"): elif role == "assistant" and message.get("tool_calls"):
parts: list[types.Part] = [] tool_parts: list[types.Part] = []
if text_content: if text_content:
parts.append(types.Part.from_text(text=text_content)) tool_parts.append(types.Part.from_text(text=text_content))
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or [] tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls: for tool_call in tool_calls:
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
else: else:
func_args = func_args_raw func_args = func_args_raw
parts.append( tool_parts.append(
types.Part.from_function_call(name=func_name, args=func_args) types.Part.from_function_call(name=func_name, args=func_args)
) )
contents.append(types.Content(role="model", parts=parts)) contents.append(types.Content(role="model", parts=tool_parts))
else: else:
# Convert role for Gemini (assistant -> model) # Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user" gemini_role = "model" if role == "assistant" else "user"
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
chunk=chunk.text, chunk=chunk.text,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if chunk.candidates: if chunk.candidates:
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
"index": call_index, "index": call_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
return full_response, function_calls, usage_data return full_response, function_calls, usage_data
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming content generation.""" """Handle streaming content generation."""
full_response = "" full_response = ""
function_calls: dict[int, dict[str, Any]] = {} function_calls: dict[int, dict[str, Any]] = {}
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle async streaming content generation.""" """Handle async streaming content generation."""
full_response = "" full_response = ""
function_calls: dict[int, dict[str, Any]] = {} function_calls: dict[int, dict[str, Any]] = {}

View File

@@ -18,7 +18,6 @@ if TYPE_CHECKING:
) )
from chromadb.utils.embedding_functions.google_embedding_function import ( from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction, GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
) )
from chromadb.utils.embedding_functions.huggingface_embedding_function import ( from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingFunction,
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.types import ( from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec, GenerativeAiProviderSpec,
VertexAIProviderSpec, VertexAIProviderSpec,
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
@overload @overload
def build_embedder_from_dict( def build_embedder_from_dict(
spec: VertexAIProviderSpec, spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ... ) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload @overload
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload @overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ... def build_embedder(
spec: VertexAIProviderSpec,
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload @overload

View File

@@ -1,5 +1,8 @@
"""Google embedding providers.""" """Google embedding providers."""
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.generative_ai import ( from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider, GenerativeAiProvider,
) )
@@ -18,6 +21,7 @@ __all__ = [
"GenerativeAiProvider", "GenerativeAiProvider",
"GenerativeAiProviderConfig", "GenerativeAiProviderConfig",
"GenerativeAiProviderSpec", "GenerativeAiProviderSpec",
"GoogleGenAIVertexEmbeddingFunction",
"VertexAIProvider", "VertexAIProvider",
"VertexAIProviderConfig", "VertexAIProviderConfig",
"VertexAIProviderSpec", "VertexAIProviderSpec",

View File

@@ -0,0 +1,237 @@
"""Google Vertex AI embedding function implementation.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The deprecated vertexai.language_models module will be removed after June 24, 2026.
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from typing import Any, ClassVar, cast
import warnings
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Google Vertex AI with dual SDK support.
This class supports both:
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Supports two authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access
Example:
# Using legacy model (will emit deprecation warning)
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
region="us-central1",
model_name="textembedding-gecko"
)
# Using new model with google-genai SDK
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# Using API key (new SDK only)
embedder = GoogleGenAIVertexEmbeddingFunction(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
# Models that use the legacy vertexai.language_models SDK
LEGACY_MODELS: ClassVar[set[str]] = {
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
}
# Models that use the new google-genai SDK
GENAI_MODELS: ClassVar[set[str]] = {
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
}
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
"""Initialize Google Vertex AI embedding function.
Args:
**kwargs: Configuration parameters including:
- model_name: Model to use for embeddings (default: "textembedding-gecko")
- api_key: Optional API key for authentication (new SDK only)
- project_id: GCP project ID (for Vertex AI backend)
- location: GCP region (default: "us-central1")
- region: Deprecated alias for location
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
- output_dimensionality: Optional output embedding dimension (new SDK only)
"""
# Handle deprecated 'region' parameter (only if it has a value)
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
if region_value is not None:
warnings.warn(
"The 'region' parameter is deprecated, use 'location' instead. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=2,
)
if "location" not in kwargs or kwargs.get("location") is None:
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
self._config = kwargs
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
self._use_legacy = self._is_legacy_model(self._model_name)
if self._use_legacy:
self._init_legacy_client(**kwargs)
else:
self._init_genai_client(**kwargs)
def _is_legacy_model(self, model_name: str) -> bool:
"""Check if the model uses the legacy SDK."""
return model_name in self.LEGACY_MODELS or model_name.startswith(
"textembedding-gecko"
)
def _init_legacy_client(self, **kwargs: Any) -> None:
"""Initialize using the deprecated vertexai.language_models SDK."""
warnings.warn(
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
"which will be removed after June 24, 2026. Consider migrating to newer models "
"like 'gemini-embedding-001'. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=3,
)
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
except ImportError as e:
raise ImportError(
"vertexai is required for legacy embedding models (textembedding-gecko*). "
"Install it with: pip install google-cloud-aiplatform"
) from e
project_id = kwargs.get("project_id")
location = str(kwargs.get("location", "us-central1"))
if not project_id:
raise ValueError(
"project_id is required for legacy models. "
"For API key authentication, use newer models like 'gemini-embedding-001'."
)
vertexai.init(project=str(project_id), location=location)
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
def _init_genai_client(self, **kwargs: Any) -> None:
"""Initialize using the new google-genai SDK."""
try:
from google import genai
from google.genai.types import EmbedContentConfig
except ImportError as e:
raise ImportError(
"google-genai is required for Google Gen AI embeddings. "
"Install it with: uv add 'crewai[google-genai]'"
) from e
self._genai = genai
self._EmbedContentConfig = EmbedContentConfig
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
self._output_dimensionality = kwargs.get("output_dimensionality")
# Initialize client based on authentication mode
api_key = kwargs.get("api_key")
project_id = kwargs.get("project_id")
location: str = str(kwargs.get("location", "us-central1"))
if api_key:
self._client = genai.Client(api_key=api_key)
elif project_id:
self._client = genai.Client(
vertexai=True,
project=str(project_id),
location=location,
)
else:
raise ValueError(
"Either 'api_key' (for API key authentication) or 'project_id' "
"(for Vertex AI backend with ADC) must be provided."
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "google-vertex"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
if self._use_legacy:
return self._call_legacy(input)
return self._call_genai(input)
def _call_legacy(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the legacy SDK."""
import numpy as np
embeddings_list = []
for text in input:
embedding_result = self._legacy_model.get_embeddings([text])
embeddings_list.append(
np.array(embedding_result[0].values, dtype=np.float32)
)
return cast(Embeddings, embeddings_list)
def _call_genai(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the new google-genai SDK."""
# Build config for embed_content
config_kwargs: dict[str, Any] = {
"task_type": self._task_type,
}
if self._output_dimensionality is not None:
config_kwargs["output_dimensionality"] = self._output_dimensionality
config = self._EmbedContentConfig(**config_kwargs)
# Call the embedding API
response = self._client.models.embed_content(
model=self._model_name,
contents=input, # type: ignore[arg-type]
config=config,
)
# Extract embeddings from response
if response.embeddings is None:
raise ValueError("No embeddings returned from the API")
embeddings = [emb.values for emb in response.embeddings]
return cast(Embeddings, embeddings)

View File

@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
class VertexAIProviderConfig(TypedDict, total=False): class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider.""" """Configuration for Vertex AI provider with dual SDK support.
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models using google-genai SDK.
Attributes:
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
model_name: Embedding model name (default: "textembedding-gecko").
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
project_id: GCP project ID (required for Vertex AI backend and legacy models).
location: GCP region/location (default: "us-central1").
region: Deprecated alias for location (kept for backwards compatibility).
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
"""
api_key: str api_key: str
model_name: Annotated[str, "textembedding-gecko"] model_name: Annotated[
project_id: Annotated[str, "cloud-large-language-models"] Literal[
region: Annotated[str, "us-central1"] # Legacy models (deprecated vertexai.language_models SDK)
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
# New models (google-genai SDK)
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"textembedding-gecko",
]
project_id: str
location: Annotated[str, "us-central1"]
region: Annotated[str, "us-central1"] # Deprecated alias for location
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
output_dimensionality: int
class VertexAIProviderSpec(TypedDict, total=False): class VertexAIProviderSpec(TypedDict, total=False):

View File

@@ -1,46 +1,126 @@
"""Google Vertex AI embeddings provider.""" """Google Vertex AI embeddings provider.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The SDK is automatically selected based on the model name:
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
- New models (gemini-embedding-*, text-embedding-*) use google-genai
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from __future__ import annotations
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import AliasChoices, Field from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]): class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider.""" """Google Vertex AI embeddings provider with dual SDK support.
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field( Supports both legacy models (textembedding-gecko*) using the deprecated
default=GoogleVertexEmbeddingFunction, vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
description="Vertex AI embedding function class", using the google-genai SDK.
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access (new SDK models only)
Example:
# Legacy model (backwards compatible, will emit deprecation warning)
provider = VertexAIProvider(
project_id="my-project",
region="us-central1", # or location="us-central1"
model_name="textembedding-gecko"
)
# New model with Vertex AI backend
provider = VertexAIProvider(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# New model with API key
provider = VertexAIProvider(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
default=GoogleGenAIVertexEmbeddingFunction,
description="Google Vertex AI embedding function class",
) )
model_name: str = Field( model_name: str = Field(
default="textembedding-gecko", default="textembedding-gecko",
description="Model name to use for embeddings", description=(
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
"use the google-genai SDK."
),
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME", "EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME", "GOOGLE_VERTEX_MODEL_NAME",
"model", "model",
), ),
) )
api_key: str = Field( api_key: str | None = Field(
description="Google API key", default=None,
description="Google API key (optional if using project_id with Application Default Credentials)",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY" "EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
"GOOGLE_CLOUD_API_KEY",
"GOOGLE_API_KEY",
), ),
) )
project_id: str = Field( project_id: str | None = Field(
default="cloud-large-language-models", default=None,
description="GCP project ID", description="GCP project ID (required for Vertex AI backend and legacy models)",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT" "EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_PROJECT",
), ),
) )
region: str = Field( location: str = Field(
default="us-central1", default="us-central1",
description="GCP region", description="GCP region/location",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION" "EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_CLOUD_REGION",
),
)
region: str | None = Field(
default=None,
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
"GOOGLE_VERTEX_REGION",
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
"GOOGLE_VERTEX_TASK_TYPE",
),
)
output_dimensionality: int | None = Field(
default=None,
description="Output embedding dimensionality (optional). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
), ),
) )

View File

@@ -5,17 +5,29 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
import concurrent.futures import concurrent.futures
import logging
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
from uuid import UUID from uuid import UUID
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
if TYPE_CHECKING: if TYPE_CHECKING:
from aiocache import Cache
from crewai_files import FileInput from crewai_files import FileInput
logger = logging.getLogger(__name__)
_file_store: Cache | None = None
try:
from aiocache import Cache
from aiocache.serializers import PickleSerializer
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer()) _file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
except ImportError:
logger.debug(
"aiocache is not installed. File store features will be disabled. "
"Install with: uv add aiocache"
)
T = TypeVar("T") T = TypeVar("T")
@@ -59,6 +71,8 @@ async def astore_files(
files: Dictionary mapping names to file inputs. files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds. ttl: Time-to-live in seconds.
""" """
if _file_store is None:
return
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl) await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
Returns: Returns:
Dictionary of files or None if not found. Dictionary of files or None if not found.
""" """
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get( result: dict[str, FileInput] | None = await _file_store.get(
f"{_CREW_PREFIX}{execution_id}" f"{_CREW_PREFIX}{execution_id}"
) )
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
Args: Args:
execution_id: Unique identifier for the crew execution. execution_id: Unique identifier for the crew execution.
""" """
if _file_store is None:
return
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}") await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
@@ -98,6 +116,8 @@ async def astore_task_files(
files: Dictionary mapping names to file inputs. files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds. ttl: Time-to-live in seconds.
""" """
if _file_store is None:
return
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl) await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
Returns: Returns:
Dictionary of files or None if not found. Dictionary of files or None if not found.
""" """
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get( result: dict[str, FileInput] | None = await _file_store.get(
f"{_TASK_PREFIX}{task_id}" f"{_TASK_PREFIX}{task_id}"
) )
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
Args: Args:
task_id: Unique identifier for the task. task_id: Unique identifier for the task.
""" """
if _file_store is None:
return
await _file_store.delete(f"{_TASK_PREFIX}{task_id}") await _file_store.delete(f"{_TASK_PREFIX}{task_id}")

View File

@@ -1,6 +1,8 @@
interactions: interactions:
- request: - request:
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}' body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
the structured response","input_schema":{"type":"object","description":"Response
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
headers: headers:
User-Agent: User-Agent:
- X-USER-AGENT-XXX - X-USER-AGENT-XXX
@@ -13,7 +15,7 @@ interactions:
connection: connection:
- keep-alive - keep-alive
content-length: content-length:
- '539' - '551'
content-type: content-type:
- application/json - application/json
host: host:
@@ -29,7 +31,7 @@ interactions:
x-stainless-os: x-stainless-os:
- X-STAINLESS-OS-XXX - X-STAINLESS-OS-XXX
x-stainless-package-version: x-stainless-package-version:
- 0.75.0 - 0.76.0
x-stainless-retry-count: x-stainless-retry-count:
- '0' - '0'
x-stainless-runtime: x-stainless-runtime:
@@ -42,7 +44,7 @@ interactions:
uri: https://api.anthropic.com/v1/messages uri: https://api.anthropic.com/v1/messages
response: response:
body: body:
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}' string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
headers: headers:
CF-RAY: CF-RAY:
- CF-RAY-XXX - CF-RAY-XXX
@@ -51,7 +53,7 @@ interactions:
Content-Type: Content-Type:
- application/json - application/json
Date: Date:
- Mon, 01 Dec 2025 11:19:38 GMT - Mon, 26 Jan 2026 14:59:34 GMT
Server: Server:
- cloudflare - cloudflare
Transfer-Encoding: Transfer-Encoding:
@@ -82,12 +84,10 @@ interactions:
- DYNAMIC - DYNAMIC
request-id: request-id:
- REQUEST-ID-XXX - REQUEST-ID-XXX
retry-after:
- '24'
strict-transport-security: strict-transport-security:
- STS-XXX - STS-XXX
x-envoy-upstream-service-time: x-envoy-upstream-service-time:
- '2101' - '968'
status: status:
code: 200 code: 200
message: OK message: OK

View File

@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider) mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function assert result == mock_embedding_function
mock_import.assert_not_called() mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
"""Test routing to Google Vertex provider with new genai model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"api_key": "test-google-api-key",
"model_name": "gemini-embedding-001",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-google-api-key"
assert call_kwargs["model_name"] == "gemini-embedding-001"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"region": "us-central1",
"model_name": "textembedding-gecko",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["region"] == "us-central1"
assert call_kwargs["model_name"] == "textembedding-gecko"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_location(self, mock_import):
"""Test routing to Google Vertex provider with location parameter."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"location": "europe-west1",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
"output_dimensionality": 768,
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["location"] == "europe-west1"
assert call_kwargs["model_name"] == "gemini-embedding-001"
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
assert call_kwargs["output_dimensionality"] == 768

View File

@@ -0,0 +1,176 @@
"""Integration tests for Google Vertex embeddings with Crew memory.
These tests make real API calls and use VCR to record/replay responses.
"""
import os
import threading
from collections import defaultdict
from unittest.mock import patch
import pytest
from crewai import Agent, Crew, Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
@pytest.fixture(autouse=True)
def setup_vertex_ai_env():
"""Set up environment for Vertex AI tests.
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
Also mocks GOOGLE_API_KEY if not already set.
"""
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
# Add a mock API key if none exists
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
env_updates["GOOGLE_API_KEY"] = "test-key"
with patch.dict(os.environ, env_updates):
yield
@pytest.fixture
def google_vertex_embedder_config():
"""Fixture providing Google Vertex embedder configuration."""
return {
"provider": "google-vertex",
"config": {
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
"model_name": "gemini-embedding-001",
},
}
@pytest.fixture
def simple_agent():
"""Fixture providing a simple test agent."""
return Agent(
role="Research Assistant",
goal="Help with research tasks",
backstory="You are a helpful research assistant.",
verbose=False,
)
@pytest.fixture
def simple_task(simple_agent):
"""Fixture providing a simple test task."""
return Task(
description="Summarize the key points about artificial intelligence in one sentence.",
expected_output="A one sentence summary about AI.",
agent=simple_agent,
)
@pytest.mark.vcr()
@pytest.mark.timeout(120) # Longer timeout for VCR recording
def test_crew_memory_with_google_vertex_embedder(
google_vertex_embedder_config, simple_agent, simple_task
) -> None:
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=google_vertex_embedder_config,
verbose=False,
)
result = crew.kickoff()
assert result is not None
assert result.raw is not None
assert len(result.raw) > 0
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
@pytest.mark.vcr()
@pytest.mark.timeout(120)
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
"""Test Crew memory with Google Vertex using project_id authentication."""
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
if not project_id:
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
embedder_config = {
"provider": "google-vertex",
"config": {
"project_id": project_id,
"location": "us-central1",
"model_name": "gemini-embedding-001",
},
}
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=embedder_config,
verbose=False,
)
result = crew.kickoff()
# Verify basic result
assert result is not None
assert result.raw is not None
# Wait for memory save events
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
# Verify memory was actually used
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools.""" """CrewAI development tools."""
__version__ = "1.8.1" __version__ = "1.9.0"

8
uv.lock generated
View File

@@ -310,7 +310,7 @@ wheels = [
[[package]] [[package]]
name = "anthropic" name = "anthropic"
version = "0.71.1" version = "0.73.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
@@ -322,9 +322,9 @@ dependencies = [
{ name = "sniffio" }, { name = "sniffio" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" } sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" }, { url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
] ]
[[package]] [[package]]
@@ -1276,7 +1276,7 @@ requires-dist = [
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" }, { name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" }, { name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
{ name = "aiosqlite", specifier = "~=0.21.0" }, { name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
{ name = "appdirs", specifier = "~=1.4.4" }, { name = "appdirs", specifier = "~=1.4.4" },
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" }, { name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" }, { name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },