feat: add structured outputs support to Bedrock and Anthropic providers

This commit is contained in:
Greyson Lalonde
2026-01-26 08:43:34 -05:00
parent c9b240a86c
commit 2e48e4e276
2 changed files with 388 additions and 133 deletions

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Final, Literal, cast
from anthropic.types import ThinkingBlock
from pydantic import BaseModel from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
@@ -24,6 +23,7 @@ if TYPE_CHECKING:
try: try:
from anthropic import Anthropic, AsyncAnthropic from anthropic import Anthropic, AsyncAnthropic
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
import httpx import httpx
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -31,7 +31,36 @@ except ImportError:
) from None ) from None
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14" ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
tuple[
Literal["claude-sonnet-4"],
Literal["claude-opus-4"],
Literal["claude-haiku-4"],
]
] = (
"claude-sonnet-4",
"claude-opus-4",
"claude-haiku-4",
)
def _supports_native_structured_outputs(model: str) -> bool:
"""Check if the model supports native structured outputs.
Native structured outputs are only available for Claude 4.x models.
Claude 3.x models require the tool-based fallback approach.
Args:
model: The model name/identifier.
Returns:
True if the model supports native structured outputs.
"""
model_lower = model.lower()
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool: def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
@@ -84,6 +113,7 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None, thinking: AnthropicThinkingConfig | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Anthropic chat completion client. """Initialize Anthropic chat completion client.
@@ -101,6 +131,8 @@ class AnthropicCompletion(BaseLLM):
stream: Enable streaming responses stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level. interceptor: HTTP interceptor for modifying requests/responses at transport level.
response_format: Pydantic model for structured output. When provided, responses
will be validated against this model schema.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
super().__init__( super().__init__(
@@ -131,6 +163,7 @@ class AnthropicCompletion(BaseLLM):
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
self.thinking = thinking self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = [] self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Model-specific settings # Model-specific settings
self.is_claude_3 = "claude-3" in model.lower() self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True self.supports_tools = True
@@ -231,6 +264,8 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
# Handle streaming vs non-streaming # Handle streaming vs non-streaming
if self.stream: if self.stream:
return self._handle_streaming_completion( return self._handle_streaming_completion(
@@ -238,7 +273,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -246,7 +281,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -298,13 +333,15 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
if self.stream: if self.stream:
return await self._ahandle_streaming_completion( return await self._ahandle_streaming_completion(
completion_params, completion_params,
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -312,7 +349,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -565,21 +602,33 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming message completion.""" """Handle non-streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", [])) uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
if response_model:
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
params["output_format"] = {
"type": "json_schema",
"schema": response_model.model_json_schema(),
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = self.client.beta.messages.create(**params) response = self.client.beta.messages.create(**params)
else: else:
response = self.client.messages.create(**params) response = self.client.messages.create(**params)
@@ -594,21 +643,33 @@ class AnthropicCompletion(BaseLLM):
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if response_model and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
self._emit_call_completed_event( call_type=LLMCallType.LLM_CALL,
response=structured_json, from_task=from_task,
call_type=LLMCallType.LLM_CALL, from_agent=from_agent,
from_task=from_task, messages=params["messages"],
from_agent=from_agent, )
messages=params["messages"], return structured_json
) else:
for block in response.content:
return structured_json if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
# Check if Claude wants to use tools # Check if Claude wants to use tools
if response.content: if response.content:
@@ -678,17 +739,27 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming message completion.""" """Handle streaming message completion."""
if response_model: betas: list[str] = []
structured_tool = { use_native_structured_output = False
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool] if response_model:
params["tool_choice"] = {"type": "tool", "name": "structured_output"} if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
params["output_format"] = {
"type": "json_schema",
"schema": response_model.model_json_schema(),
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = "" full_response = ""
@@ -696,15 +767,22 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally) # (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call stream_context = (
with self.client.messages.stream(**stream_params) as stream: self.client.beta.messages.stream(**stream_params)
if betas
else self.client.messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None response_id = None
for event in stream: for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
response_id = event.message.id response_id = event.message.id
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
full_response += text_delta full_response += text_delta
@@ -712,7 +790,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -739,7 +817,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -763,10 +841,10 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = stream.get_final_message() final_message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = [] thinking_blocks: list[ThinkingBlock] = []
if final_message.content: if final_message.content:
@@ -781,25 +859,30 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if response_model:
tool_uses = [ if use_native_structured_output:
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return full_response
return structured_json for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content: if final_message.content:
tool_uses = [ tool_uses = [
@@ -809,11 +892,9 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
# Handle tool use conversation flow internally
return self._handle_tool_use_conversation( return self._handle_tool_use_conversation(
final_message, final_message,
tool_uses, tool_uses,
@@ -823,10 +904,8 @@ class AnthropicCompletion(BaseLLM):
from_agent, from_agent,
) )
# Apply stop words to full response
full_response = self._apply_stop_words(full_response) full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event( self._emit_call_completed_event(
response=full_response, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -884,7 +963,7 @@ class AnthropicCompletion(BaseLLM):
def _handle_tool_use_conversation( def _handle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1002,21 +1081,33 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming async message completion.""" """Handle non-streaming async message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", [])) uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
if response_model:
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
params["output_format"] = {
"type": "json_schema",
"schema": response_model.model_json_schema(),
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = await self.async_client.beta.messages.create(**params) response = await self.async_client.beta.messages.create(**params)
else: else:
response = await self.async_client.messages.create(**params) response = await self.async_client.messages.create(**params)
@@ -1031,22 +1122,33 @@ class AnthropicCompletion(BaseLLM):
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if response_model and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
call_type=LLMCallType.LLM_CALL,
self._emit_call_completed_event( from_task=from_task,
response=structured_json, from_agent=from_agent,
call_type=LLMCallType.LLM_CALL, messages=params["messages"],
from_task=from_task, )
from_agent=from_agent, return structured_json
messages=params["messages"], else:
) for block in response.content:
if (
return structured_json isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if response.content: if response.content:
tool_uses = [ tool_uses = [
@@ -1102,25 +1204,43 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle async streaming message completion.""" """Handle async streaming message completion."""
if response_model: betas: list[str] = []
structured_tool = { use_native_structured_output = False
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool] if response_model:
params["tool_choice"] = {"type": "tool", "name": "structured_output"} if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
params["output_format"] = {
"type": "json_schema",
"schema": response_model.model_json_schema(),
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = "" full_response = ""
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream: stream_context = (
self.async_client.beta.messages.stream(**stream_params)
if betas
else self.async_client.messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None response_id = None
async for event in stream: async for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
@@ -1133,7 +1253,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -1160,7 +1280,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -1184,33 +1304,38 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = await stream.get_final_message() final_message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if response_model:
tool_uses = [ if use_native_structured_output:
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return full_response
return structured_json for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content: if final_message.content:
tool_uses = [ tool_uses = [
@@ -1220,7 +1345,6 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
@@ -1247,7 +1371,7 @@ class AnthropicCompletion(BaseLLM):
async def _ahandle_tool_use_conversation( async def _ahandle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1356,7 +1480,9 @@ class AnthropicCompletion(BaseLLM):
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO) return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
@staticmethod @staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]: def _extract_anthropic_token_usage(
response: Message | BetaMessage,
) -> dict[str, Any]:
"""Extract token usage from Anthropic response.""" """Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage = response.usage usage = response.usage

View File

@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: dict[str, Any] | None = None, additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None, additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize AWS Bedrock completion client. """Initialize AWS Bedrock completion client.
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: Model-specific request parameters additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock). interceptor: HTTP interceptor (not yet supported for Bedrock).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
self.top_k = top_k self.top_k = top_k
self.stream = stream self.stream = stream
self.stop_sequences = stop_sequences self.stop_sequences = stop_sequences
self.response_format = response_format
# Store advanced features (optional) # Store advanced features (optional)
self.guardrail_config = guardrail_config self.guardrail_config = guardrail_config
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Call AWS Bedrock Converse API.""" """Call AWS Bedrock Converse API."""
effective_response_model = response_model or self.response_format
try: try:
# Emit call started event # Emit call started event
self._emit_call_started_event( self._emit_call_started_event(
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
return self._handle_converse( return self._handle_converse(
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
NotImplementedError: If aiobotocore is not installed. NotImplementedError: If aiobotocore is not installed.
LLMContextLengthExceededError: If context window is exceeded. LLMContextLengthExceededError: If context window is exceeded.
""" """
effective_response_model = response_model or self.response_format
if not AIOBOTOCORE_AVAILABLE: if not AIOBOTOCORE_AVAILABLE:
raise NotImplementedError( raise NotImplementedError(
"Async support for AWS Bedrock requires aiobotocore. " "Async support for AWS Bedrock requires aiobotocore. "
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
if self.stream: if self.stream:
return await self._ahandle_streaming_converse( return await self._ahandle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
return await self._ahandle_converse( return await self._ahandle_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming converse API call following AWS best practices.""" """Handle non-streaming converse API call following AWS best practices."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
# Validate messages format before API call
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming converse API call with comprehensive event handling.""" """Handle streaming converse API call with comprehensive event handling."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif "contentBlockStop" in event: elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream") logging.debug("Content block stopped in stream")
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming converse API call.""" """Handle async non-streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle async streaming converse API call.""" """Handle async streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk, chunk=text_chunk,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
elif "toolUse" in delta and current_tool_use: elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")