mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
Compare commits
5 Commits
llm-event-
...
gl/chore/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39e6618dab | ||
|
|
6b926b90d0 | ||
|
|
fc84daadbb | ||
|
|
58b866a83d | ||
|
|
9797567342 |
@@ -4,6 +4,74 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
|
||||
icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
<Update label="Jan 26, 2026">
|
||||
## v1.9.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add structured outputs and response_format support across providers
|
||||
- Add response ID to streaming responses
|
||||
- Add event ordering with parent-child hierarchies
|
||||
- Add Keycloak SSO authentication support
|
||||
- Add multimodal file handling capabilities
|
||||
- Add native OpenAI responses API support
|
||||
- Add A2A task execution utilities
|
||||
- Add A2A server configuration and agent card generation
|
||||
- Enhance event system and expand transport options
|
||||
- Improve tool calling mechanisms
|
||||
|
||||
### Bug Fixes
|
||||
- Enhance file store with fallback memory cache when aiocache is not available
|
||||
- Ensure document list is not empty
|
||||
- Handle Bedrock stop sequences properly
|
||||
- Add Google Vertex API key support
|
||||
- Enhance Azure model stop word detection
|
||||
- Improve error handling for HumanFeedbackPending in flow execution
|
||||
- Fix execution span task unlinking
|
||||
|
||||
### Documentation
|
||||
- Add native file handling documentation
|
||||
- Add OpenAI responses API documentation
|
||||
- Add agent card implementation guidance
|
||||
- Refine A2A documentation
|
||||
- Update changelog for v1.8.0
|
||||
|
||||
### Contributors
|
||||
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Jan 15, 2026">
|
||||
## v1.8.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Features
|
||||
- Add A2A task execution utilities
|
||||
- Add A2A server configuration and agent card generation
|
||||
- Add additional transport mechanisms
|
||||
- Add Galileo integration support
|
||||
|
||||
### Bug Fixes
|
||||
- Improve Azure model compatibility
|
||||
- Expand frame inspection depth to detect parent_flow
|
||||
- Resolve task execution span management issues
|
||||
- Enhance error handling for human feedback scenarios during flow execution
|
||||
|
||||
### Documentation
|
||||
- Add A2A agent card documentation
|
||||
- Add PII redaction feature documentation
|
||||
|
||||
### Contributors
|
||||
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Jan 08, 2026">
|
||||
## v1.8.0
|
||||
|
||||
|
||||
@@ -401,23 +401,58 @@ crew = Crew(
|
||||
|
||||
### Vertex AI Embeddings
|
||||
|
||||
For Google Cloud users with Vertex AI access.
|
||||
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
|
||||
|
||||
<Note>
|
||||
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
|
||||
</Note>
|
||||
|
||||
```python
|
||||
# Recommended: Using new models with google-genai SDK
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "vertexai",
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or your preferred region
|
||||
"api_key": "your-service-account-key",
|
||||
"model_name": "textembedding-gecko"
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||
"task_type": "RETRIEVAL_DOCUMENT", # Optional
|
||||
"output_dimensionality": 768 # Optional
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Using API key authentication (Exp)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "your-google-api-key",
|
||||
"model_name": "gemini-embedding-001"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Legacy models (backwards compatible, emits deprecation warning)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or "location" (region is deprecated)
|
||||
"model_name": "textembedding-gecko" # Legacy model
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Available models:**
|
||||
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
|
||||
|
||||
### Ollama Embeddings (Local)
|
||||
|
||||
Run embeddings locally for privacy and cost savings.
|
||||
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
@@ -591,7 +626,7 @@ crew = Crew(
|
||||
|
||||
### Choosing the Right Embedding Provider
|
||||
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
Below is a comparison to help you decide:
|
||||
|
||||
| Provider | Best For | Pros | Cons |
|
||||
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
|
||||
|
||||
This improves performance and observability when writing many entities in one operation.
|
||||
|
||||
## 2. External Memory
|
||||
## 2. External Memory
|
||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||
|
||||
### Basic External Memory with Mem0
|
||||
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.8.1",
|
||||
"crewai==1.9.0",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.8.1",
|
||||
"crewai-tools==1.9.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -90,7 +90,7 @@ azure-ai-inference = [
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic~=0.71.0",
|
||||
"anthropic~=0.73.0",
|
||||
]
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.9.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.9.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||
|
||||
from anthropic.types import ThinkingBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types.beta import BetaMessage
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -31,7 +31,62 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14"
|
||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||
|
||||
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
|
||||
tuple[
|
||||
Literal["claude-sonnet-4-5"],
|
||||
Literal["claude-sonnet-4.5"],
|
||||
Literal["claude-opus-4-5"],
|
||||
Literal["claude-opus-4.5"],
|
||||
Literal["claude-opus-4-1"],
|
||||
Literal["claude-opus-4.1"],
|
||||
Literal["claude-haiku-4-5"],
|
||||
Literal["claude-haiku-4.5"],
|
||||
]
|
||||
] = (
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4.5",
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4.1",
|
||||
"claude-haiku-4-5",
|
||||
"claude-haiku-4.5",
|
||||
)
|
||||
|
||||
|
||||
def _supports_native_structured_outputs(model: str) -> bool:
|
||||
"""Check if the model supports native structured outputs.
|
||||
|
||||
Native structured outputs are only available for Claude 4.5 models
|
||||
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
|
||||
Other models require the tool-based fallback approach.
|
||||
|
||||
Args:
|
||||
model: The model name/identifier.
|
||||
|
||||
Returns:
|
||||
True if the model supports native structured outputs.
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
|
||||
|
||||
|
||||
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Check if an object is a Pydantic model class.
|
||||
|
||||
This distinguishes between Pydantic model classes that support structured
|
||||
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if obj is a Pydantic model class.
|
||||
"""
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
||||
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
@@ -231,6 +290,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
@@ -238,7 +299,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
@@ -246,7 +307,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -298,13 +359,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
@@ -312,7 +375,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -565,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
if uses_file_api:
|
||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
if uses_file_api:
|
||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||
response = self.client.beta.messages.create(**params)
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self.client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
@@ -593,22 +674,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content:
|
||||
@@ -678,17 +771,31 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
@@ -696,15 +803,22 @@ class AnthropicCompletion(BaseLLM):
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
if betas:
|
||||
stream_params["betas"] = betas
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
stream_context = (
|
||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
if betas
|
||||
else self.client.messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
response_id = event.message.id
|
||||
|
||||
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -712,7 +826,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
@@ -739,7 +853,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
@@ -763,10 +877,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
if final_message.content:
|
||||
@@ -781,25 +895,30 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return full_response
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
@@ -809,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
# Handle tool use conversation flow internally
|
||||
return self._handle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
@@ -823,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -884,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
initial_response: Message | BetaMessage,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -1002,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
if uses_file_api:
|
||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
if uses_file_api:
|
||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||
response = await self.async_client.beta.messages.create(**params)
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self.async_client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self.async_client.messages.create(**params)
|
||||
|
||||
@@ -1030,23 +1163,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if response.content:
|
||||
tool_uses = [
|
||||
@@ -1102,25 +1246,49 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle async streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
if betas:
|
||||
stream_params["betas"] = betas
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
stream_context = (
|
||||
self.async_client.beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self.async_client.messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
async for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
@@ -1133,7 +1301,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
@@ -1160,7 +1328,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "input_json_delta":
|
||||
@@ -1184,33 +1352,38 @@ class AnthropicCompletion(BaseLLM):
|
||||
"index": block_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return full_response
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
@@ -1220,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -1247,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
async def _ahandle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
initial_response: Message | BetaMessage,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -1356,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
@staticmethod
|
||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||
def _extract_anthropic_token_usage(
|
||||
response: Message | BetaMessage,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
|
||||
@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
response_id = update.id if hasattr(update,"id") else None
|
||||
response_id = update.id if hasattr(update, "id") else None
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
|
||||
"index": idx,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences
|
||||
self.response_format = response_format
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call AWS Bedrock Converse API."""
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
|
||||
NotImplementedError: If aiobotocore is not installed.
|
||||
LLMContextLengthExceededError: If context window is exceeded.
|
||||
"""
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
if not AIOBOTOCORE_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Async support for AWS Bedrock requires aiobotocore. "
|
||||
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming converse API call following AWS best practices."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate messages format before API call
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||
|
||||
if response_model and tool_uses:
|
||||
for tool_use in tool_uses:
|
||||
if tool_use.get("name") == "structured_output":
|
||||
structured_data = tool_use.get("input", {})
|
||||
result = response_model.model_validate(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
return result
|
||||
|
||||
if tool_uses and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=tool_uses,
|
||||
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"index": tool_use_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming converse API call."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||
|
||||
if response_model and tool_uses:
|
||||
for tool_use in tool_uses:
|
||||
if tool_use.get("name") == "structured_output":
|
||||
structured_data = tool_use.get("input", {})
|
||||
result = response_model.model_validate(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
return result
|
||||
|
||||
if tool_uses and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=tool_uses,
|
||||
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
|
||||
@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
|
||||
types.Content(role="user", parts=[function_response_part])
|
||||
)
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
parts: list[types.Part] = []
|
||||
tool_parts: list[types.Part] = []
|
||||
|
||||
if text_content:
|
||||
parts.append(types.Part.from_text(text=text_content))
|
||||
tool_parts.append(types.Part.from_text(text=text_content))
|
||||
|
||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||
for tool_call in tool_calls:
|
||||
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
func_args = func_args_raw
|
||||
|
||||
parts.append(
|
||||
tool_parts.append(
|
||||
types.Part.from_function_call(name=func_name, args=func_args)
|
||||
)
|
||||
|
||||
contents.append(types.Content(role="model", parts=parts))
|
||||
contents.append(types.Content(role="model", parts=tool_parts))
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||
"""
|
||||
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
|
||||
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
|
||||
if chunk.usage_metadata:
|
||||
usage_data = self._extract_token_usage(chunk)
|
||||
|
||||
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"index": call_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return full_response, function_calls, usage_data
|
||||
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
def build_embedder(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
@@ -18,6 +21,7 @@ __all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"GoogleGenAIVertexEmbeddingFunction",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Google Vertex AI embedding function implementation.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar, cast
|
||||
import warnings
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
||||
|
||||
|
||||
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for Google Vertex AI with dual SDK support.
|
||||
|
||||
This class supports both:
|
||||
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
||||
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access
|
||||
|
||||
Example:
|
||||
# Using legacy model (will emit deprecation warning)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
region="us-central1",
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# Using new model with google-genai SDK
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# Using API key (new SDK only)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
# Models that use the legacy vertexai.language_models SDK
|
||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
}
|
||||
|
||||
# Models that use the new google-genai SDK
|
||||
GENAI_MODELS: ClassVar[set[str]] = {
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
||||
"""Initialize Google Vertex AI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters including:
|
||||
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
||||
- api_key: Optional API key for authentication (new SDK only)
|
||||
- project_id: GCP project ID (for Vertex AI backend)
|
||||
- location: GCP region (default: "us-central1")
|
||||
- region: Deprecated alias for location
|
||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
self._use_legacy = self._is_legacy_model(self._model_name)
|
||||
|
||||
if self._use_legacy:
|
||||
self._init_legacy_client(**kwargs)
|
||||
else:
|
||||
self._init_genai_client(**kwargs)
|
||||
|
||||
def _is_legacy_model(self, model_name: str) -> bool:
|
||||
"""Check if the model uses the legacy SDK."""
|
||||
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
||||
"textembedding-gecko"
|
||||
)
|
||||
|
||||
def _init_legacy_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the deprecated vertexai.language_models SDK."""
|
||||
warnings.warn(
|
||||
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
||||
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
||||
"like 'gemini-embedding-001'. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
"Install it with: pip install google-cloud-aiplatform"
|
||||
) from e
|
||||
|
||||
project_id = kwargs.get("project_id")
|
||||
location = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"project_id is required for legacy models. "
|
||||
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
||||
)
|
||||
|
||||
vertexai.init(project=str(project_id), location=location)
|
||||
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
||||
|
||||
def _init_genai_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the new google-genai SDK."""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import EmbedContentConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"google-genai is required for Google Gen AI embeddings. "
|
||||
"Install it with: uv add 'crewai[google-genai]'"
|
||||
) from e
|
||||
|
||||
self._genai = genai
|
||||
self._EmbedContentConfig = EmbedContentConfig
|
||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||
|
||||
# Initialize client based on authentication mode
|
||||
api_key = kwargs.get("api_key")
|
||||
project_id = kwargs.get("project_id")
|
||||
location: str = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if api_key:
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
elif project_id:
|
||||
self._client = genai.Client(
|
||||
vertexai=True,
|
||||
project=str(project_id),
|
||||
location=location,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'api_key' (for API key authentication) or 'project_id' "
|
||||
"(for Vertex AI backend with ADC) must be provided."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "google-vertex"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if self._use_legacy:
|
||||
return self._call_legacy(input)
|
||||
return self._call_genai(input)
|
||||
|
||||
def _call_legacy(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the legacy SDK."""
|
||||
import numpy as np
|
||||
|
||||
embeddings_list = []
|
||||
for text in input:
|
||||
embedding_result = self._legacy_model.get_embeddings([text])
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result[0].values, dtype=np.float32)
|
||||
)
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the new google-genai SDK."""
|
||||
# Build config for embed_content
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"task_type": self._task_type,
|
||||
}
|
||||
if self._output_dimensionality is not None:
|
||||
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
||||
|
||||
config = self._EmbedContentConfig(**config_kwargs)
|
||||
|
||||
# Call the embedding API
|
||||
response = self._client.models.embed_content(
|
||||
model=self._model_name,
|
||||
contents=input, # type: ignore[arg-type]
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
if response.embeddings is None:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
embeddings = [emb.values for emb in response.embeddings]
|
||||
return cast(Embeddings, embeddings)
|
||||
@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
"""Configuration for Vertex AI provider with dual SDK support.
|
||||
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models using google-genai SDK.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
||||
model_name: Embedding model name (default: "textembedding-gecko").
|
||||
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
||||
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
||||
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
||||
location: GCP region/location (default: "us-central1").
|
||||
region: Deprecated alias for location (kept for backwards compatibility).
|
||||
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
||||
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
# Legacy models (deprecated vertexai.language_models SDK)
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
# New models (google-genai SDK)
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"textembedding-gecko",
|
||||
]
|
||||
project_id: str
|
||||
location: Annotated[str, "us-central1"]
|
||||
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
output_dimensionality: int
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -1,46 +1,126 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
"""Google Vertex AI embeddings provider.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The SDK is automatically selected based on the model name:
|
||||
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
||||
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
||||
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider with dual SDK support.
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
||||
using the google-genai SDK.
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access (new SDK models only)
|
||||
|
||||
Example:
|
||||
# Legacy model (backwards compatible, will emit deprecation warning)
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
region="us-central1", # or location="us-central1"
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# New model with Vertex AI backend
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# New model with API key
|
||||
provider = VertexAIProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
||||
default=GoogleGenAIVertexEmbeddingFunction,
|
||||
description="Google Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
description=(
|
||||
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
||||
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
||||
"use the google-genai SDK."
|
||||
),
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Google API key (optional if using project_id with Application Default Credentials)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
"GOOGLE_CLOUD_PROJECT",
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
description="GCP region/location",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
"GOOGLE_CLOUD_LOCATION",
|
||||
"GOOGLE_CLOUD_REGION",
|
||||
),
|
||||
)
|
||||
region: str | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
||||
"GOOGLE_VERTEX_REGION",
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
||||
"GOOGLE_VERTEX_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
output_dimensionality: int | None = Field(
|
||||
default=None,
|
||||
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,17 +5,29 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiocache import Cache
|
||||
from crewai_files import FileInput
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_file_store: Cache | None = None
|
||||
|
||||
try:
|
||||
from aiocache import Cache
|
||||
from aiocache.serializers import PickleSerializer
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"aiocache is not installed. File store features will be disabled. "
|
||||
"Install with: uv add aiocache"
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -59,6 +71,8 @@ async def astore_files(
|
||||
files: Dictionary mapping names to file inputs.
|
||||
ttl: Time-to-live in seconds.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
||||
|
||||
|
||||
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
|
||||
Returns:
|
||||
Dictionary of files or None if not found.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return None
|
||||
result: dict[str, FileInput] | None = await _file_store.get(
|
||||
f"{_CREW_PREFIX}{execution_id}"
|
||||
)
|
||||
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
|
||||
Args:
|
||||
execution_id: Unique identifier for the crew execution.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
||||
|
||||
|
||||
@@ -98,6 +116,8 @@ async def astore_task_files(
|
||||
files: Dictionary mapping names to file inputs.
|
||||
ttl: Time-to-live in seconds.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
||||
|
||||
|
||||
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
|
||||
Returns:
|
||||
Dictionary of files or None if not found.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return None
|
||||
result: dict[str, FileInput] | None = await _file_store.get(
|
||||
f"{_TASK_PREFIX}{task_id}"
|
||||
)
|
||||
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
|
||||
Args:
|
||||
task_id: Unique identifier for the task.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}'
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
|
||||
the structured response","input_schema":{"type":"object","description":"Response
|
||||
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -13,7 +15,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '539'
|
||||
- '551'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -29,7 +31,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 0.75.0
|
||||
- 0.76.0
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
@@ -42,7 +44,7 @@ interactions:
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -51,7 +53,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 01 Dec 2025 11:19:38 GMT
|
||||
- Mon, 26 Jan 2026 14:59:34 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
@@ -82,12 +84,10 @@ interactions:
|
||||
- DYNAMIC
|
||||
request-id:
|
||||
- REQUEST-ID-XXX
|
||||
retry-after:
|
||||
- '24'
|
||||
strict-transport-security:
|
||||
- STS-XXX
|
||||
x-envoy-upstream-service-time:
|
||||
- '2101'
|
||||
- '968'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with new genai model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "test-google-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-google-api-key"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"region": "us-central1",
|
||||
"model_name": "textembedding-gecko",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["region"] == "us-central1"
|
||||
assert call_kwargs["model_name"] == "textembedding-gecko"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with location parameter."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"location": "europe-west1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"output_dimensionality": 768,
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["location"] == "europe-west1"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
||||
assert call_kwargs["output_dimensionality"] == 768
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Integration tests for Google Vertex embeddings with Crew memory.
|
||||
|
||||
These tests make real API calls and use VCR to record/replay responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vertex_ai_env():
|
||||
"""Set up environment for Vertex AI tests.
|
||||
|
||||
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
||||
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
||||
Also mocks GOOGLE_API_KEY if not already set.
|
||||
"""
|
||||
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
||||
|
||||
# Add a mock API key if none exists
|
||||
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
||||
env_updates["GOOGLE_API_KEY"] = "test-key"
|
||||
|
||||
with patch.dict(os.environ, env_updates):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_vertex_embedder_config():
|
||||
"""Fixture providing Google Vertex embedder configuration."""
|
||||
return {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent():
|
||||
"""Fixture providing a simple test agent."""
|
||||
return Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research tasks",
|
||||
backstory="You are a helpful research assistant.",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(simple_agent):
|
||||
"""Fixture providing a simple test task."""
|
||||
return Task(
|
||||
description="Summarize the key points about artificial intelligence in one sentence.",
|
||||
expected_output="A one sentence summary about AI.",
|
||||
agent=simple_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
||||
def test_crew_memory_with_google_vertex_embedder(
|
||||
google_vertex_embedder_config, simple_agent, simple_task
|
||||
) -> None:
|
||||
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=google_vertex_embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert len(result.raw) > 0
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120)
|
||||
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
||||
"""Test Crew memory with Google Vertex using project_id authentication."""
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if not project_id:
|
||||
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
||||
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": project_id,
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify basic result
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
# Wait for memory save events
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Verify memory was actually used
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -310,7 +310,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.71.1"
|
||||
version = "0.73.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -322,9 +322,9 @@ dependencies = [
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1276,7 +1276,7 @@ requires-dist = [
|
||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
||||
|
||||
Reference in New Issue
Block a user