mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 10:08:13 +00:00
Compare commits
1 Commits
1.9.0
...
devin/1769
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
615f6ad9d6 |
@@ -401,58 +401,23 @@ crew = Crew(
|
|||||||
|
|
||||||
### Vertex AI Embeddings
|
### Vertex AI Embeddings
|
||||||
|
|
||||||
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
|
For Google Cloud users with Vertex AI access.
|
||||||
|
|
||||||
<Note>
|
|
||||||
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
|
|
||||||
</Note>
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Recommended: Using new models with google-genai SDK
|
|
||||||
crew = Crew(
|
crew = Crew(
|
||||||
memory=True,
|
memory=True,
|
||||||
embedder={
|
embedder={
|
||||||
"provider": "google-vertex",
|
"provider": "vertexai",
|
||||||
"config": {
|
"config": {
|
||||||
"project_id": "your-gcp-project-id",
|
"project_id": "your-gcp-project-id",
|
||||||
"location": "us-central1",
|
"region": "us-central1", # or your preferred region
|
||||||
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
|
"api_key": "your-service-account-key",
|
||||||
"task_type": "RETRIEVAL_DOCUMENT", # Optional
|
"model_name": "textembedding-gecko"
|
||||||
"output_dimensionality": 768 # Optional
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Using API key authentication (Exp)
|
|
||||||
crew = Crew(
|
|
||||||
memory=True,
|
|
||||||
embedder={
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"api_key": "your-google-api-key",
|
|
||||||
"model_name": "gemini-embedding-001"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Legacy models (backwards compatible, emits deprecation warning)
|
|
||||||
crew = Crew(
|
|
||||||
memory=True,
|
|
||||||
embedder={
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"project_id": "your-gcp-project-id",
|
|
||||||
"region": "us-central1", # or "location" (region is deprecated)
|
|
||||||
"model_name": "textembedding-gecko" # Legacy model
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Available models:**
|
|
||||||
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
|
||||||
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
|
|
||||||
|
|
||||||
### Ollama Embeddings (Local)
|
### Ollama Embeddings (Local)
|
||||||
|
|
||||||
Run embeddings locally for privacy and cost savings.
|
Run embeddings locally for privacy and cost savings.
|
||||||
@@ -604,7 +569,7 @@ mem0_client_embedder_config = {
|
|||||||
"project_id": "my_project_id", # Optional
|
"project_id": "my_project_id", # Optional
|
||||||
"api_key": "custom-api-key" # Optional - overrides env var
|
"api_key": "custom-api-key" # Optional - overrides env var
|
||||||
"run_id": "my_run_id", # Optional - for short-term memory
|
"run_id": "my_run_id", # Optional - for short-term memory
|
||||||
"includes": "include1", # Optional
|
"includes": "include1", # Optional
|
||||||
"excludes": "exclude1", # Optional
|
"excludes": "exclude1", # Optional
|
||||||
"infer": True # Optional defaults to True
|
"infer": True # Optional defaults to True
|
||||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||||
@@ -626,7 +591,7 @@ crew = Crew(
|
|||||||
|
|
||||||
### Choosing the Right Embedding Provider
|
### Choosing the Right Embedding Provider
|
||||||
|
|
||||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||||
Below is a comparison to help you decide:
|
Below is a comparison to help you decide:
|
||||||
|
|
||||||
| Provider | Best For | Pros | Cons |
|
| Provider | Best For | Pros | Cons |
|
||||||
@@ -784,7 +749,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
|
|||||||
|
|
||||||
This improves performance and observability when writing many entities in one operation.
|
This improves performance and observability when writing many entities in one operation.
|
||||||
|
|
||||||
## 2. External Memory
|
## 2. External Memory
|
||||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||||
|
|
||||||
### Basic External Memory with Mem0
|
### Basic External Memory with Mem0
|
||||||
@@ -854,7 +819,7 @@ external_memory = ExternalMemory(
|
|||||||
"project_id": "my_project_id", # Optional
|
"project_id": "my_project_id", # Optional
|
||||||
"api_key": "custom-api-key" # Optional - overrides env var
|
"api_key": "custom-api-key" # Optional - overrides env var
|
||||||
"run_id": "my_run_id", # Optional - for short-term memory
|
"run_id": "my_run_id", # Optional - for short-term memory
|
||||||
"includes": "include1", # Optional
|
"includes": "include1", # Optional
|
||||||
"excludes": "exclude1", # Optional
|
"excludes": "exclude1", # Optional
|
||||||
"infer": True # Optional defaults to True
|
"infer": True # Optional defaults to True
|
||||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||||
|
|||||||
@@ -152,4 +152,4 @@ __all__ = [
|
|||||||
"wrap_file_source",
|
"wrap_file_source",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "1.9.0"
|
__version__ = "1.8.1"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ dependencies = [
|
|||||||
"pytube~=15.0.0",
|
"pytube~=15.0.0",
|
||||||
"requests~=2.32.5",
|
"requests~=2.32.5",
|
||||||
"docker~=7.1.0",
|
"docker~=7.1.0",
|
||||||
"crewai==1.9.0",
|
"crewai==1.8.1",
|
||||||
"lancedb~=0.5.4",
|
"lancedb~=0.5.4",
|
||||||
"tiktoken~=0.8.0",
|
"tiktoken~=0.8.0",
|
||||||
"beautifulsoup4~=4.13.4",
|
"beautifulsoup4~=4.13.4",
|
||||||
|
|||||||
@@ -291,4 +291,4 @@ __all__ = [
|
|||||||
"ZapierActionTools",
|
"ZapierActionTools",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "1.9.0"
|
__version__ = "1.8.1"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tools = [
|
tools = [
|
||||||
"crewai-tools==1.9.0",
|
"crewai-tools==1.8.1",
|
||||||
]
|
]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
"tiktoken~=0.8.0"
|
"tiktoken~=0.8.0"
|
||||||
@@ -90,7 +90,7 @@ azure-ai-inference = [
|
|||||||
"azure-ai-inference~=1.0.0b9",
|
"azure-ai-inference~=1.0.0b9",
|
||||||
]
|
]
|
||||||
anthropic = [
|
anthropic = [
|
||||||
"anthropic~=0.73.0",
|
"anthropic~=0.71.0",
|
||||||
]
|
]
|
||||||
a2a = [
|
a2a = [
|
||||||
"a2a-sdk~=0.3.10",
|
"a2a-sdk~=0.3.10",
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
|||||||
|
|
||||||
_suppress_pydantic_deprecation_warnings()
|
_suppress_pydantic_deprecation_warnings()
|
||||||
|
|
||||||
__version__ = "1.9.0"
|
__version__ = "1.8.1"
|
||||||
_telemetry_submitted = False
|
_telemetry_submitted = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
|||||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.10,<3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crewai[tools]==1.9.0"
|
"crewai[tools]==1.8.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
|||||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.10,<3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crewai[tools]==1.9.0"
|
"crewai[tools]==1.8.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -84,4 +84,3 @@ class LLMStreamChunkEvent(LLMEventBase):
|
|||||||
chunk: str
|
chunk: str
|
||||||
tool_call: ToolCall | None = None
|
tool_call: ToolCall | None = None
|
||||||
call_type: LLMCallType | None = None
|
call_type: LLMCallType | None = None
|
||||||
response_id: str | None = None
|
|
||||||
|
|||||||
@@ -768,10 +768,6 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
# Extract content from the chunk
|
# Extract content from the chunk
|
||||||
chunk_content = None
|
chunk_content = None
|
||||||
response_id = None
|
|
||||||
|
|
||||||
if hasattr(chunk,'id'):
|
|
||||||
response_id = chunk.id
|
|
||||||
|
|
||||||
# Safely extract content from various chunk formats
|
# Safely extract content from various chunk formats
|
||||||
try:
|
try:
|
||||||
@@ -827,7 +823,6 @@ class LLM(BaseLLM):
|
|||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
@@ -849,7 +844,6 @@ class LLM(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
response_id=response_id
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# --- 4) Fallback to non-streaming if no content received
|
# --- 4) Fallback to non-streaming if no content received
|
||||||
@@ -1027,7 +1021,6 @@ class LLM(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Task | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Agent | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_id: str | None = None,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||||
@@ -1048,7 +1041,6 @@ class LLM(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1410,13 +1402,11 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
params["stream_options"] = {"include_usage": True}
|
params["stream_options"] = {"include_usage": True}
|
||||||
response_id = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in await litellm.acompletion(**params):
|
async for chunk in await litellm.acompletion(**params):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
chunk_content = None
|
chunk_content = None
|
||||||
response_id = chunk.id if hasattr(chunk, "id") else None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
choices = None
|
choices = None
|
||||||
@@ -1476,7 +1466,6 @@ class LLM(BaseLLM):
|
|||||||
chunk=chunk_content,
|
chunk=chunk_content,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1514,7 +1503,6 @@ class LLM(BaseLLM):
|
|||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -404,7 +404,6 @@ class BaseLLM(ABC):
|
|||||||
from_agent: Agent | None = None,
|
from_agent: Agent | None = None,
|
||||||
tool_call: dict[str, Any] | None = None,
|
tool_call: dict[str, Any] | None = None,
|
||||||
call_type: LLMCallType | None = None,
|
call_type: LLMCallType | None = None,
|
||||||
response_id: str | None = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit stream chunk event.
|
"""Emit stream chunk event.
|
||||||
|
|
||||||
@@ -414,7 +413,6 @@ class BaseLLM(ABC):
|
|||||||
from_agent: The agent that initiated the call.
|
from_agent: The agent that initiated the call.
|
||||||
tool_call: Tool call information if this is a tool call chunk.
|
tool_call: Tool call information if this is a tool call chunk.
|
||||||
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
call_type: The type of LLM call (LLM_CALL or TOOL_CALL).
|
||||||
response_id: Unique ID for a particular LLM response, chunks have same response_id.
|
|
||||||
"""
|
"""
|
||||||
if not hasattr(crewai_event_bus, "emit"):
|
if not hasattr(crewai_event_bus, "emit"):
|
||||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||||
@@ -427,7 +425,6 @@ class BaseLLM(ABC):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
response_id=response_id
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
|
from anthropic.types import ThinkingBlock
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
@@ -21,9 +22,8 @@ if TYPE_CHECKING:
|
|||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.llms.hooks.base import BaseInterceptor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
from anthropic import Anthropic, AsyncAnthropic
|
||||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||||
from anthropic.types.beta import BetaMessage
|
|
||||||
import httpx
|
import httpx
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -31,62 +31,7 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14"
|
||||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
|
||||||
|
|
||||||
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
|
|
||||||
tuple[
|
|
||||||
Literal["claude-sonnet-4-5"],
|
|
||||||
Literal["claude-sonnet-4.5"],
|
|
||||||
Literal["claude-opus-4-5"],
|
|
||||||
Literal["claude-opus-4.5"],
|
|
||||||
Literal["claude-opus-4-1"],
|
|
||||||
Literal["claude-opus-4.1"],
|
|
||||||
Literal["claude-haiku-4-5"],
|
|
||||||
Literal["claude-haiku-4.5"],
|
|
||||||
]
|
|
||||||
] = (
|
|
||||||
"claude-sonnet-4-5",
|
|
||||||
"claude-sonnet-4.5",
|
|
||||||
"claude-opus-4-5",
|
|
||||||
"claude-opus-4.5",
|
|
||||||
"claude-opus-4-1",
|
|
||||||
"claude-opus-4.1",
|
|
||||||
"claude-haiku-4-5",
|
|
||||||
"claude-haiku-4.5",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _supports_native_structured_outputs(model: str) -> bool:
|
|
||||||
"""Check if the model supports native structured outputs.
|
|
||||||
|
|
||||||
Native structured outputs are only available for Claude 4.5 models
|
|
||||||
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
|
|
||||||
Other models require the tool-based fallback approach.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model name/identifier.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the model supports native structured outputs.
|
|
||||||
"""
|
|
||||||
model_lower = model.lower()
|
|
||||||
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
|
|
||||||
"""Check if an object is a Pydantic model class.
|
|
||||||
|
|
||||||
This distinguishes between Pydantic model classes that support structured
|
|
||||||
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to check.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if obj is a Pydantic model class.
|
|
||||||
"""
|
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
||||||
@@ -139,7 +84,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||||
thinking: AnthropicThinkingConfig | None = None,
|
thinking: AnthropicThinkingConfig | None = None,
|
||||||
response_format: type[BaseModel] | None = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Anthropic chat completion client.
|
"""Initialize Anthropic chat completion client.
|
||||||
@@ -157,8 +101,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
client_params: Additional parameters for the Anthropic client
|
client_params: Additional parameters for the Anthropic client
|
||||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||||
response_format: Pydantic model for structured output. When provided, responses
|
|
||||||
will be validated against this model schema.
|
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -189,7 +131,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
self.thinking = thinking
|
self.thinking = thinking
|
||||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||||
self.response_format = response_format
|
|
||||||
# Model-specific settings
|
# Model-specific settings
|
||||||
self.is_claude_3 = "claude-3" in model.lower()
|
self.is_claude_3 = "claude-3" in model.lower()
|
||||||
self.supports_tools = True
|
self.supports_tools = True
|
||||||
@@ -290,8 +231,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
formatted_messages, system_message, tools
|
formatted_messages, system_message, tools
|
||||||
)
|
)
|
||||||
|
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
# Handle streaming vs non-streaming
|
# Handle streaming vs non-streaming
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_completion(
|
return self._handle_streaming_completion(
|
||||||
@@ -299,7 +238,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -307,7 +246,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -359,15 +298,13 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
formatted_messages, system_message, tools
|
formatted_messages, system_message, tools
|
||||||
)
|
)
|
||||||
|
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return await self._ahandle_streaming_completion(
|
return await self._ahandle_streaming_completion(
|
||||||
completion_params,
|
completion_params,
|
||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -375,7 +312,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -628,40 +565,22 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming message completion."""
|
"""Handle non-streaming message completion."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"input_schema": response_model.model_json_schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||||
betas: list[str] = []
|
|
||||||
use_native_structured_output = False
|
|
||||||
|
|
||||||
if uses_file_api:
|
|
||||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
|
||||||
|
|
||||||
extra_body: dict[str, Any] | None = None
|
|
||||||
if _is_pydantic_model_class(response_model):
|
|
||||||
schema = transform_schema(response_model.model_json_schema())
|
|
||||||
if _supports_native_structured_outputs(self.model):
|
|
||||||
use_native_structured_output = True
|
|
||||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
|
||||||
extra_body = {
|
|
||||||
"output_format": {
|
|
||||||
"type": "json_schema",
|
|
||||||
"schema": schema,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Output the structured response",
|
|
||||||
"input_schema": schema,
|
|
||||||
}
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if betas:
|
if uses_file_api:
|
||||||
params["betas"] = betas
|
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||||
response = self.client.beta.messages.create(
|
response = self.client.beta.messages.create(**params)
|
||||||
**params, extra_body=extra_body
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
response = self.client.messages.create(**params)
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
@@ -674,34 +593,22 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(response)
|
usage = self._extract_anthropic_token_usage(response)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if _is_pydantic_model_class(response_model) and response.content:
|
if response_model and response.content:
|
||||||
if use_native_structured_output:
|
tool_uses = [
|
||||||
for block in response.content:
|
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||||
if isinstance(block, TextBlock):
|
]
|
||||||
structured_json = block.text
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
self._emit_call_completed_event(
|
structured_data = tool_uses[0].input
|
||||||
response=structured_json,
|
structured_json = json.dumps(structured_data)
|
||||||
call_type=LLMCallType.LLM_CALL,
|
self._emit_call_completed_event(
|
||||||
from_task=from_task,
|
response=structured_json,
|
||||||
from_agent=from_agent,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
messages=params["messages"],
|
from_task=from_task,
|
||||||
)
|
from_agent=from_agent,
|
||||||
return structured_json
|
messages=params["messages"],
|
||||||
else:
|
)
|
||||||
for block in response.content:
|
|
||||||
if (
|
return structured_json
|
||||||
isinstance(block, ToolUseBlock)
|
|
||||||
and block.name == "structured_output"
|
|
||||||
):
|
|
||||||
structured_json = json.dumps(block.input)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=structured_json,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=params["messages"],
|
|
||||||
)
|
|
||||||
return structured_json
|
|
||||||
|
|
||||||
# Check if Claude wants to use tools
|
# Check if Claude wants to use tools
|
||||||
if response.content:
|
if response.content:
|
||||||
@@ -771,31 +678,17 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str:
|
||||||
"""Handle streaming message completion."""
|
"""Handle streaming message completion."""
|
||||||
betas: list[str] = []
|
if response_model:
|
||||||
use_native_structured_output = False
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"input_schema": response_model.model_json_schema(),
|
||||||
|
}
|
||||||
|
|
||||||
extra_body: dict[str, Any] | None = None
|
params["tools"] = [structured_tool]
|
||||||
if _is_pydantic_model_class(response_model):
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
schema = transform_schema(response_model.model_json_schema())
|
|
||||||
if _supports_native_structured_outputs(self.model):
|
|
||||||
use_native_structured_output = True
|
|
||||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
|
||||||
extra_body = {
|
|
||||||
"output_format": {
|
|
||||||
"type": "json_schema",
|
|
||||||
"schema": schema,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Output the structured response",
|
|
||||||
"input_schema": schema,
|
|
||||||
}
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
@@ -803,22 +696,11 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
# (the SDK sets it internally)
|
# (the SDK sets it internally)
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
if betas:
|
|
||||||
stream_params["betas"] = betas
|
|
||||||
|
|
||||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
stream_context = (
|
# Make streaming API call
|
||||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
with self.client.messages.stream(**stream_params) as stream:
|
||||||
if betas
|
|
||||||
else self.client.messages.stream(**stream_params)
|
|
||||||
)
|
|
||||||
with stream_context as stream:
|
|
||||||
response_id = None
|
|
||||||
for event in stream:
|
for event in stream:
|
||||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
|
||||||
response_id = event.message.id
|
|
||||||
|
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
full_response += text_delta
|
full_response += text_delta
|
||||||
@@ -826,7 +708,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
chunk=text_delta,
|
chunk=text_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
@@ -853,7 +734,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
elif event.type == "content_block_delta":
|
elif event.type == "content_block_delta":
|
||||||
if event.delta.type == "input_json_delta":
|
if event.delta.type == "input_json_delta":
|
||||||
@@ -877,10 +757,9 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
final_message = stream.get_final_message()
|
final_message: Message = stream.get_final_message()
|
||||||
|
|
||||||
thinking_blocks: list[ThinkingBlock] = []
|
thinking_blocks: list[ThinkingBlock] = []
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
@@ -895,30 +774,25 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if _is_pydantic_model_class(response_model):
|
if response_model and final_message.content:
|
||||||
if use_native_structured_output:
|
tool_uses = [
|
||||||
|
block
|
||||||
|
for block in final_message.content
|
||||||
|
if isinstance(block, ToolUseBlock)
|
||||||
|
]
|
||||||
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
|
structured_data = tool_uses[0].input
|
||||||
|
structured_json = json.dumps(structured_data)
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=structured_json,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return full_response
|
|
||||||
for block in final_message.content:
|
return structured_json
|
||||||
if (
|
|
||||||
isinstance(block, ToolUseBlock)
|
|
||||||
and block.name == "structured_output"
|
|
||||||
):
|
|
||||||
structured_json = json.dumps(block.input)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=structured_json,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=params["messages"],
|
|
||||||
)
|
|
||||||
return structured_json
|
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -928,9 +802,11 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if tool_uses:
|
if tool_uses:
|
||||||
|
# If no available_functions, return tool calls for executor to handle
|
||||||
if not available_functions:
|
if not available_functions:
|
||||||
return list(tool_uses)
|
return list(tool_uses)
|
||||||
|
|
||||||
|
# Handle tool use conversation flow internally
|
||||||
return self._handle_tool_use_conversation(
|
return self._handle_tool_use_conversation(
|
||||||
final_message,
|
final_message,
|
||||||
tool_uses,
|
tool_uses,
|
||||||
@@ -940,8 +816,10 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_agent,
|
from_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply stop words to full response
|
||||||
full_response = self._apply_stop_words(full_response)
|
full_response = self._apply_stop_words(full_response)
|
||||||
|
|
||||||
|
# Emit completion event and return full response
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -999,7 +877,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
def _handle_tool_use_conversation(
|
def _handle_tool_use_conversation(
|
||||||
self,
|
self,
|
||||||
initial_response: Message | BetaMessage,
|
initial_response: Message,
|
||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
@@ -1117,40 +995,22 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming async message completion."""
|
"""Handle non-streaming async message completion."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"input_schema": response_model.model_json_schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||||
betas: list[str] = []
|
|
||||||
use_native_structured_output = False
|
|
||||||
|
|
||||||
if uses_file_api:
|
|
||||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
|
||||||
|
|
||||||
extra_body: dict[str, Any] | None = None
|
|
||||||
if _is_pydantic_model_class(response_model):
|
|
||||||
schema = transform_schema(response_model.model_json_schema())
|
|
||||||
if _supports_native_structured_outputs(self.model):
|
|
||||||
use_native_structured_output = True
|
|
||||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
|
||||||
extra_body = {
|
|
||||||
"output_format": {
|
|
||||||
"type": "json_schema",
|
|
||||||
"schema": schema,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Output the structured response",
|
|
||||||
"input_schema": schema,
|
|
||||||
}
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if betas:
|
if uses_file_api:
|
||||||
params["betas"] = betas
|
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||||
response = await self.async_client.beta.messages.create(
|
response = await self.async_client.beta.messages.create(**params)
|
||||||
**params, extra_body=extra_body
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.messages.create(**params)
|
response = await self.async_client.messages.create(**params)
|
||||||
|
|
||||||
@@ -1163,34 +1023,23 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(response)
|
usage = self._extract_anthropic_token_usage(response)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if _is_pydantic_model_class(response_model) and response.content:
|
if response_model and response.content:
|
||||||
if use_native_structured_output:
|
tool_uses = [
|
||||||
for block in response.content:
|
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||||
if isinstance(block, TextBlock):
|
]
|
||||||
structured_json = block.text
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
self._emit_call_completed_event(
|
structured_data = tool_uses[0].input
|
||||||
response=structured_json,
|
structured_json = json.dumps(structured_data)
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
self._emit_call_completed_event(
|
||||||
from_agent=from_agent,
|
response=structured_json,
|
||||||
messages=params["messages"],
|
call_type=LLMCallType.LLM_CALL,
|
||||||
)
|
from_task=from_task,
|
||||||
return structured_json
|
from_agent=from_agent,
|
||||||
else:
|
messages=params["messages"],
|
||||||
for block in response.content:
|
)
|
||||||
if (
|
|
||||||
isinstance(block, ToolUseBlock)
|
return structured_json
|
||||||
and block.name == "structured_output"
|
|
||||||
):
|
|
||||||
structured_json = json.dumps(block.input)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=structured_json,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=params["messages"],
|
|
||||||
)
|
|
||||||
return structured_json
|
|
||||||
|
|
||||||
if response.content:
|
if response.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1246,54 +1095,26 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str:
|
||||||
"""Handle async streaming message completion."""
|
"""Handle async streaming message completion."""
|
||||||
betas: list[str] = []
|
if response_model:
|
||||||
use_native_structured_output = False
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"input_schema": response_model.model_json_schema(),
|
||||||
|
}
|
||||||
|
|
||||||
extra_body: dict[str, Any] | None = None
|
params["tools"] = [structured_tool]
|
||||||
if _is_pydantic_model_class(response_model):
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
schema = transform_schema(response_model.model_json_schema())
|
|
||||||
if _supports_native_structured_outputs(self.model):
|
|
||||||
use_native_structured_output = True
|
|
||||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
|
||||||
extra_body = {
|
|
||||||
"output_format": {
|
|
||||||
"type": "json_schema",
|
|
||||||
"schema": schema,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Output the structured response",
|
|
||||||
"input_schema": schema,
|
|
||||||
}
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
if betas:
|
|
||||||
stream_params["betas"] = betas
|
|
||||||
|
|
||||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
stream_context = (
|
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||||
self.async_client.beta.messages.stream(
|
|
||||||
**stream_params, extra_body=extra_body
|
|
||||||
)
|
|
||||||
if betas
|
|
||||||
else self.async_client.messages.stream(**stream_params)
|
|
||||||
)
|
|
||||||
async with stream_context as stream:
|
|
||||||
response_id = None
|
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
|
||||||
response_id = event.message.id
|
|
||||||
|
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
full_response += text_delta
|
full_response += text_delta
|
||||||
@@ -1301,7 +1122,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
chunk=text_delta,
|
chunk=text_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
@@ -1328,7 +1148,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
elif event.type == "content_block_delta":
|
elif event.type == "content_block_delta":
|
||||||
if event.delta.type == "input_json_delta":
|
if event.delta.type == "input_json_delta":
|
||||||
@@ -1352,38 +1171,32 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
final_message = await stream.get_final_message()
|
final_message: Message = await stream.get_final_message()
|
||||||
|
|
||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if _is_pydantic_model_class(response_model):
|
if response_model and final_message.content:
|
||||||
if use_native_structured_output:
|
tool_uses = [
|
||||||
|
block
|
||||||
|
for block in final_message.content
|
||||||
|
if isinstance(block, ToolUseBlock)
|
||||||
|
]
|
||||||
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
|
structured_data = tool_uses[0].input
|
||||||
|
structured_json = json.dumps(structured_data)
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=structured_json,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return full_response
|
|
||||||
for block in final_message.content:
|
return structured_json
|
||||||
if (
|
|
||||||
isinstance(block, ToolUseBlock)
|
|
||||||
and block.name == "structured_output"
|
|
||||||
):
|
|
||||||
structured_json = json.dumps(block.input)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=structured_json,
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=params["messages"],
|
|
||||||
)
|
|
||||||
return structured_json
|
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1393,6 +1206,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if tool_uses:
|
if tool_uses:
|
||||||
|
# If no available_functions, return tool calls for executor to handle
|
||||||
if not available_functions:
|
if not available_functions:
|
||||||
return list(tool_uses)
|
return list(tool_uses)
|
||||||
|
|
||||||
@@ -1419,7 +1233,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
async def _ahandle_tool_use_conversation(
|
async def _ahandle_tool_use_conversation(
|
||||||
self,
|
self,
|
||||||
initial_response: Message | BetaMessage,
|
initial_response: Message,
|
||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
@@ -1528,9 +1342,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_anthropic_token_usage(
|
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||||
response: Message | BetaMessage,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Extract token usage from Anthropic response."""
|
"""Extract token usage from Anthropic response."""
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
|
|||||||
@@ -92,7 +92,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
response_format: type[BaseModel] | None = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Azure AI Inference chat completion client.
|
"""Initialize Azure AI Inference chat completion client.
|
||||||
@@ -112,9 +111,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||||
response_format: Pydantic model for structured output. Used as default when
|
|
||||||
response_model is not passed to call()/acall() methods.
|
|
||||||
Only works with OpenAI models deployed on Azure.
|
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -169,7 +165,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.response_format = response_format
|
|
||||||
|
|
||||||
self.is_openai_model = any(
|
self.is_openai_model = any(
|
||||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||||
@@ -303,7 +298,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
# Format messages for Azure
|
# Format messages for Azure
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
@@ -313,7 +307,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, effective_response_model
|
formatted_messages, tools, response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle streaming vs non-streaming
|
# Handle streaming vs non-streaming
|
||||||
@@ -323,7 +317,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -331,7 +325,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -370,12 +364,11 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
|
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, effective_response_model
|
formatted_messages, tools, response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -384,7 +377,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -392,7 +385,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -733,7 +726,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
if update.choices:
|
if update.choices:
|
||||||
choice = update.choices[0]
|
choice = update.choices[0]
|
||||||
response_id = update.id if hasattr(update, "id") else None
|
|
||||||
if choice.delta and choice.delta.content:
|
if choice.delta and choice.delta.content:
|
||||||
content_delta = choice.delta.content
|
content_delta = choice.delta.content
|
||||||
full_response += content_delta
|
full_response += content_delta
|
||||||
@@ -741,7 +733,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
chunk=content_delta,
|
chunk=content_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if choice.delta and choice.delta.tool_calls:
|
if choice.delta and choice.delta.tool_calls:
|
||||||
@@ -776,7 +767,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
"index": idx,
|
"index": idx,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|||||||
@@ -172,7 +172,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
additional_model_request_fields: dict[str, Any] | None = None,
|
additional_model_request_fields: dict[str, Any] | None = None,
|
||||||
additional_model_response_field_paths: list[str] | None = None,
|
additional_model_response_field_paths: list[str] | None = None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
response_format: type[BaseModel] | None = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize AWS Bedrock completion client.
|
"""Initialize AWS Bedrock completion client.
|
||||||
@@ -193,8 +192,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
additional_model_request_fields: Model-specific request parameters
|
additional_model_request_fields: Model-specific request parameters
|
||||||
additional_model_response_field_paths: Custom response field paths
|
additional_model_response_field_paths: Custom response field paths
|
||||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||||
response_format: Pydantic model for structured output. Used as default when
|
|
||||||
response_model is not passed to call()/acall() methods.
|
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -250,8 +247,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.stop_sequences = stop_sequences
|
self.stop_sequences = stop_sequences or []
|
||||||
self.response_format = response_format
|
|
||||||
|
|
||||||
# Store advanced features (optional)
|
# Store advanced features (optional)
|
||||||
self.guardrail_config = guardrail_config
|
self.guardrail_config = guardrail_config
|
||||||
@@ -271,7 +267,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
@property
|
@property
|
||||||
def stop(self) -> list[str]:
|
def stop(self) -> list[str]:
|
||||||
"""Get stop sequences sent to the API."""
|
"""Get stop sequences sent to the API."""
|
||||||
return [] if self.stop_sequences is None else list(self.stop_sequences)
|
return list(self.stop_sequences)
|
||||||
|
|
||||||
@stop.setter
|
@stop.setter
|
||||||
def stop(self, value: Sequence[str] | str | None) -> None:
|
def stop(self, value: Sequence[str] | str | None) -> None:
|
||||||
@@ -303,8 +299,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call AWS Bedrock Converse API."""
|
"""Call AWS Bedrock Converse API."""
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Emit call started event
|
# Emit call started event
|
||||||
self._emit_call_started_event(
|
self._emit_call_started_event(
|
||||||
@@ -381,7 +375,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_converse(
|
return self._handle_converse(
|
||||||
@@ -390,7 +383,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -433,8 +425,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
NotImplementedError: If aiobotocore is not installed.
|
NotImplementedError: If aiobotocore is not installed.
|
||||||
LLMContextLengthExceededError: If context window is exceeded.
|
LLMContextLengthExceededError: If context window is exceeded.
|
||||||
"""
|
"""
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
if not AIOBOTOCORE_AVAILABLE:
|
if not AIOBOTOCORE_AVAILABLE:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Async support for AWS Bedrock requires aiobotocore. "
|
"Async support for AWS Bedrock requires aiobotocore. "
|
||||||
@@ -504,21 +494,11 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return await self._ahandle_streaming_converse(
|
return await self._ahandle_streaming_converse(
|
||||||
formatted_messages,
|
formatted_messages, body, available_functions, from_task, from_agent
|
||||||
body,
|
|
||||||
available_functions,
|
|
||||||
from_task,
|
|
||||||
from_agent,
|
|
||||||
effective_response_model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_converse(
|
return await self._ahandle_converse(
|
||||||
formatted_messages,
|
formatted_messages, body, available_functions, from_task, from_agent
|
||||||
body,
|
|
||||||
available_functions,
|
|
||||||
from_task,
|
|
||||||
from_agent,
|
|
||||||
effective_response_model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -540,29 +520,10 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
) -> str:
|
||||||
) -> str | Any:
|
|
||||||
"""Handle non-streaming converse API call following AWS best practices."""
|
"""Handle non-streaming converse API call following AWS best practices."""
|
||||||
if response_model:
|
|
||||||
structured_tool: ConverseToolTypeDef = {
|
|
||||||
"toolSpec": {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"inputSchema": {"json": response_model.model_json_schema()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body["toolConfig"] = cast(
|
|
||||||
"ToolConfigurationTypeDef",
|
|
||||||
cast(
|
|
||||||
object,
|
|
||||||
{
|
|
||||||
"tools": [structured_tool],
|
|
||||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Validate messages format before API call
|
||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError("Messages cannot be empty")
|
raise ValueError("Messages cannot be empty")
|
||||||
|
|
||||||
@@ -610,21 +571,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||||
|
|
||||||
if response_model and tool_uses:
|
|
||||||
for tool_use in tool_uses:
|
|
||||||
if tool_use.get("name") == "structured_output":
|
|
||||||
structured_data = tool_use.get("input", {})
|
|
||||||
result = response_model.model_validate(structured_data)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=result.model_dump_json(),
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
if tool_uses and not available_functions:
|
if tool_uses and not available_functions:
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=tool_uses,
|
response=tool_uses,
|
||||||
@@ -771,28 +717,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming converse API call with comprehensive event handling."""
|
"""Handle streaming converse API call with comprehensive event handling."""
|
||||||
if response_model:
|
|
||||||
structured_tool: ConverseToolTypeDef = {
|
|
||||||
"toolSpec": {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"inputSchema": {"json": response_model.model_json_schema()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body["toolConfig"] = cast(
|
|
||||||
"ToolConfigurationTypeDef",
|
|
||||||
cast(
|
|
||||||
object,
|
|
||||||
{
|
|
||||||
"tools": [structured_tool],
|
|
||||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id: str | None = None
|
tool_use_id: str | None = None
|
||||||
@@ -810,7 +736,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
stream = response.get("stream")
|
stream = response.get("stream")
|
||||||
response_id = None
|
|
||||||
if stream:
|
if stream:
|
||||||
for event in stream:
|
for event in stream:
|
||||||
if "messageStart" in event:
|
if "messageStart" in event:
|
||||||
@@ -842,7 +767,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"index": tool_use_index,
|
"index": tool_use_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||||
@@ -858,7 +782,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
chunk=text_chunk,
|
chunk=text_chunk,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
elif "toolUse" in delta and current_tool_use:
|
elif "toolUse" in delta and current_tool_use:
|
||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
@@ -879,7 +802,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"index": tool_use_index,
|
"index": tool_use_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
logging.debug("Content block stopped in stream")
|
logging.debug("Content block stopped in stream")
|
||||||
@@ -1003,28 +925,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
) -> str:
|
||||||
) -> str | Any:
|
|
||||||
"""Handle async non-streaming converse API call."""
|
"""Handle async non-streaming converse API call."""
|
||||||
if response_model:
|
|
||||||
structured_tool: ConverseToolTypeDef = {
|
|
||||||
"toolSpec": {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"inputSchema": {"json": response_model.model_json_schema()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body["toolConfig"] = cast(
|
|
||||||
"ToolConfigurationTypeDef",
|
|
||||||
cast(
|
|
||||||
object,
|
|
||||||
{
|
|
||||||
"tools": [structured_tool],
|
|
||||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError("Messages cannot be empty")
|
raise ValueError("Messages cannot be empty")
|
||||||
@@ -1070,21 +972,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||||
|
|
||||||
if response_model and tool_uses:
|
|
||||||
for tool_use in tool_uses:
|
|
||||||
if tool_use.get("name") == "structured_output":
|
|
||||||
structured_data = tool_use.get("input", {})
|
|
||||||
result = response_model.model_validate(structured_data)
|
|
||||||
self._emit_call_completed_event(
|
|
||||||
response=result.model_dump_json(),
|
|
||||||
call_type=LLMCallType.LLM_CALL,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
if tool_uses and not available_functions:
|
if tool_uses and not available_functions:
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=tool_uses,
|
response=tool_uses,
|
||||||
@@ -1215,28 +1102,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle async streaming converse API call."""
|
"""Handle async streaming converse API call."""
|
||||||
if response_model:
|
|
||||||
structured_tool: ConverseToolTypeDef = {
|
|
||||||
"toolSpec": {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"inputSchema": {"json": response_model.model_json_schema()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body["toolConfig"] = cast(
|
|
||||||
"ToolConfigurationTypeDef",
|
|
||||||
cast(
|
|
||||||
object,
|
|
||||||
{
|
|
||||||
"tools": [structured_tool],
|
|
||||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id: str | None = None
|
tool_use_id: str | None = None
|
||||||
@@ -1255,7 +1122,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
stream = response.get("stream")
|
stream = response.get("stream")
|
||||||
response_id = None
|
|
||||||
if stream:
|
if stream:
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if "messageStart" in event:
|
if "messageStart" in event:
|
||||||
@@ -1287,7 +1153,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"index": tool_use_index,
|
"index": tool_use_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||||
@@ -1303,7 +1168,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
chunk=text_chunk,
|
chunk=text_chunk,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
elif "toolUse" in delta and current_tool_use:
|
elif "toolUse" in delta and current_tool_use:
|
||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
@@ -1324,7 +1188,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"index": tool_use_index,
|
"index": tool_use_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
use_vertexai: bool | None = None,
|
use_vertexai: bool | None = None,
|
||||||
response_format: type[BaseModel] | None = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Google Gemini chat completion client.
|
"""Initialize Google Gemini chat completion client.
|
||||||
@@ -87,8 +86,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||||
When using Vertex AI with API key (Express mode), http_options with
|
When using Vertex AI with API key (Express mode), http_options with
|
||||||
api_version="v1" is automatically configured.
|
api_version="v1" is automatically configured.
|
||||||
response_format: Pydantic model for structured output. Used as default when
|
|
||||||
response_model is not passed to call()/acall() methods.
|
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -124,7 +121,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
self.safety_settings = safety_settings or {}
|
self.safety_settings = safety_settings or {}
|
||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
self.tools: list[dict[str, Any]] | None = None
|
self.tools: list[dict[str, Any]] | None = None
|
||||||
self.response_format = response_format
|
|
||||||
|
|
||||||
# Model-specific settings
|
# Model-specific settings
|
||||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||||
@@ -296,7 +292,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
@@ -308,7 +303,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, effective_response_model
|
system_instruction, tools, response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -318,7 +313,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -327,7 +322,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -379,14 +374,13 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
effective_response_model = response_model or self.response_format
|
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, effective_response_model
|
system_instruction, tools, response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -396,7 +390,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -405,7 +399,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
effective_response_model,
|
response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -561,11 +555,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
response_data: dict[str, Any]
|
response_data: dict[str, Any]
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(text_content) if text_content else {}
|
response_data = json.loads(text_content) if text_content else {}
|
||||||
if isinstance(parsed, dict):
|
|
||||||
response_data = parsed
|
|
||||||
else:
|
|
||||||
response_data = {"result": parsed}
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
response_data = {"result": text_content}
|
response_data = {"result": text_content}
|
||||||
|
|
||||||
@@ -576,10 +566,10 @@ class GeminiCompletion(BaseLLM):
|
|||||||
types.Content(role="user", parts=[function_response_part])
|
types.Content(role="user", parts=[function_response_part])
|
||||||
)
|
)
|
||||||
elif role == "assistant" and message.get("tool_calls"):
|
elif role == "assistant" and message.get("tool_calls"):
|
||||||
tool_parts: list[types.Part] = []
|
parts: list[types.Part] = []
|
||||||
|
|
||||||
if text_content:
|
if text_content:
|
||||||
tool_parts.append(types.Part.from_text(text=text_content))
|
parts.append(types.Part.from_text(text=text_content))
|
||||||
|
|
||||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
@@ -598,11 +588,11 @@ class GeminiCompletion(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
func_args = func_args_raw
|
func_args = func_args_raw
|
||||||
|
|
||||||
tool_parts.append(
|
parts.append(
|
||||||
types.Part.from_function_call(name=func_name, args=func_args)
|
types.Part.from_function_call(name=func_name, args=func_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
contents.append(types.Content(role="model", parts=tool_parts))
|
contents.append(types.Content(role="model", parts=parts))
|
||||||
else:
|
else:
|
||||||
# Convert role for Gemini (assistant -> model)
|
# Convert role for Gemini (assistant -> model)
|
||||||
gemini_role = "model" if role == "assistant" else "user"
|
gemini_role = "model" if role == "assistant" else "user"
|
||||||
@@ -796,7 +786,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||||
"""
|
"""
|
||||||
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
|
|
||||||
if chunk.usage_metadata:
|
if chunk.usage_metadata:
|
||||||
usage_data = self._extract_token_usage(chunk)
|
usage_data = self._extract_token_usage(chunk)
|
||||||
|
|
||||||
@@ -806,7 +795,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
chunk=chunk.text,
|
chunk=chunk.text,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.candidates:
|
if chunk.candidates:
|
||||||
@@ -843,7 +831,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"index": call_index,
|
"index": call_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response, function_calls, usage_data
|
return full_response, function_calls, usage_data
|
||||||
@@ -978,7 +965,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1056,7 +1043,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str:
|
||||||
"""Handle async streaming content generation."""
|
"""Handle async streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|||||||
@@ -1047,12 +1047,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
final_response: Response | None = None
|
final_response: Response | None = None
|
||||||
|
|
||||||
stream = self.client.responses.create(**params)
|
stream = self.client.responses.create(**params)
|
||||||
response_id_stream = None
|
|
||||||
|
|
||||||
for event in stream:
|
for event in stream:
|
||||||
if event.type == "response.created":
|
|
||||||
response_id_stream = event.response.id
|
|
||||||
|
|
||||||
if event.type == "response.output_text.delta":
|
if event.type == "response.output_text.delta":
|
||||||
delta_text = event.delta or ""
|
delta_text = event.delta or ""
|
||||||
full_response += delta_text
|
full_response += delta_text
|
||||||
@@ -1060,7 +1056,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=delta_text,
|
chunk=delta_text,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif event.type == "response.function_call_arguments.delta":
|
elif event.type == "response.function_call_arguments.delta":
|
||||||
@@ -1175,12 +1170,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
final_response: Response | None = None
|
final_response: Response | None = None
|
||||||
|
|
||||||
stream = await self.async_client.responses.create(**params)
|
stream = await self.async_client.responses.create(**params)
|
||||||
response_id_stream = None
|
|
||||||
|
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if event.type == "response.created":
|
|
||||||
response_id_stream = event.response.id
|
|
||||||
|
|
||||||
if event.type == "response.output_text.delta":
|
if event.type == "response.output_text.delta":
|
||||||
delta_text = event.delta or ""
|
delta_text = event.delta or ""
|
||||||
full_response += delta_text
|
full_response += delta_text
|
||||||
@@ -1188,7 +1179,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=delta_text,
|
chunk=delta_text,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif event.type == "response.function_call_arguments.delta":
|
elif event.type == "response.function_call_arguments.delta":
|
||||||
@@ -1709,8 +1699,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
**parse_params, response_format=response_model
|
**parse_params, response_format=response_model
|
||||||
) as stream:
|
) as stream:
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
|
||||||
|
|
||||||
if chunk.type == "content.delta":
|
if chunk.type == "content.delta":
|
||||||
delta_content = chunk.delta
|
delta_content = chunk.delta
|
||||||
if delta_content:
|
if delta_content:
|
||||||
@@ -1718,7 +1706,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=delta_content,
|
chunk=delta_content,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
final_completion = stream.get_final_completion()
|
final_completion = stream.get_final_completion()
|
||||||
@@ -1748,8 +1735,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
usage_data = {"total_tokens": 0}
|
usage_data = {"total_tokens": 0}
|
||||||
|
|
||||||
for completion_chunk in completion_stream:
|
for completion_chunk in completion_stream:
|
||||||
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
|
|
||||||
|
|
||||||
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
||||||
usage_data = self._extract_openai_token_usage(completion_chunk)
|
usage_data = self._extract_openai_token_usage(completion_chunk)
|
||||||
continue
|
continue
|
||||||
@@ -1766,7 +1751,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=chunk_delta.content,
|
chunk=chunk_delta.content,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_delta.tool_calls:
|
if chunk_delta.tool_calls:
|
||||||
@@ -1805,7 +1789,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
"index": tool_calls[tool_index]["index"],
|
"index": tool_calls[tool_index]["index"],
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
@@ -2017,8 +2000,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
usage_data = {"total_tokens": 0}
|
usage_data = {"total_tokens": 0}
|
||||||
async for chunk in completion_stream:
|
async for chunk in completion_stream:
|
||||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
|
||||||
|
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
usage_data = self._extract_openai_token_usage(chunk)
|
usage_data = self._extract_openai_token_usage(chunk)
|
||||||
continue
|
continue
|
||||||
@@ -2035,7 +2016,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=delta.content,
|
chunk=delta.content,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
@@ -2071,8 +2051,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
usage_data = {"total_tokens": 0}
|
usage_data = {"total_tokens": 0}
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
|
||||||
|
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
usage_data = self._extract_openai_token_usage(chunk)
|
usage_data = self._extract_openai_token_usage(chunk)
|
||||||
continue
|
continue
|
||||||
@@ -2089,7 +2067,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
chunk=chunk_delta.content,
|
chunk=chunk_delta.content,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_delta.tool_calls:
|
if chunk_delta.tool_calls:
|
||||||
@@ -2128,7 +2105,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
"index": tool_calls[tool_index]["index"],
|
"index": tool_calls[tool_index]["index"],
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id_stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ def _default_settings() -> Settings:
|
|||||||
persist_directory=DEFAULT_STORAGE_PATH,
|
persist_directory=DEFAULT_STORAGE_PATH,
|
||||||
allow_reset=True,
|
allow_reset=True,
|
||||||
is_persistent=True,
|
is_persistent=True,
|
||||||
anonymized_telemetry=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
|
GoogleVertexEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingFunction,
|
HuggingFaceEmbeddingFunction,
|
||||||
@@ -51,9 +52,6 @@ if TYPE_CHECKING:
|
|||||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
|
||||||
GoogleGenAIVertexEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from crewai.rag.embeddings.providers.google.types import (
|
from crewai.rag.embeddings.providers.google.types import (
|
||||||
GenerativeAiProviderSpec,
|
GenerativeAiProviderSpec,
|
||||||
VertexAIProviderSpec,
|
VertexAIProviderSpec,
|
||||||
@@ -165,7 +163,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
|||||||
@overload
|
@overload
|
||||||
def build_embedder_from_dict(
|
def build_embedder_from_dict(
|
||||||
spec: VertexAIProviderSpec,
|
spec: VertexAIProviderSpec,
|
||||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
) -> GoogleVertexEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -298,9 +296,7 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def build_embedder(
|
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||||
spec: VertexAIProviderSpec,
|
|
||||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
"""Google embedding providers."""
|
"""Google embedding providers."""
|
||||||
|
|
||||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
|
||||||
GoogleGenAIVertexEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||||
GenerativeAiProvider,
|
GenerativeAiProvider,
|
||||||
)
|
)
|
||||||
@@ -21,7 +18,6 @@ __all__ = [
|
|||||||
"GenerativeAiProvider",
|
"GenerativeAiProvider",
|
||||||
"GenerativeAiProviderConfig",
|
"GenerativeAiProviderConfig",
|
||||||
"GenerativeAiProviderSpec",
|
"GenerativeAiProviderSpec",
|
||||||
"GoogleGenAIVertexEmbeddingFunction",
|
|
||||||
"VertexAIProvider",
|
"VertexAIProvider",
|
||||||
"VertexAIProviderConfig",
|
"VertexAIProviderConfig",
|
||||||
"VertexAIProviderSpec",
|
"VertexAIProviderSpec",
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
"""Google Vertex AI embedding function implementation.
|
|
||||||
|
|
||||||
This module supports both the new google-genai SDK and the deprecated
|
|
||||||
vertexai.language_models module for backwards compatibility.
|
|
||||||
|
|
||||||
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
|
||||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, ClassVar, cast
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
|
||||||
"""Embedding function for Google Vertex AI with dual SDK support.
|
|
||||||
|
|
||||||
This class supports both:
|
|
||||||
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
|
||||||
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
|
||||||
|
|
||||||
The SDK is automatically selected based on the model name. Legacy models will
|
|
||||||
emit a deprecation warning.
|
|
||||||
|
|
||||||
Supports two authentication modes:
|
|
||||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
|
||||||
2. API key: Set api_key for direct API access
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Using legacy model (will emit deprecation warning)
|
|
||||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
|
||||||
project_id="my-project",
|
|
||||||
region="us-central1",
|
|
||||||
model_name="textembedding-gecko"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Using new model with google-genai SDK
|
|
||||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
|
||||||
project_id="my-project",
|
|
||||||
location="us-central1",
|
|
||||||
model_name="gemini-embedding-001"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Using API key (new SDK only)
|
|
||||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
|
||||||
api_key="your-api-key",
|
|
||||||
model_name="gemini-embedding-001"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Models that use the legacy vertexai.language_models SDK
|
|
||||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
|
||||||
"textembedding-gecko",
|
|
||||||
"textembedding-gecko@001",
|
|
||||||
"textembedding-gecko@002",
|
|
||||||
"textembedding-gecko@003",
|
|
||||||
"textembedding-gecko@latest",
|
|
||||||
"textembedding-gecko-multilingual",
|
|
||||||
"textembedding-gecko-multilingual@001",
|
|
||||||
"textembedding-gecko-multilingual@latest",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Models that use the new google-genai SDK
|
|
||||||
GENAI_MODELS: ClassVar[set[str]] = {
|
|
||||||
"gemini-embedding-001",
|
|
||||||
"text-embedding-005",
|
|
||||||
"text-multilingual-embedding-002",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
|
||||||
"""Initialize Google Vertex AI embedding function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Configuration parameters including:
|
|
||||||
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
|
||||||
- api_key: Optional API key for authentication (new SDK only)
|
|
||||||
- project_id: GCP project ID (for Vertex AI backend)
|
|
||||||
- location: GCP region (default: "us-central1")
|
|
||||||
- region: Deprecated alias for location
|
|
||||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
|
||||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
|
||||||
"""
|
|
||||||
# Handle deprecated 'region' parameter (only if it has a value)
|
|
||||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
|
||||||
if region_value is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
|
||||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if "location" not in kwargs or kwargs.get("location") is None:
|
|
||||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
|
||||||
|
|
||||||
self._config = kwargs
|
|
||||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
|
||||||
self._use_legacy = self._is_legacy_model(self._model_name)
|
|
||||||
|
|
||||||
if self._use_legacy:
|
|
||||||
self._init_legacy_client(**kwargs)
|
|
||||||
else:
|
|
||||||
self._init_genai_client(**kwargs)
|
|
||||||
|
|
||||||
def _is_legacy_model(self, model_name: str) -> bool:
|
|
||||||
"""Check if the model uses the legacy SDK."""
|
|
||||||
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
|
||||||
"textembedding-gecko"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_legacy_client(self, **kwargs: Any) -> None:
|
|
||||||
"""Initialize using the deprecated vertexai.language_models SDK."""
|
|
||||||
warnings.warn(
|
|
||||||
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
|
||||||
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
|
||||||
"like 'gemini-embedding-001'. "
|
|
||||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import vertexai
|
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
|
||||||
"Install it with: pip install google-cloud-aiplatform"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
project_id = kwargs.get("project_id")
|
|
||||||
location = str(kwargs.get("location", "us-central1"))
|
|
||||||
|
|
||||||
if not project_id:
|
|
||||||
raise ValueError(
|
|
||||||
"project_id is required for legacy models. "
|
|
||||||
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
|
||||||
)
|
|
||||||
|
|
||||||
vertexai.init(project=str(project_id), location=location)
|
|
||||||
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
|
||||||
|
|
||||||
def _init_genai_client(self, **kwargs: Any) -> None:
|
|
||||||
"""Initialize using the new google-genai SDK."""
|
|
||||||
try:
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import EmbedContentConfig
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"google-genai is required for Google Gen AI embeddings. "
|
|
||||||
"Install it with: uv add 'crewai[google-genai]'"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
self._genai = genai
|
|
||||||
self._EmbedContentConfig = EmbedContentConfig
|
|
||||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
|
||||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
|
||||||
|
|
||||||
# Initialize client based on authentication mode
|
|
||||||
api_key = kwargs.get("api_key")
|
|
||||||
project_id = kwargs.get("project_id")
|
|
||||||
location: str = str(kwargs.get("location", "us-central1"))
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
self._client = genai.Client(api_key=api_key)
|
|
||||||
elif project_id:
|
|
||||||
self._client = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=str(project_id),
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Either 'api_key' (for API key authentication) or 'project_id' "
|
|
||||||
"(for Vertex AI backend with ADC) must be provided."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def name() -> str:
|
|
||||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
|
||||||
return "google-vertex"
|
|
||||||
|
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
|
||||||
"""Generate embeddings for input documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: List of documents to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embedding vectors.
|
|
||||||
"""
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [input]
|
|
||||||
|
|
||||||
if self._use_legacy:
|
|
||||||
return self._call_legacy(input)
|
|
||||||
return self._call_genai(input)
|
|
||||||
|
|
||||||
def _call_legacy(self, input: list[str]) -> Embeddings:
|
|
||||||
"""Generate embeddings using the legacy SDK."""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
embeddings_list = []
|
|
||||||
for text in input:
|
|
||||||
embedding_result = self._legacy_model.get_embeddings([text])
|
|
||||||
embeddings_list.append(
|
|
||||||
np.array(embedding_result[0].values, dtype=np.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(Embeddings, embeddings_list)
|
|
||||||
|
|
||||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
|
||||||
"""Generate embeddings using the new google-genai SDK."""
|
|
||||||
# Build config for embed_content
|
|
||||||
config_kwargs: dict[str, Any] = {
|
|
||||||
"task_type": self._task_type,
|
|
||||||
}
|
|
||||||
if self._output_dimensionality is not None:
|
|
||||||
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
|
||||||
|
|
||||||
config = self._EmbedContentConfig(**config_kwargs)
|
|
||||||
|
|
||||||
# Call the embedding API
|
|
||||||
response = self._client.models.embed_content(
|
|
||||||
model=self._model_name,
|
|
||||||
contents=input, # type: ignore[arg-type]
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract embeddings from response
|
|
||||||
if response.embeddings is None:
|
|
||||||
raise ValueError("No embeddings returned from the API")
|
|
||||||
embeddings = [emb.values for emb in response.embeddings]
|
|
||||||
return cast(Embeddings, embeddings)
|
|
||||||
@@ -34,47 +34,12 @@ class GenerativeAiProviderSpec(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class VertexAIProviderConfig(TypedDict, total=False):
|
class VertexAIProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for Vertex AI provider with dual SDK support.
|
"""Configuration for Vertex AI provider."""
|
||||||
|
|
||||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
|
||||||
vertexai.language_models SDK and new models using google-genai SDK.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
|
||||||
model_name: Embedding model name (default: "textembedding-gecko").
|
|
||||||
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
|
||||||
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
|
||||||
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
|
||||||
location: GCP region/location (default: "us-central1").
|
|
||||||
region: Deprecated alias for location (kept for backwards compatibility).
|
|
||||||
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
|
||||||
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
model_name: Annotated[
|
model_name: Annotated[str, "textembedding-gecko"]
|
||||||
Literal[
|
project_id: Annotated[str, "cloud-large-language-models"]
|
||||||
# Legacy models (deprecated vertexai.language_models SDK)
|
region: Annotated[str, "us-central1"]
|
||||||
"textembedding-gecko",
|
|
||||||
"textembedding-gecko@001",
|
|
||||||
"textembedding-gecko@002",
|
|
||||||
"textembedding-gecko@003",
|
|
||||||
"textembedding-gecko@latest",
|
|
||||||
"textembedding-gecko-multilingual",
|
|
||||||
"textembedding-gecko-multilingual@001",
|
|
||||||
"textembedding-gecko-multilingual@latest",
|
|
||||||
# New models (google-genai SDK)
|
|
||||||
"gemini-embedding-001",
|
|
||||||
"text-embedding-005",
|
|
||||||
"text-multilingual-embedding-002",
|
|
||||||
],
|
|
||||||
"textembedding-gecko",
|
|
||||||
]
|
|
||||||
project_id: str
|
|
||||||
location: Annotated[str, "us-central1"]
|
|
||||||
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
|
||||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
|
||||||
output_dimensionality: int
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIProviderSpec(TypedDict, total=False):
|
class VertexAIProviderSpec(TypedDict, total=False):
|
||||||
|
|||||||
@@ -1,126 +1,46 @@
|
|||||||
"""Google Vertex AI embeddings provider.
|
"""Google Vertex AI embeddings provider."""
|
||||||
|
|
||||||
This module supports both the new google-genai SDK and the deprecated
|
|
||||||
vertexai.language_models module for backwards compatibility.
|
|
||||||
|
|
||||||
The SDK is automatically selected based on the model name:
|
|
||||||
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
|
||||||
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
|
||||||
|
|
||||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
from pydantic import AliasChoices, Field
|
from pydantic import AliasChoices, Field
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
|
||||||
GoogleGenAIVertexEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||||
"""Google Vertex AI embeddings provider with dual SDK support.
|
"""Google Vertex AI embeddings provider."""
|
||||||
|
|
||||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||||
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
default=GoogleVertexEmbeddingFunction,
|
||||||
using the google-genai SDK.
|
description="Vertex AI embedding function class",
|
||||||
|
|
||||||
The SDK is automatically selected based on the model name. Legacy models will
|
|
||||||
emit a deprecation warning.
|
|
||||||
|
|
||||||
Authentication modes:
|
|
||||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
|
||||||
2. API key: Set api_key for direct API access (new SDK models only)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Legacy model (backwards compatible, will emit deprecation warning)
|
|
||||||
provider = VertexAIProvider(
|
|
||||||
project_id="my-project",
|
|
||||||
region="us-central1", # or location="us-central1"
|
|
||||||
model_name="textembedding-gecko"
|
|
||||||
)
|
|
||||||
|
|
||||||
# New model with Vertex AI backend
|
|
||||||
provider = VertexAIProvider(
|
|
||||||
project_id="my-project",
|
|
||||||
location="us-central1",
|
|
||||||
model_name="gemini-embedding-001"
|
|
||||||
)
|
|
||||||
|
|
||||||
# New model with API key
|
|
||||||
provider = VertexAIProvider(
|
|
||||||
api_key="your-api-key",
|
|
||||||
model_name="gemini-embedding-001"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
|
||||||
default=GoogleGenAIVertexEmbeddingFunction,
|
|
||||||
description="Google Vertex AI embedding function class",
|
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="textembedding-gecko",
|
default="textembedding-gecko",
|
||||||
description=(
|
description="Model name to use for embeddings",
|
||||||
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
|
||||||
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
|
||||||
"use the google-genai SDK."
|
|
||||||
),
|
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||||
"GOOGLE_VERTEX_MODEL_NAME",
|
"GOOGLE_VERTEX_MODEL_NAME",
|
||||||
"model",
|
"model",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
api_key: str | None = Field(
|
api_key: str = Field(
|
||||||
default=None,
|
description="Google API key",
|
||||||
description="Google API key (optional if using project_id with Application Default Credentials)",
|
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||||
"GOOGLE_CLOUD_API_KEY",
|
|
||||||
"GOOGLE_API_KEY",
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
project_id: str | None = Field(
|
project_id: str = Field(
|
||||||
default=None,
|
default="cloud-large-language-models",
|
||||||
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
description="GCP project ID",
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||||
"GOOGLE_CLOUD_PROJECT",
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
location: str = Field(
|
region: str = Field(
|
||||||
default="us-central1",
|
default="us-central1",
|
||||||
description="GCP region/location",
|
description="GCP region",
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
|
||||||
"GOOGLE_CLOUD_LOCATION",
|
|
||||||
"GOOGLE_CLOUD_REGION",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
region: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
|
||||||
validation_alias=AliasChoices(
|
|
||||||
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
|
||||||
"GOOGLE_VERTEX_REGION",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
task_type: str = Field(
|
|
||||||
default="RETRIEVAL_DOCUMENT",
|
|
||||||
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
|
||||||
validation_alias=AliasChoices(
|
|
||||||
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
|
||||||
"GOOGLE_VERTEX_TASK_TYPE",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
output_dimensionality: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
|
||||||
validation_alias=AliasChoices(
|
|
||||||
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
|
||||||
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,29 +5,17 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from aiocache import Cache # type: ignore[import-untyped]
|
||||||
|
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiocache import Cache
|
|
||||||
from crewai_files import FileInput
|
from crewai_files import FileInput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||||
|
|
||||||
_file_store: Cache | None = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from aiocache import Cache
|
|
||||||
from aiocache.serializers import PickleSerializer
|
|
||||||
|
|
||||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
|
||||||
except ImportError:
|
|
||||||
logger.debug(
|
|
||||||
"aiocache is not installed. File store features will be disabled. "
|
|
||||||
"Install with: uv add aiocache"
|
|
||||||
)
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -71,8 +59,6 @@ async def astore_files(
|
|||||||
files: Dictionary mapping names to file inputs.
|
files: Dictionary mapping names to file inputs.
|
||||||
ttl: Time-to-live in seconds.
|
ttl: Time-to-live in seconds.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return
|
|
||||||
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
||||||
|
|
||||||
|
|
||||||
@@ -85,8 +71,6 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of files or None if not found.
|
Dictionary of files or None if not found.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return None
|
|
||||||
result: dict[str, FileInput] | None = await _file_store.get(
|
result: dict[str, FileInput] | None = await _file_store.get(
|
||||||
f"{_CREW_PREFIX}{execution_id}"
|
f"{_CREW_PREFIX}{execution_id}"
|
||||||
)
|
)
|
||||||
@@ -99,8 +83,6 @@ async def aclear_files(execution_id: UUID) -> None:
|
|||||||
Args:
|
Args:
|
||||||
execution_id: Unique identifier for the crew execution.
|
execution_id: Unique identifier for the crew execution.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return
|
|
||||||
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
||||||
|
|
||||||
|
|
||||||
@@ -116,8 +98,6 @@ async def astore_task_files(
|
|||||||
files: Dictionary mapping names to file inputs.
|
files: Dictionary mapping names to file inputs.
|
||||||
ttl: Time-to-live in seconds.
|
ttl: Time-to-live in seconds.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return
|
|
||||||
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,8 +110,6 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of files or None if not found.
|
Dictionary of files or None if not found.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return None
|
|
||||||
result: dict[str, FileInput] | None = await _file_store.get(
|
result: dict[str, FileInput] | None = await _file_store.get(
|
||||||
f"{_TASK_PREFIX}{task_id}"
|
f"{_TASK_PREFIX}{task_id}"
|
||||||
)
|
)
|
||||||
@@ -144,8 +122,6 @@ async def aclear_task_files(task_id: UUID) -> None:
|
|||||||
Args:
|
Args:
|
||||||
task_id: Unique identifier for the task.
|
task_id: Unique identifier for the task.
|
||||||
"""
|
"""
|
||||||
if _file_store is None:
|
|
||||||
return
|
|
||||||
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
interactions:
|
interactions:
|
||||||
- request:
|
- request:
|
||||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
|
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}'
|
||||||
the structured response","input_schema":{"type":"object","description":"Response
|
|
||||||
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
|
|
||||||
headers:
|
headers:
|
||||||
User-Agent:
|
User-Agent:
|
||||||
- X-USER-AGENT-XXX
|
- X-USER-AGENT-XXX
|
||||||
@@ -15,7 +13,7 @@ interactions:
|
|||||||
connection:
|
connection:
|
||||||
- keep-alive
|
- keep-alive
|
||||||
content-length:
|
content-length:
|
||||||
- '551'
|
- '539'
|
||||||
content-type:
|
content-type:
|
||||||
- application/json
|
- application/json
|
||||||
host:
|
host:
|
||||||
@@ -31,7 +29,7 @@ interactions:
|
|||||||
x-stainless-os:
|
x-stainless-os:
|
||||||
- X-STAINLESS-OS-XXX
|
- X-STAINLESS-OS-XXX
|
||||||
x-stainless-package-version:
|
x-stainless-package-version:
|
||||||
- 0.76.0
|
- 0.75.0
|
||||||
x-stainless-retry-count:
|
x-stainless-retry-count:
|
||||||
- '0'
|
- '0'
|
||||||
x-stainless-runtime:
|
x-stainless-runtime:
|
||||||
@@ -44,7 +42,7 @@ interactions:
|
|||||||
uri: https://api.anthropic.com/v1/messages
|
uri: https://api.anthropic.com/v1/messages
|
||||||
response:
|
response:
|
||||||
body:
|
body:
|
||||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||||
headers:
|
headers:
|
||||||
CF-RAY:
|
CF-RAY:
|
||||||
- CF-RAY-XXX
|
- CF-RAY-XXX
|
||||||
@@ -53,7 +51,7 @@ interactions:
|
|||||||
Content-Type:
|
Content-Type:
|
||||||
- application/json
|
- application/json
|
||||||
Date:
|
Date:
|
||||||
- Mon, 26 Jan 2026 14:59:34 GMT
|
- Mon, 01 Dec 2025 11:19:38 GMT
|
||||||
Server:
|
Server:
|
||||||
- cloudflare
|
- cloudflare
|
||||||
Transfer-Encoding:
|
Transfer-Encoding:
|
||||||
@@ -84,10 +82,12 @@ interactions:
|
|||||||
- DYNAMIC
|
- DYNAMIC
|
||||||
request-id:
|
request-id:
|
||||||
- REQUEST-ID-XXX
|
- REQUEST-ID-XXX
|
||||||
|
retry-after:
|
||||||
|
- '24'
|
||||||
strict-transport-security:
|
strict-transport-security:
|
||||||
- STS-XXX
|
- STS-XXX
|
||||||
x-envoy-upstream-service-time:
|
x-envoy-upstream-service-time:
|
||||||
- '968'
|
- '2101'
|
||||||
status:
|
status:
|
||||||
code: 200
|
code: 200
|
||||||
message: OK
|
message: OK
|
||||||
|
|||||||
@@ -1,319 +0,0 @@
|
|||||||
interactions:
|
|
||||||
- request:
|
|
||||||
body: '{"contents": [{"parts": [{"text": "\nCurrent Task: What is 10000 + 20000?
|
|
||||||
Use the sum_numbers tool to calculate this.\n\nThis is the expected criteria
|
|
||||||
for your final answer: The sum of the two numbers\nyou MUST return the actual
|
|
||||||
complete content as the final answer, not a summary.\n\nThis is VERY important
|
|
||||||
to you, your job depends on it!"}], "role": "user"}], "systemInstruction": {"parts":
|
|
||||||
[{"text": "You are Calculator. You are a calculator that adds numbers.\nYour
|
|
||||||
personal goal is: Calculate numbers accurately"}], "role": "user"}, "tools":
|
|
||||||
[{"functionDeclarations": [{"description": "Add two numbers together and return
|
|
||||||
the result", "name": "sum_numbers", "parameters": {"properties": {"a": {"description":
|
|
||||||
"The first number to add", "title": "A", "type": "NUMBER"}, "b": {"description":
|
|
||||||
"The second number to add", "title": "B", "type": "NUMBER"}}, "required": ["a",
|
|
||||||
"b"], "type": "OBJECT"}}]}], "generationConfig": {"stopSequences": ["\nObservation:"]}}'
|
|
||||||
headers:
|
|
||||||
User-Agent:
|
|
||||||
- X-USER-AGENT-XXX
|
|
||||||
accept:
|
|
||||||
- '*/*'
|
|
||||||
accept-encoding:
|
|
||||||
- ACCEPT-ENCODING-XXX
|
|
||||||
connection:
|
|
||||||
- keep-alive
|
|
||||||
content-length:
|
|
||||||
- '962'
|
|
||||||
content-type:
|
|
||||||
- application/json
|
|
||||||
host:
|
|
||||||
- generativelanguage.googleapis.com
|
|
||||||
x-goog-api-client:
|
|
||||||
- google-genai-sdk/1.49.0 gl-python/3.13.3
|
|
||||||
x-goog-api-key:
|
|
||||||
- X-GOOG-API-KEY-XXX
|
|
||||||
method: POST
|
|
||||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent
|
|
||||||
response:
|
|
||||||
body:
|
|
||||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
|
||||||
[\n {\n \"functionCall\": {\n \"name\": \"sum_numbers\",\n
|
|
||||||
\ \"args\": {\n \"a\": 10000,\n \"b\":
|
|
||||||
20000\n }\n }\n }\n ],\n \"role\":
|
|
||||||
\"model\"\n },\n \"finishReason\": \"STOP\",\n \"avgLogprobs\":
|
|
||||||
-0.00059548033667462211\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
|
|
||||||
127,\n \"candidatesTokenCount\": 7,\n \"totalTokenCount\": 134,\n \"promptTokensDetails\":
|
|
||||||
[\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 127\n
|
|
||||||
\ }\n ],\n \"candidatesTokensDetails\": [\n {\n \"modality\":
|
|
||||||
\"TEXT\",\n \"tokenCount\": 7\n }\n ]\n },\n \"modelVersion\":
|
|
||||||
\"gemini-2.0-flash-001\",\n \"responseId\": \"bLBzabiACaP3-8YP7s-P6QI\"\n}\n"
|
|
||||||
headers:
|
|
||||||
Alt-Svc:
|
|
||||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
||||||
Content-Type:
|
|
||||||
- application/json; charset=UTF-8
|
|
||||||
Date:
|
|
||||||
- Fri, 23 Jan 2026 17:31:24 GMT
|
|
||||||
Server:
|
|
||||||
- scaffolding on HTTPServer2
|
|
||||||
Server-Timing:
|
|
||||||
- gfet4t7; dur=673
|
|
||||||
Transfer-Encoding:
|
|
||||||
- chunked
|
|
||||||
Vary:
|
|
||||||
- Origin
|
|
||||||
- X-Origin
|
|
||||||
- Referer
|
|
||||||
X-Content-Type-Options:
|
|
||||||
- X-CONTENT-TYPE-XXX
|
|
||||||
X-Frame-Options:
|
|
||||||
- X-FRAME-OPTIONS-XXX
|
|
||||||
X-XSS-Protection:
|
|
||||||
- '0'
|
|
||||||
status:
|
|
||||||
code: 200
|
|
||||||
message: OK
|
|
||||||
- request:
|
|
||||||
body: '{"contents": [{"parts": [{"text": "\nCurrent Task: What is 10000 + 20000?
|
|
||||||
Use the sum_numbers tool to calculate this.\n\nThis is the expected criteria
|
|
||||||
for your final answer: The sum of the two numbers\nyou MUST return the actual
|
|
||||||
complete content as the final answer, not a summary.\n\nThis is VERY important
|
|
||||||
to you, your job depends on it!"}], "role": "user"}, {"parts": [{"functionCall":
|
|
||||||
{"args": {"a": 10000, "b": 20000}, "name": "sum_numbers"}}], "role": "model"},
|
|
||||||
{"parts": [{"functionResponse": {"name": "sum_numbers", "response": {"result":
|
|
||||||
30000}}}], "role": "user"}, {"parts": [{"text": "Analyze the tool result. If
|
|
||||||
requirements are met, provide the Final Answer. Otherwise, call the next tool.
|
|
||||||
Deliver only the answer without meta-commentary."}], "role": "user"}], "systemInstruction":
|
|
||||||
{"parts": [{"text": "You are Calculator. You are a calculator that adds numbers.\nYour
|
|
||||||
personal goal is: Calculate numbers accurately"}], "role": "user"}, "tools":
|
|
||||||
[{"functionDeclarations": [{"description": "Add two numbers together and return
|
|
||||||
the result", "name": "sum_numbers", "parameters": {"properties": {"a": {"description":
|
|
||||||
"The first number to add", "title": "A", "type": "NUMBER"}, "b": {"description":
|
|
||||||
"The second number to add", "title": "B", "type": "NUMBER"}}, "required": ["a",
|
|
||||||
"b"], "type": "OBJECT"}}]}], "generationConfig": {"stopSequences": ["\nObservation:"]}}'
|
|
||||||
headers:
|
|
||||||
User-Agent:
|
|
||||||
- X-USER-AGENT-XXX
|
|
||||||
accept:
|
|
||||||
- '*/*'
|
|
||||||
accept-encoding:
|
|
||||||
- ACCEPT-ENCODING-XXX
|
|
||||||
connection:
|
|
||||||
- keep-alive
|
|
||||||
content-length:
|
|
||||||
- '1374'
|
|
||||||
content-type:
|
|
||||||
- application/json
|
|
||||||
host:
|
|
||||||
- generativelanguage.googleapis.com
|
|
||||||
x-goog-api-client:
|
|
||||||
- google-genai-sdk/1.49.0 gl-python/3.13.3
|
|
||||||
x-goog-api-key:
|
|
||||||
- X-GOOG-API-KEY-XXX
|
|
||||||
method: POST
|
|
||||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent
|
|
||||||
response:
|
|
||||||
body:
|
|
||||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
|
||||||
[\n {\n \"text\": \"\"\n }\n ],\n \"role\":
|
|
||||||
\"model\"\n },\n \"finishReason\": \"STOP\"\n }\n ],\n \"usageMetadata\":
|
|
||||||
{\n \"promptTokenCount\": 171,\n \"totalTokenCount\": 171,\n \"promptTokensDetails\":
|
|
||||||
[\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 171\n
|
|
||||||
\ }\n ]\n },\n \"modelVersion\": \"gemini-2.0-flash-001\",\n \"responseId\":
|
|
||||||
\"bLBzaaKgMc-ajrEPk7bIuQ8\"\n}\n"
|
|
||||||
headers:
|
|
||||||
Alt-Svc:
|
|
||||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
||||||
Content-Type:
|
|
||||||
- application/json; charset=UTF-8
|
|
||||||
Date:
|
|
||||||
- Fri, 23 Jan 2026 17:31:25 GMT
|
|
||||||
Server:
|
|
||||||
- scaffolding on HTTPServer2
|
|
||||||
Server-Timing:
|
|
||||||
- gfet4t7; dur=382
|
|
||||||
Transfer-Encoding:
|
|
||||||
- chunked
|
|
||||||
Vary:
|
|
||||||
- Origin
|
|
||||||
- X-Origin
|
|
||||||
- Referer
|
|
||||||
X-Content-Type-Options:
|
|
||||||
- X-CONTENT-TYPE-XXX
|
|
||||||
X-Frame-Options:
|
|
||||||
- X-FRAME-OPTIONS-XXX
|
|
||||||
X-XSS-Protection:
|
|
||||||
- '0'
|
|
||||||
status:
|
|
||||||
code: 200
|
|
||||||
message: OK
|
|
||||||
- request:
|
|
||||||
body: '{"contents": [{"parts": [{"text": "\nCurrent Task: What is 10000 + 20000?
|
|
||||||
Use the sum_numbers tool to calculate this.\n\nThis is the expected criteria
|
|
||||||
for your final answer: The sum of the two numbers\nyou MUST return the actual
|
|
||||||
complete content as the final answer, not a summary.\n\nThis is VERY important
|
|
||||||
to you, your job depends on it!"}], "role": "user"}, {"parts": [{"functionCall":
|
|
||||||
{"args": {"a": 10000, "b": 20000}, "name": "sum_numbers"}}], "role": "model"},
|
|
||||||
{"parts": [{"functionResponse": {"name": "sum_numbers", "response": {"result":
|
|
||||||
30000}}}], "role": "user"}, {"parts": [{"text": "Analyze the tool result. If
|
|
||||||
requirements are met, provide the Final Answer. Otherwise, call the next tool.
|
|
||||||
Deliver only the answer without meta-commentary."}], "role": "user"}, {"parts":
|
|
||||||
[{"text": "\nCurrent Task: What is 10000 + 20000? Use the sum_numbers tool to
|
|
||||||
calculate this.\n\nThis is the expected criteria for your final answer: The
|
|
||||||
sum of the two numbers\nyou MUST return the actual complete content as the final
|
|
||||||
answer, not a summary.\n\nThis is VERY important to you, your job depends on
|
|
||||||
it!"}], "role": "user"}], "systemInstruction": {"parts": [{"text": "You are
|
|
||||||
Calculator. You are a calculator that adds numbers.\nYour personal goal is:
|
|
||||||
Calculate numbers accurately\n\nYou are Calculator. You are a calculator that
|
|
||||||
adds numbers.\nYour personal goal is: Calculate numbers accurately"}], "role":
|
|
||||||
"user"}, "tools": [{"functionDeclarations": [{"description": "Add two numbers
|
|
||||||
together and return the result", "name": "sum_numbers", "parameters": {"properties":
|
|
||||||
{"a": {"description": "The first number to add", "title": "A", "type": "NUMBER"},
|
|
||||||
"b": {"description": "The second number to add", "title": "B", "type": "NUMBER"}},
|
|
||||||
"required": ["a", "b"], "type": "OBJECT"}}]}], "generationConfig": {"stopSequences":
|
|
||||||
["\nObservation:"]}}'
|
|
||||||
headers:
|
|
||||||
User-Agent:
|
|
||||||
- X-USER-AGENT-XXX
|
|
||||||
accept:
|
|
||||||
- '*/*'
|
|
||||||
accept-encoding:
|
|
||||||
- ACCEPT-ENCODING-XXX
|
|
||||||
connection:
|
|
||||||
- keep-alive
|
|
||||||
content-length:
|
|
||||||
- '1837'
|
|
||||||
content-type:
|
|
||||||
- application/json
|
|
||||||
host:
|
|
||||||
- generativelanguage.googleapis.com
|
|
||||||
x-goog-api-client:
|
|
||||||
- google-genai-sdk/1.49.0 gl-python/3.13.3
|
|
||||||
x-goog-api-key:
|
|
||||||
- X-GOOG-API-KEY-XXX
|
|
||||||
method: POST
|
|
||||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent
|
|
||||||
response:
|
|
||||||
body:
|
|
||||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
|
||||||
[\n {\n \"text\": \"\"\n }\n ],\n \"role\":
|
|
||||||
\"model\"\n },\n \"finishReason\": \"STOP\"\n }\n ],\n \"usageMetadata\":
|
|
||||||
{\n \"promptTokenCount\": 271,\n \"totalTokenCount\": 271,\n \"promptTokensDetails\":
|
|
||||||
[\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 271\n
|
|
||||||
\ }\n ]\n },\n \"modelVersion\": \"gemini-2.0-flash-001\",\n \"responseId\":
|
|
||||||
\"bbBzaczHDcW7jrEPgaj1CA\"\n}\n"
|
|
||||||
headers:
|
|
||||||
Alt-Svc:
|
|
||||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
||||||
Content-Type:
|
|
||||||
- application/json; charset=UTF-8
|
|
||||||
Date:
|
|
||||||
- Fri, 23 Jan 2026 17:31:25 GMT
|
|
||||||
Server:
|
|
||||||
- scaffolding on HTTPServer2
|
|
||||||
Server-Timing:
|
|
||||||
- gfet4t7; dur=410
|
|
||||||
Transfer-Encoding:
|
|
||||||
- chunked
|
|
||||||
Vary:
|
|
||||||
- Origin
|
|
||||||
- X-Origin
|
|
||||||
- Referer
|
|
||||||
X-Content-Type-Options:
|
|
||||||
- X-CONTENT-TYPE-XXX
|
|
||||||
X-Frame-Options:
|
|
||||||
- X-FRAME-OPTIONS-XXX
|
|
||||||
X-XSS-Protection:
|
|
||||||
- '0'
|
|
||||||
status:
|
|
||||||
code: 200
|
|
||||||
message: OK
|
|
||||||
- request:
|
|
||||||
body: '{"contents": [{"parts": [{"text": "\nCurrent Task: What is 10000 + 20000?
|
|
||||||
Use the sum_numbers tool to calculate this.\n\nThis is the expected criteria
|
|
||||||
for your final answer: The sum of the two numbers\nyou MUST return the actual
|
|
||||||
complete content as the final answer, not a summary.\n\nThis is VERY important
|
|
||||||
to you, your job depends on it!"}], "role": "user"}, {"parts": [{"functionCall":
|
|
||||||
{"args": {"a": 10000, "b": 20000}, "name": "sum_numbers"}}], "role": "model"},
|
|
||||||
{"parts": [{"functionResponse": {"name": "sum_numbers", "response": {"result":
|
|
||||||
30000}}}], "role": "user"}, {"parts": [{"text": "Analyze the tool result. If
|
|
||||||
requirements are met, provide the Final Answer. Otherwise, call the next tool.
|
|
||||||
Deliver only the answer without meta-commentary."}], "role": "user"}, {"parts":
|
|
||||||
[{"text": "\nCurrent Task: What is 10000 + 20000? Use the sum_numbers tool to
|
|
||||||
calculate this.\n\nThis is the expected criteria for your final answer: The
|
|
||||||
sum of the two numbers\nyou MUST return the actual complete content as the final
|
|
||||||
answer, not a summary.\n\nThis is VERY important to you, your job depends on
|
|
||||||
it!"}], "role": "user"}, {"parts": [{"text": "\nCurrent Task: What is 10000
|
|
||||||
+ 20000? Use the sum_numbers tool to calculate this.\n\nThis is the expected
|
|
||||||
criteria for your final answer: The sum of the two numbers\nyou MUST return
|
|
||||||
the actual complete content as the final answer, not a summary.\n\nThis is VERY
|
|
||||||
important to you, your job depends on it!"}], "role": "user"}], "systemInstruction":
|
|
||||||
{"parts": [{"text": "You are Calculator. You are a calculator that adds numbers.\nYour
|
|
||||||
personal goal is: Calculate numbers accurately\n\nYou are Calculator. You are
|
|
||||||
a calculator that adds numbers.\nYour personal goal is: Calculate numbers accurately\n\nYou
|
|
||||||
are Calculator. You are a calculator that adds numbers.\nYour personal goal
|
|
||||||
is: Calculate numbers accurately"}], "role": "user"}, "tools": [{"functionDeclarations":
|
|
||||||
[{"description": "Add two numbers together and return the result", "name": "sum_numbers",
|
|
||||||
"parameters": {"properties": {"a": {"description": "The first number to add",
|
|
||||||
"title": "A", "type": "NUMBER"}, "b": {"description": "The second number to
|
|
||||||
add", "title": "B", "type": "NUMBER"}}, "required": ["a", "b"], "type": "OBJECT"}}]}],
|
|
||||||
"generationConfig": {"stopSequences": ["\nObservation:"]}}'
|
|
||||||
headers:
|
|
||||||
User-Agent:
|
|
||||||
- X-USER-AGENT-XXX
|
|
||||||
accept:
|
|
||||||
- '*/*'
|
|
||||||
accept-encoding:
|
|
||||||
- ACCEPT-ENCODING-XXX
|
|
||||||
connection:
|
|
||||||
- keep-alive
|
|
||||||
content-length:
|
|
||||||
- '2300'
|
|
||||||
content-type:
|
|
||||||
- application/json
|
|
||||||
host:
|
|
||||||
- generativelanguage.googleapis.com
|
|
||||||
x-goog-api-client:
|
|
||||||
- google-genai-sdk/1.49.0 gl-python/3.13.3
|
|
||||||
x-goog-api-key:
|
|
||||||
- X-GOOG-API-KEY-XXX
|
|
||||||
method: POST
|
|
||||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent
|
|
||||||
response:
|
|
||||||
body:
|
|
||||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
|
||||||
[\n {\n \"text\": \"\\n{\\\"sum_numbers_response\\\":
|
|
||||||
{\\\"result\\\": 30000}}\\n\"\n }\n ],\n \"role\":
|
|
||||||
\"model\"\n },\n \"finishReason\": \"STOP\",\n \"avgLogprobs\":
|
|
||||||
-0.0038021293125654523\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
|
|
||||||
371,\n \"candidatesTokenCount\": 19,\n \"totalTokenCount\": 390,\n \"promptTokensDetails\":
|
|
||||||
[\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 371\n
|
|
||||||
\ }\n ],\n \"candidatesTokensDetails\": [\n {\n \"modality\":
|
|
||||||
\"TEXT\",\n \"tokenCount\": 19\n }\n ]\n },\n \"modelVersion\":
|
|
||||||
\"gemini-2.0-flash-001\",\n \"responseId\": \"bbBzaauxJ_SgjrEP7onK2Ak\"\n}\n"
|
|
||||||
headers:
|
|
||||||
Alt-Svc:
|
|
||||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
||||||
Content-Type:
|
|
||||||
- application/json; charset=UTF-8
|
|
||||||
Date:
|
|
||||||
- Fri, 23 Jan 2026 17:31:26 GMT
|
|
||||||
Server:
|
|
||||||
- scaffolding on HTTPServer2
|
|
||||||
Server-Timing:
|
|
||||||
- gfet4t7; dur=454
|
|
||||||
Transfer-Encoding:
|
|
||||||
- chunked
|
|
||||||
Vary:
|
|
||||||
- Origin
|
|
||||||
- X-Origin
|
|
||||||
- Referer
|
|
||||||
X-Content-Type-Options:
|
|
||||||
- X-CONTENT-TYPE-XXX
|
|
||||||
X-Frame-Options:
|
|
||||||
- X-FRAME-OPTIONS-XXX
|
|
||||||
X-XSS-Protection:
|
|
||||||
- '0'
|
|
||||||
status:
|
|
||||||
code: 200
|
|
||||||
message: OK
|
|
||||||
version: 1
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -193,3 +193,40 @@ def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||||
storage.save(["test document"])
|
storage.save(["test document"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||||
|
def test_save_empty_documents_list(mock_get_client: MagicMock) -> None:
|
||||||
|
"""Test that save() handles empty documents list gracefully.
|
||||||
|
|
||||||
|
Calling save() with an empty documents list should be a no-op and not
|
||||||
|
propagate low-level storage exceptions from ChromaDB.
|
||||||
|
"""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="empty_docs_test")
|
||||||
|
|
||||||
|
storage.save([])
|
||||||
|
|
||||||
|
mock_client.get_or_create_collection.assert_not_called()
|
||||||
|
mock_client.add_documents.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||||
|
async def test_asave_empty_documents_list(mock_get_client: MagicMock) -> None:
|
||||||
|
"""Test that asave() handles empty documents list gracefully.
|
||||||
|
|
||||||
|
Calling asave() with an empty documents list should be a no-op and not
|
||||||
|
propagate low-level storage exceptions from ChromaDB.
|
||||||
|
"""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="empty_docs_async_test")
|
||||||
|
|
||||||
|
await storage.asave([])
|
||||||
|
|
||||||
|
mock_client.aget_or_create_collection.assert_not_called()
|
||||||
|
mock_client.aadd_documents.assert_not_called()
|
||||||
|
|||||||
@@ -635,54 +635,6 @@ def test_gemini_token_usage_tracking():
|
|||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
|
||||||
def test_gemini_tool_returning_float():
|
|
||||||
"""
|
|
||||||
Test that Gemini properly handles tools that return non-dict values like floats.
|
|
||||||
|
|
||||||
This is an end-to-end test that verifies the agent can use a tool that returns
|
|
||||||
a float (which gets wrapped in {"result": value} for Gemini's FunctionResponse).
|
|
||||||
"""
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Type
|
|
||||||
from crewai.tools import BaseTool
|
|
||||||
|
|
||||||
class SumNumbersToolInput(BaseModel):
|
|
||||||
a: float = Field(..., description="The first number to add")
|
|
||||||
b: float = Field(..., description="The second number to add")
|
|
||||||
|
|
||||||
class SumNumbersTool(BaseTool):
|
|
||||||
name: str = "sum_numbers"
|
|
||||||
description: str = "Add two numbers together and return the result"
|
|
||||||
args_schema: Type[BaseModel] = SumNumbersToolInput
|
|
||||||
|
|
||||||
def _run(self, a: float, b: float) -> float:
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
sum_tool = SumNumbersTool()
|
|
||||||
|
|
||||||
agent = Agent(
|
|
||||||
role="Calculator",
|
|
||||||
goal="Calculate numbers accurately",
|
|
||||||
backstory="You are a calculator that adds numbers.",
|
|
||||||
llm=LLM(model="google/gemini-2.0-flash-001"),
|
|
||||||
tools=[sum_tool],
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
task = Task(
|
|
||||||
description="What is 10000 + 20000? Use the sum_numbers tool to calculate this.",
|
|
||||||
expected_output="The sum of the two numbers",
|
|
||||||
agent=agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
|
||||||
result = crew.kickoff()
|
|
||||||
|
|
||||||
# The result should contain 30000 (the sum)
|
|
||||||
assert "30000" in result.raw
|
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_stop_sequences_sync():
|
def test_gemini_stop_sequences_sync():
|
||||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||||
|
|||||||
@@ -511,13 +511,10 @@ def test_openai_streaming_with_response_model():
|
|||||||
mock_chunk1 = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk1.type = "content.delta"
|
mock_chunk1.type = "content.delta"
|
||||||
mock_chunk1.delta = '{"answer": "test", '
|
mock_chunk1.delta = '{"answer": "test", '
|
||||||
mock_chunk1.id = "response-1"
|
|
||||||
|
|
||||||
# Second chunk
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.type = "content.delta"
|
mock_chunk2.type = "content.delta"
|
||||||
mock_chunk2.delta = '"confidence": 0.95}'
|
mock_chunk2.delta = '"confidence": 0.95}'
|
||||||
mock_chunk2.id = "response-2"
|
|
||||||
|
|
||||||
# Create mock final completion with parsed result
|
# Create mock final completion with parsed result
|
||||||
mock_parsed = TestResponse(answer="test", confidence=0.95)
|
mock_parsed = TestResponse(answer="test", confidence=0.95)
|
||||||
|
|||||||
@@ -272,100 +272,3 @@ class TestEmbeddingFactory:
|
|||||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||||
assert result == mock_embedding_function
|
assert result == mock_embedding_function
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
||||||
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
|
||||||
"""Test routing to Google Vertex provider with new genai model."""
|
|
||||||
mock_provider_class = MagicMock()
|
|
||||||
mock_provider_instance = MagicMock()
|
|
||||||
mock_embedding_function = MagicMock()
|
|
||||||
|
|
||||||
mock_import.return_value = mock_provider_class
|
|
||||||
mock_provider_class.return_value = mock_provider_instance
|
|
||||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"api_key": "test-google-api-key",
|
|
||||||
"model_name": "gemini-embedding-001",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
build_embedder(config)
|
|
||||||
|
|
||||||
mock_import.assert_called_once_with(
|
|
||||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
|
||||||
)
|
|
||||||
mock_provider_class.assert_called_once()
|
|
||||||
|
|
||||||
call_kwargs = mock_provider_class.call_args.kwargs
|
|
||||||
assert call_kwargs["api_key"] == "test-google-api-key"
|
|
||||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
|
||||||
|
|
||||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
||||||
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
|
||||||
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
|
||||||
mock_provider_class = MagicMock()
|
|
||||||
mock_provider_instance = MagicMock()
|
|
||||||
mock_embedding_function = MagicMock()
|
|
||||||
|
|
||||||
mock_import.return_value = mock_provider_class
|
|
||||||
mock_provider_class.return_value = mock_provider_instance
|
|
||||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"project_id": "my-gcp-project",
|
|
||||||
"region": "us-central1",
|
|
||||||
"model_name": "textembedding-gecko",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
build_embedder(config)
|
|
||||||
|
|
||||||
mock_import.assert_called_once_with(
|
|
||||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
|
||||||
)
|
|
||||||
mock_provider_class.assert_called_once()
|
|
||||||
|
|
||||||
call_kwargs = mock_provider_class.call_args.kwargs
|
|
||||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
|
||||||
assert call_kwargs["region"] == "us-central1"
|
|
||||||
assert call_kwargs["model_name"] == "textembedding-gecko"
|
|
||||||
|
|
||||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
|
||||||
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
|
||||||
"""Test routing to Google Vertex provider with location parameter."""
|
|
||||||
mock_provider_class = MagicMock()
|
|
||||||
mock_provider_instance = MagicMock()
|
|
||||||
mock_embedding_function = MagicMock()
|
|
||||||
|
|
||||||
mock_import.return_value = mock_provider_class
|
|
||||||
mock_provider_class.return_value = mock_provider_instance
|
|
||||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"project_id": "my-gcp-project",
|
|
||||||
"location": "europe-west1",
|
|
||||||
"model_name": "gemini-embedding-001",
|
|
||||||
"task_type": "RETRIEVAL_DOCUMENT",
|
|
||||||
"output_dimensionality": 768,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
build_embedder(config)
|
|
||||||
|
|
||||||
mock_import.assert_called_once_with(
|
|
||||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
|
||||||
)
|
|
||||||
|
|
||||||
call_kwargs = mock_provider_class.call_args.kwargs
|
|
||||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
|
||||||
assert call_kwargs["location"] == "europe-west1"
|
|
||||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
|
||||||
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
|
||||||
assert call_kwargs["output_dimensionality"] == 768
|
|
||||||
|
|||||||
@@ -1,176 +0,0 @@
|
|||||||
"""Integration tests for Google Vertex embeddings with Crew memory.
|
|
||||||
|
|
||||||
These tests make real API calls and use VCR to record/replay responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from collections import defaultdict
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from crewai import Agent, Crew, Task
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.types.memory_events import (
|
|
||||||
MemorySaveCompletedEvent,
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_vertex_ai_env():
|
|
||||||
"""Set up environment for Vertex AI tests.
|
|
||||||
|
|
||||||
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
|
||||||
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
|
||||||
Also mocks GOOGLE_API_KEY if not already set.
|
|
||||||
"""
|
|
||||||
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
|
||||||
|
|
||||||
# Add a mock API key if none exists
|
|
||||||
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
|
||||||
env_updates["GOOGLE_API_KEY"] = "test-key"
|
|
||||||
|
|
||||||
with patch.dict(os.environ, env_updates):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def google_vertex_embedder_config():
|
|
||||||
"""Fixture providing Google Vertex embedder configuration."""
|
|
||||||
return {
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
|
||||||
"model_name": "gemini-embedding-001",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_agent():
|
|
||||||
"""Fixture providing a simple test agent."""
|
|
||||||
return Agent(
|
|
||||||
role="Research Assistant",
|
|
||||||
goal="Help with research tasks",
|
|
||||||
backstory="You are a helpful research assistant.",
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_task(simple_agent):
|
|
||||||
"""Fixture providing a simple test task."""
|
|
||||||
return Task(
|
|
||||||
description="Summarize the key points about artificial intelligence in one sentence.",
|
|
||||||
expected_output="A one sentence summary about AI.",
|
|
||||||
agent=simple_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
|
||||||
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
|
||||||
def test_crew_memory_with_google_vertex_embedder(
|
|
||||||
google_vertex_embedder_config, simple_agent, simple_task
|
|
||||||
) -> None:
|
|
||||||
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
|
||||||
# Track memory events
|
|
||||||
events: dict[str, list] = defaultdict(list)
|
|
||||||
condition = threading.Condition()
|
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
|
||||||
def on_save_started(source, event):
|
|
||||||
with condition:
|
|
||||||
events["MemorySaveStartedEvent"].append(event)
|
|
||||||
condition.notify()
|
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
|
||||||
def on_save_completed(source, event):
|
|
||||||
with condition:
|
|
||||||
events["MemorySaveCompletedEvent"].append(event)
|
|
||||||
condition.notify()
|
|
||||||
|
|
||||||
crew = Crew(
|
|
||||||
agents=[simple_agent],
|
|
||||||
tasks=[simple_task],
|
|
||||||
memory=True,
|
|
||||||
embedder=google_vertex_embedder_config,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = crew.kickoff()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.raw is not None
|
|
||||||
assert len(result.raw) > 0
|
|
||||||
|
|
||||||
with condition:
|
|
||||||
success = condition.wait_for(
|
|
||||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
|
||||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
|
||||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
|
||||||
@pytest.mark.timeout(120)
|
|
||||||
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
|
||||||
"""Test Crew memory with Google Vertex using project_id authentication."""
|
|
||||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
||||||
if not project_id:
|
|
||||||
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
|
||||||
|
|
||||||
# Track memory events
|
|
||||||
events: dict[str, list] = defaultdict(list)
|
|
||||||
condition = threading.Condition()
|
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
|
||||||
def on_save_started(source, event):
|
|
||||||
with condition:
|
|
||||||
events["MemorySaveStartedEvent"].append(event)
|
|
||||||
condition.notify()
|
|
||||||
|
|
||||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
|
||||||
def on_save_completed(source, event):
|
|
||||||
with condition:
|
|
||||||
events["MemorySaveCompletedEvent"].append(event)
|
|
||||||
condition.notify()
|
|
||||||
|
|
||||||
embedder_config = {
|
|
||||||
"provider": "google-vertex",
|
|
||||||
"config": {
|
|
||||||
"project_id": project_id,
|
|
||||||
"location": "us-central1",
|
|
||||||
"model_name": "gemini-embedding-001",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
crew = Crew(
|
|
||||||
agents=[simple_agent],
|
|
||||||
tasks=[simple_task],
|
|
||||||
memory=True,
|
|
||||||
embedder=embedder_config,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = crew.kickoff()
|
|
||||||
|
|
||||||
# Verify basic result
|
|
||||||
assert result is not None
|
|
||||||
assert result.raw is not None
|
|
||||||
|
|
||||||
# Wait for memory save events
|
|
||||||
with condition:
|
|
||||||
success = condition.wait_for(
|
|
||||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify memory was actually used
|
|
||||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
|
||||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
|
||||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
|
||||||
@@ -984,8 +984,8 @@ def test_streaming_fallback_to_non_streaming():
|
|||||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||||
nonlocal fallback_called
|
nonlocal fallback_called
|
||||||
# Emit a couple of chunks to simulate partial streaming
|
# Emit a couple of chunks to simulate partial streaming
|
||||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1", response_id = "Id"))
|
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2", response_id = "Id"))
|
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||||
|
|
||||||
# Mark that fallback would be called
|
# Mark that fallback would be called
|
||||||
fallback_called = True
|
fallback_called = True
|
||||||
@@ -1041,7 +1041,7 @@ def test_streaming_empty_response_handling():
|
|||||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||||
# Emit a few empty chunks
|
# Emit a few empty chunks
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="",response_id="id"))
|
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||||
|
|
||||||
# Return the default message for empty responses
|
# Return the default message for empty responses
|
||||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""CrewAI development tools."""
|
"""CrewAI development tools."""
|
||||||
|
|
||||||
__version__ = "1.9.0"
|
__version__ = "1.8.1"
|
||||||
|
|||||||
8
uv.lock
generated
8
uv.lock
generated
@@ -310,7 +310,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anthropic"
|
name = "anthropic"
|
||||||
version = "0.73.0"
|
version = "0.71.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
@@ -322,9 +322,9 @@ dependencies = [
|
|||||||
{ name = "sniffio" },
|
{ name = "sniffio" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
|
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1276,7 +1276,7 @@ requires-dist = [
|
|||||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
||||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||||
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
||||||
|
|||||||
Reference in New Issue
Block a user