feat: enhance BedrockCompletion class with advanced features

* feat: enhance BedrockCompletion class with advanced features and error handling

- Added support for guardrail configuration, additional model request fields, and custom response field paths in the BedrockCompletion class.
- Improved error handling for AWS exceptions and added token usage tracking with stop reason logging.
- Enhanced streaming response handling with comprehensive event management, including tool use and content block processing.
- Updated documentation to reflect new features and initialization parameters.
- Introduced a new test suite for BedrockCompletion to validate functionality and ensure robust integration with AWS Bedrock APIs.

* chore: add boto typing

* fix: use typing_extensions.Required for Python 3.10 compatibility

---------

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
Lorenze Jay
2025-10-17 08:30:35 -07:00
committed by GitHub
parent 02d7ce7621
commit 3b32793e78
4 changed files with 1145 additions and 41 deletions

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing_extensions import Required
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
@@ -11,6 +15,20 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
if TYPE_CHECKING:
from mypy_boto3_bedrock_runtime.type_defs import (
GuardrailConfigurationTypeDef,
GuardrailStreamConfigurationTypeDef,
InferenceConfigurationTypeDef,
MessageOutputTypeDef,
MessageTypeDef,
SystemContentBlockTypeDef,
TokenUsageTypeDef,
ToolConfigurationTypeDef,
ToolTypeDef,
)
try:
from boto3.session import Session
from botocore.config import Config
@@ -21,11 +39,104 @@ except ImportError:
) from None
if TYPE_CHECKING:
class EnhancedInferenceConfigurationTypeDef(
InferenceConfigurationTypeDef, total=False
):
"""Extended InferenceConfigurationTypeDef with topK support.
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
This extends the base type to include topK while maintaining all other fields.
"""
topK: int # noqa: N815 - AWS API uses topK naming
else:
class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False):
"""Extended InferenceConfigurationTypeDef with topK support.
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
This extends the base type to include topK while maintaining all other fields.
"""
maxTokens: int
temperature: float
topP: float # noqa: N815 - AWS API uses topP naming
stopSequences: list[str]
topK: int # noqa: N815 - AWS API uses topK naming
class ToolInputSchema(TypedDict):
"""Type definition for tool input schema in Converse API."""
json: dict[str, Any]
class ToolSpec(TypedDict, total=False):
"""Type definition for tool specification in Converse API."""
name: Required[str]
description: Required[str]
inputSchema: ToolInputSchema
class ConverseToolTypeDef(TypedDict):
"""Type definition for a Converse API tool."""
toolSpec: ToolSpec
class BedrockConverseRequestBody(TypedDict, total=False):
"""Type definition for AWS Bedrock Converse API request body.
Based on AWS Bedrock Converse API specification.
"""
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
system: list[SystemContentBlockTypeDef]
toolConfig: ToolConfigurationTypeDef
guardrailConfig: GuardrailConfigurationTypeDef
additionalModelRequestFields: dict[str, Any]
additionalModelResponseFieldPaths: list[str]
class BedrockConverseStreamRequestBody(TypedDict, total=False):
"""Type definition for AWS Bedrock Converse Stream API request body.
Based on AWS Bedrock Converse Stream API specification.
"""
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
system: list[SystemContentBlockTypeDef]
toolConfig: ToolConfigurationTypeDef
guardrailConfig: GuardrailStreamConfigurationTypeDef
additionalModelRequestFields: dict[str, Any]
additionalModelResponseFieldPaths: list[str]
class BedrockCompletion(BaseLLM):
"""AWS Bedrock native completion implementation using the Converse API.
This class provides direct integration with AWS Bedrock using the modern
Converse API, which provides a unified interface across all Bedrock models.
Features:
- Full tool calling support with proper conversation continuation
- Streaming and non-streaming responses with comprehensive event handling
- Guardrail configuration for content filtering
- Model-specific parameters via additionalModelRequestFields
- Custom response field extraction
- Proper error handling for all AWS exception types
- Token usage tracking and stop reason logging
- Support for both text and tool use content blocks
The implementation follows AWS Bedrock Converse API best practices including:
- Proper tool use ID tracking for multi-turn tool conversations
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
- Response metadata and trace information capture
- Model-specific conversation format handling (e.g., Cohere requirements)
"""
def __init__(
@@ -41,9 +152,30 @@ class BedrockCompletion(BaseLLM):
top_k: int | None = None,
stop_sequences: Sequence[str] | None = None,
stream: bool = False,
guardrail_config: dict[str, Any] | None = None,
additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None,
**kwargs,
):
"""Initialize AWS Bedrock completion client."""
"""Initialize AWS Bedrock completion client.
Args:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
temperature: Sampling temperature for response generation
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
**kwargs: Additional parameters
"""
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
@@ -66,7 +198,6 @@ class BedrockCompletion(BaseLLM):
# Configure client with timeouts and retries following AWS best practices
config = Config(
connect_timeout=60,
read_timeout=300,
retries={
"max_attempts": 3,
@@ -85,6 +216,13 @@ class BedrockCompletion(BaseLLM):
self.stream = stream
self.stop_sequences = stop_sequences or []
# Store advanced features (optional)
self.guardrail_config = guardrail_config
self.additional_model_request_fields = additional_model_request_fields
self.additional_model_response_field_paths = (
additional_model_response_field_paths
)
# Model-specific settings
self.is_claude_model = "claude" in model.lower()
self.supports_tools = True # Converse API supports tools for most models
@@ -96,7 +234,7 @@ class BedrockCompletion(BaseLLM):
def call(
self,
messages: str | list[dict[str, str]],
tools: Sequence[Mapping[str, Any]] | None = None,
tools: list[dict[Any, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -119,24 +257,45 @@ class BedrockCompletion(BaseLLM):
messages
)
# Prepare tool configuration
tool_config = None
if tools:
tool_config = {"tools": self._format_tools_for_converse(tools)}
# Prepare request body
body = {
body: BedrockConverseRequestBody = {
"inferenceConfig": self._get_inference_config(),
}
# Add system message if present
if system_message:
body["system"] = [{"text": system_message}]
body["system"] = cast(
"list[SystemContentBlockTypeDef]",
cast(object, [{"text": system_message}]),
)
# Add tool config if present
if tool_config:
if tools:
tool_config: ToolConfigurationTypeDef = {
"tools": cast(
"Sequence[ToolTypeDef]",
cast(object, self._format_tools_for_converse(tools)),
)
}
body["toolConfig"] = tool_config
# Add optional advanced features if configured
if self.guardrail_config:
guardrail_config: GuardrailConfigurationTypeDef = cast(
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
)
body["guardrailConfig"] = guardrail_config
if self.additional_model_request_fields:
body["additionalModelRequestFields"] = (
self.additional_model_request_fields
)
if self.additional_model_response_field_paths:
body["additionalModelResponseFieldPaths"] = (
self.additional_model_response_field_paths
)
if self.stream:
return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
@@ -161,7 +320,7 @@ class BedrockCompletion(BaseLLM):
def _handle_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
@@ -183,13 +342,26 @@ class BedrockCompletion(BaseLLM):
# Call Bedrock Converse API with proper error handling
response = self.client.converse(
modelId=self.model_id, messages=messages, **body
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body,
)
# Track token usage according to AWS response format
if "usage" in response:
self._track_token_usage_internal(response["usage"])
stop_reason = response.get("stopReason")
if stop_reason:
logging.debug(f"Response stop reason: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning("Response truncated due to max_tokens limit")
elif stop_reason == "content_filtered":
logging.warning("Response was filtered due to content policy")
# Extract content following AWS response structure
output = response.get("output", {})
message = output.get("message", {})
@@ -201,28 +373,59 @@ class BedrockCompletion(BaseLLM):
"I apologize, but I received an empty response. Please try again."
)
# Extract text content from response
# Process content blocks and handle tool use correctly
text_content = ""
for content_block in content:
# Handle different content block types as per AWS documentation
# Handle text content
if "text" in content_block:
text_content += content_block["text"]
elif content_block.get("type") == "toolUse" and available_functions:
# Handle tool use according to AWS format
tool_use = content_block["toolUse"]
function_name = tool_use.get("name")
function_args = tool_use.get("input", {})
result = self._handle_tool_execution(
# Handle tool use - corrected structure according to AWS API docs
elif "toolUse" in content_block and available_functions:
tool_use_block = content_block["toolUse"]
tool_use_id = tool_use_block.get("toolUseId")
function_name = tool_use_block["name"]
function_args = tool_use_block.get("input", {})
logging.debug(
f"Tool use requested: {function_name} with ID {tool_use_id}"
)
# Execute the tool
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
available_functions=dict(available_functions),
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
if tool_result is not None:
messages.append(
{
"role": "assistant",
"content": [{"toolUse": tool_use_block}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [{"text": str(tool_result)}],
}
}
],
}
)
return self._handle_converse(
messages, body, available_functions, from_task, from_agent
)
# Apply stop sequences if configured
text_content = self._apply_stop_words(text_content)
@@ -298,23 +501,43 @@ class BedrockCompletion(BaseLLM):
def _handle_streaming_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming converse API call."""
"""Handle streaming converse API call with comprehensive event handling."""
full_response = ""
current_tool_use = None
tool_use_id = None
try:
response = self.client.converse_stream(
modelId=self.model_id, messages=messages, **body
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body, # type: ignore[arg-type]
)
stream = response.get("stream")
if stream:
for event in stream:
if "contentBlockDelta" in event:
if "messageStart" in event:
role = event["messageStart"].get("role")
logging.debug(f"Streaming message started with role: {role}")
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_id = current_tool_use.get("toolUseId")
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
text_chunk = delta["text"]
@@ -325,10 +548,93 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
# Content block stop - end of a content block
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
# If we were accumulating a tool use, it's now complete
if current_tool_use and available_functions:
function_name = current_tool_use["name"]
function_args = cast(
dict[str, Any], current_tool_use.get("input", {})
)
# Execute tool
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if tool_result is not None and tool_use_id:
# Continue conversation with tool result
messages.append(
{
"role": "assistant",
"content": [{"toolUse": current_tool_use}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [
{"text": str(tool_result)}
],
}
}
],
}
)
# Recursive call - note this switches to non-streaming
return self._handle_converse(
messages,
body,
available_functions,
from_task,
from_agent,
)
current_tool_use = None
tool_use_id = None
# Message stop - end of entire message
elif "messageStop" in event:
# Handle end of message
stop_reason = event["messageStop"].get("stopReason")
logging.debug(f"Streaming message stopped: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning(
"Streaming response truncated due to max_tokens"
)
elif stop_reason == "content_filtered":
logging.warning(
"Streaming response filtered due to content policy"
)
break
# Metadata - contains usage information and trace details
elif "metadata" in event:
metadata = event["metadata"]
if "usage" in metadata:
usage_metrics = metadata["usage"]
self._track_token_usage_internal(usage_metrics)
logging.debug(f"Token usage: {usage_metrics}")
if "trace" in metadata:
logging.debug(
f"Trace information available: {metadata['trace']}"
)
except ClientError as e:
error_msg = self._handle_client_error(e)
raise RuntimeError(error_msg) from e
@@ -430,25 +736,27 @@ class BedrockCompletion(BaseLLM):
return converse_messages, system_message
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
@staticmethod
def _format_tools_for_converse(tools: list[dict]) -> list[ConverseToolTypeDef]:
"""Convert CrewAI tools to Converse API format following AWS specification."""
from crewai.llms.providers.utils.common import safe_tool_conversion
converse_tools = []
converse_tools: list[ConverseToolTypeDef] = []
for tool in tools:
try:
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
converse_tool = {
"toolSpec": {
"name": name,
"description": description,
}
tool_spec: ToolSpec = {
"name": name,
"description": description,
}
if parameters and isinstance(parameters, dict):
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
input_schema: ToolInputSchema = {"json": parameters}
tool_spec["inputSchema"] = input_schema
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}
converse_tools.append(converse_tool)
@@ -460,9 +768,9 @@ class BedrockCompletion(BaseLLM):
return converse_tools
def _get_inference_config(self) -> dict[str, Any]:
def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef:
"""Get inference configuration following AWS Converse API specification."""
config = {}
config: EnhancedInferenceConfigurationTypeDef = {}
if self.max_tokens:
config["maxTokens"] = self.max_tokens
@@ -503,7 +811,7 @@ class BedrockCompletion(BaseLLM):
return full_error_msg
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
def _track_token_usage_internal(self, usage: TokenUsageTypeDef) -> None: # type: ignore[override]
"""Track token usage from Bedrock response."""
input_tokens = usage.get("inputTokens", 0)
output_tokens = usage.get("outputTokens", 0)