mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
feat: enhance BedrockCompletion class with advanced features
* feat: enhance BedrockCompletion class with advanced features and error handling - Added support for guardrail configuration, additional model request fields, and custom response field paths in the BedrockCompletion class. - Improved error handling for AWS exceptions and added token usage tracking with stop reason logging. - Enhanced streaming response handling with comprehensive event management, including tool use and content block processing. - Updated documentation to reflect new features and initialization parameters. - Introduced a new test suite for BedrockCompletion to validate functionality and ensure robust integration with AWS Bedrock APIs. * chore: add boto typing * fix: use typing_extensions.Required for Python 3.10 compatibility --------- Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -11,6 +15,20 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_bedrock_runtime.type_defs import (
|
||||
GuardrailConfigurationTypeDef,
|
||||
GuardrailStreamConfigurationTypeDef,
|
||||
InferenceConfigurationTypeDef,
|
||||
MessageOutputTypeDef,
|
||||
MessageTypeDef,
|
||||
SystemContentBlockTypeDef,
|
||||
TokenUsageTypeDef,
|
||||
ToolConfigurationTypeDef,
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from boto3.session import Session
|
||||
from botocore.config import Config
|
||||
@@ -21,11 +39,104 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class EnhancedInferenceConfigurationTypeDef(
|
||||
InferenceConfigurationTypeDef, total=False
|
||||
):
|
||||
"""Extended InferenceConfigurationTypeDef with topK support.
|
||||
|
||||
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
|
||||
This extends the base type to include topK while maintaining all other fields.
|
||||
"""
|
||||
|
||||
topK: int # noqa: N815 - AWS API uses topK naming
|
||||
|
||||
else:
|
||||
|
||||
class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False):
|
||||
"""Extended InferenceConfigurationTypeDef with topK support.
|
||||
|
||||
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
|
||||
This extends the base type to include topK while maintaining all other fields.
|
||||
"""
|
||||
|
||||
maxTokens: int
|
||||
temperature: float
|
||||
topP: float # noqa: N815 - AWS API uses topP naming
|
||||
stopSequences: list[str]
|
||||
topK: int # noqa: N815 - AWS API uses topK naming
|
||||
|
||||
|
||||
class ToolInputSchema(TypedDict):
|
||||
"""Type definition for tool input schema in Converse API."""
|
||||
|
||||
json: dict[str, Any]
|
||||
|
||||
|
||||
class ToolSpec(TypedDict, total=False):
|
||||
"""Type definition for tool specification in Converse API."""
|
||||
|
||||
name: Required[str]
|
||||
description: Required[str]
|
||||
inputSchema: ToolInputSchema
|
||||
|
||||
|
||||
class ConverseToolTypeDef(TypedDict):
|
||||
"""Type definition for a Converse API tool."""
|
||||
|
||||
toolSpec: ToolSpec
|
||||
|
||||
|
||||
class BedrockConverseRequestBody(TypedDict, total=False):
|
||||
"""Type definition for AWS Bedrock Converse API request body.
|
||||
|
||||
Based on AWS Bedrock Converse API specification.
|
||||
"""
|
||||
|
||||
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
|
||||
system: list[SystemContentBlockTypeDef]
|
||||
toolConfig: ToolConfigurationTypeDef
|
||||
guardrailConfig: GuardrailConfigurationTypeDef
|
||||
additionalModelRequestFields: dict[str, Any]
|
||||
additionalModelResponseFieldPaths: list[str]
|
||||
|
||||
|
||||
class BedrockConverseStreamRequestBody(TypedDict, total=False):
|
||||
"""Type definition for AWS Bedrock Converse Stream API request body.
|
||||
|
||||
Based on AWS Bedrock Converse Stream API specification.
|
||||
"""
|
||||
|
||||
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
|
||||
system: list[SystemContentBlockTypeDef]
|
||||
toolConfig: ToolConfigurationTypeDef
|
||||
guardrailConfig: GuardrailStreamConfigurationTypeDef
|
||||
additionalModelRequestFields: dict[str, Any]
|
||||
additionalModelResponseFieldPaths: list[str]
|
||||
|
||||
|
||||
class BedrockCompletion(BaseLLM):
|
||||
"""AWS Bedrock native completion implementation using the Converse API.
|
||||
|
||||
This class provides direct integration with AWS Bedrock using the modern
|
||||
Converse API, which provides a unified interface across all Bedrock models.
|
||||
|
||||
Features:
|
||||
- Full tool calling support with proper conversation continuation
|
||||
- Streaming and non-streaming responses with comprehensive event handling
|
||||
- Guardrail configuration for content filtering
|
||||
- Model-specific parameters via additionalModelRequestFields
|
||||
- Custom response field extraction
|
||||
- Proper error handling for all AWS exception types
|
||||
- Token usage tracking and stop reason logging
|
||||
- Support for both text and tool use content blocks
|
||||
|
||||
The implementation follows AWS Bedrock Converse API best practices including:
|
||||
- Proper tool use ID tracking for multi-turn tool conversations
|
||||
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
|
||||
- Response metadata and trace information capture
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -41,9 +152,30 @@ class BedrockCompletion(BaseLLM):
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize AWS Bedrock completion client."""
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
@@ -66,7 +198,6 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
connect_timeout=60,
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
@@ -85,6 +216,13 @@ class BedrockCompletion(BaseLLM):
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
@@ -96,7 +234,7 @@ class BedrockCompletion(BaseLLM):
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: Sequence[Mapping[str, Any]] | None = None,
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -119,24 +257,45 @@ class BedrockCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
# Prepare tool configuration
|
||||
tool_config = None
|
||||
if tools:
|
||||
tool_config = {"tools": self._format_tools_for_converse(tools)}
|
||||
|
||||
# Prepare request body
|
||||
body = {
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = [{"text": system_message}]
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Add tool config if present
|
||||
if tool_config:
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
@@ -161,7 +320,7 @@ class BedrockCompletion(BaseLLM):
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
@@ -183,13 +342,26 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
# Track token usage according to AWS response format
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
logging.debug(f"Response stop reason: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning("Response truncated due to max_tokens limit")
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning("Response was filtered due to content policy")
|
||||
|
||||
# Extract content following AWS response structure
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
@@ -201,28 +373,59 @@ class BedrockCompletion(BaseLLM):
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
# Extract text content from response
|
||||
# Process content blocks and handle tool use correctly
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
# Handle different content block types as per AWS documentation
|
||||
# Handle text content
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
elif content_block.get("type") == "toolUse" and available_functions:
|
||||
# Handle tool use according to AWS format
|
||||
tool_use = content_block["toolUse"]
|
||||
function_name = tool_use.get("name")
|
||||
function_args = tool_use.get("input", {})
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
# Handle tool use - corrected structure according to AWS API docs
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
function_name = tool_use_block["name"]
|
||||
function_args = tool_use_block.get("input", {})
|
||||
|
||||
logging.debug(
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
available_functions=dict(available_functions),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
if tool_result is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": tool_use_block}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [{"text": str(tool_result)}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Apply stop sequences if configured
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
@@ -298,23 +501,43 @@ class BedrockCompletion(BaseLLM):
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call."""
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
for event in stream:
|
||||
if "contentBlockDelta" in event:
|
||||
if "messageStart" in event:
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
@@ -325,10 +548,93 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
# If we were accumulating a tool use, it's now complete
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_result is not None and tool_use_id:
|
||||
# Continue conversation with tool result
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": current_tool_use}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [
|
||||
{"text": str(tool_result)}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Recursive call - note this switches to non-streaming
|
||||
return self._handle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
# Message stop - end of entire message
|
||||
elif "messageStop" in event:
|
||||
# Handle end of message
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
"Streaming response truncated due to max_tokens"
|
||||
)
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning(
|
||||
"Streaming response filtered due to content policy"
|
||||
)
|
||||
break
|
||||
|
||||
# Metadata - contains usage information and trace details
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
logging.debug(
|
||||
f"Trace information available: {metadata['trace']}"
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
@@ -430,25 +736,27 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return converse_messages, system_message
|
||||
|
||||
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
|
||||
@staticmethod
|
||||
def _format_tools_for_converse(tools: list[dict]) -> list[ConverseToolTypeDef]:
|
||||
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
converse_tools = []
|
||||
converse_tools: list[ConverseToolTypeDef] = []
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
|
||||
|
||||
converse_tool = {
|
||||
"toolSpec": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
tool_spec: ToolSpec = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
|
||||
input_schema: ToolInputSchema = {"json": parameters}
|
||||
tool_spec["inputSchema"] = input_schema
|
||||
|
||||
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}
|
||||
|
||||
converse_tools.append(converse_tool)
|
||||
|
||||
@@ -460,9 +768,9 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return converse_tools
|
||||
|
||||
def _get_inference_config(self) -> dict[str, Any]:
|
||||
def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef:
|
||||
"""Get inference configuration following AWS Converse API specification."""
|
||||
config = {}
|
||||
config: EnhancedInferenceConfigurationTypeDef = {}
|
||||
|
||||
if self.max_tokens:
|
||||
config["maxTokens"] = self.max_tokens
|
||||
@@ -503,7 +811,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return full_error_msg
|
||||
|
||||
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
|
||||
def _track_token_usage_internal(self, usage: TokenUsageTypeDef) -> None: # type: ignore[override]
|
||||
"""Track token usage from Bedrock response."""
|
||||
input_tokens = usage.get("inputTokens", 0)
|
||||
output_tokens = usage.get("outputTokens", 0)
|
||||
|
||||
Reference in New Issue
Block a user