mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: enhance BedrockCompletion class with advanced features
* feat: enhance BedrockCompletion class with advanced features and error handling - Added support for guardrail configuration, additional model request fields, and custom response field paths in the BedrockCompletion class. - Improved error handling for AWS exceptions and added token usage tracking with stop reason logging. - Enhanced streaming response handling with comprehensive event management, including tool use and content block processing. - Updated documentation to reflect new features and initialization parameters. - Introduced a new test suite for BedrockCompletion to validate functionality and ensure robust integration with AWS Bedrock APIs. * chore: add boto typing * fix: use typing_extensions.Required for Python 3.10 compatibility --------- Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -11,6 +15,20 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_bedrock_runtime.type_defs import (
|
||||
GuardrailConfigurationTypeDef,
|
||||
GuardrailStreamConfigurationTypeDef,
|
||||
InferenceConfigurationTypeDef,
|
||||
MessageOutputTypeDef,
|
||||
MessageTypeDef,
|
||||
SystemContentBlockTypeDef,
|
||||
TokenUsageTypeDef,
|
||||
ToolConfigurationTypeDef,
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from boto3.session import Session
|
||||
from botocore.config import Config
|
||||
@@ -21,11 +39,104 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class EnhancedInferenceConfigurationTypeDef(
|
||||
InferenceConfigurationTypeDef, total=False
|
||||
):
|
||||
"""Extended InferenceConfigurationTypeDef with topK support.
|
||||
|
||||
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
|
||||
This extends the base type to include topK while maintaining all other fields.
|
||||
"""
|
||||
|
||||
topK: int # noqa: N815 - AWS API uses topK naming
|
||||
|
||||
else:
|
||||
|
||||
class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False):
|
||||
"""Extended InferenceConfigurationTypeDef with topK support.
|
||||
|
||||
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
|
||||
This extends the base type to include topK while maintaining all other fields.
|
||||
"""
|
||||
|
||||
maxTokens: int
|
||||
temperature: float
|
||||
topP: float # noqa: N815 - AWS API uses topP naming
|
||||
stopSequences: list[str]
|
||||
topK: int # noqa: N815 - AWS API uses topK naming
|
||||
|
||||
|
||||
class ToolInputSchema(TypedDict):
|
||||
"""Type definition for tool input schema in Converse API."""
|
||||
|
||||
json: dict[str, Any]
|
||||
|
||||
|
||||
class ToolSpec(TypedDict, total=False):
|
||||
"""Type definition for tool specification in Converse API."""
|
||||
|
||||
name: Required[str]
|
||||
description: Required[str]
|
||||
inputSchema: ToolInputSchema
|
||||
|
||||
|
||||
class ConverseToolTypeDef(TypedDict):
|
||||
"""Type definition for a Converse API tool."""
|
||||
|
||||
toolSpec: ToolSpec
|
||||
|
||||
|
||||
class BedrockConverseRequestBody(TypedDict, total=False):
|
||||
"""Type definition for AWS Bedrock Converse API request body.
|
||||
|
||||
Based on AWS Bedrock Converse API specification.
|
||||
"""
|
||||
|
||||
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
|
||||
system: list[SystemContentBlockTypeDef]
|
||||
toolConfig: ToolConfigurationTypeDef
|
||||
guardrailConfig: GuardrailConfigurationTypeDef
|
||||
additionalModelRequestFields: dict[str, Any]
|
||||
additionalModelResponseFieldPaths: list[str]
|
||||
|
||||
|
||||
class BedrockConverseStreamRequestBody(TypedDict, total=False):
|
||||
"""Type definition for AWS Bedrock Converse Stream API request body.
|
||||
|
||||
Based on AWS Bedrock Converse Stream API specification.
|
||||
"""
|
||||
|
||||
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
|
||||
system: list[SystemContentBlockTypeDef]
|
||||
toolConfig: ToolConfigurationTypeDef
|
||||
guardrailConfig: GuardrailStreamConfigurationTypeDef
|
||||
additionalModelRequestFields: dict[str, Any]
|
||||
additionalModelResponseFieldPaths: list[str]
|
||||
|
||||
|
||||
class BedrockCompletion(BaseLLM):
|
||||
"""AWS Bedrock native completion implementation using the Converse API.
|
||||
|
||||
This class provides direct integration with AWS Bedrock using the modern
|
||||
Converse API, which provides a unified interface across all Bedrock models.
|
||||
|
||||
Features:
|
||||
- Full tool calling support with proper conversation continuation
|
||||
- Streaming and non-streaming responses with comprehensive event handling
|
||||
- Guardrail configuration for content filtering
|
||||
- Model-specific parameters via additionalModelRequestFields
|
||||
- Custom response field extraction
|
||||
- Proper error handling for all AWS exception types
|
||||
- Token usage tracking and stop reason logging
|
||||
- Support for both text and tool use content blocks
|
||||
|
||||
The implementation follows AWS Bedrock Converse API best practices including:
|
||||
- Proper tool use ID tracking for multi-turn tool conversations
|
||||
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
|
||||
- Response metadata and trace information capture
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -41,9 +152,30 @@ class BedrockCompletion(BaseLLM):
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize AWS Bedrock completion client."""
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
@@ -66,7 +198,6 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
connect_timeout=60,
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
@@ -85,6 +216,13 @@ class BedrockCompletion(BaseLLM):
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
@@ -96,7 +234,7 @@ class BedrockCompletion(BaseLLM):
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: Sequence[Mapping[str, Any]] | None = None,
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -119,24 +257,45 @@ class BedrockCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
# Prepare tool configuration
|
||||
tool_config = None
|
||||
if tools:
|
||||
tool_config = {"tools": self._format_tools_for_converse(tools)}
|
||||
|
||||
# Prepare request body
|
||||
body = {
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = [{"text": system_message}]
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Add tool config if present
|
||||
if tool_config:
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
@@ -161,7 +320,7 @@ class BedrockCompletion(BaseLLM):
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
@@ -183,13 +342,26 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
# Track token usage according to AWS response format
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
logging.debug(f"Response stop reason: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning("Response truncated due to max_tokens limit")
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning("Response was filtered due to content policy")
|
||||
|
||||
# Extract content following AWS response structure
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
@@ -201,28 +373,59 @@ class BedrockCompletion(BaseLLM):
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
# Extract text content from response
|
||||
# Process content blocks and handle tool use correctly
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
# Handle different content block types as per AWS documentation
|
||||
# Handle text content
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
elif content_block.get("type") == "toolUse" and available_functions:
|
||||
# Handle tool use according to AWS format
|
||||
tool_use = content_block["toolUse"]
|
||||
function_name = tool_use.get("name")
|
||||
function_args = tool_use.get("input", {})
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
# Handle tool use - corrected structure according to AWS API docs
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
function_name = tool_use_block["name"]
|
||||
function_args = tool_use_block.get("input", {})
|
||||
|
||||
logging.debug(
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
available_functions=dict(available_functions),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
if tool_result is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": tool_use_block}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [{"text": str(tool_result)}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Apply stop sequences if configured
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
@@ -298,23 +501,43 @@ class BedrockCompletion(BaseLLM):
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call."""
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
for event in stream:
|
||||
if "contentBlockDelta" in event:
|
||||
if "messageStart" in event:
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
@@ -325,10 +548,93 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
# If we were accumulating a tool use, it's now complete
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_result is not None and tool_use_id:
|
||||
# Continue conversation with tool result
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": current_tool_use}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [
|
||||
{"text": str(tool_result)}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Recursive call - note this switches to non-streaming
|
||||
return self._handle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
# Message stop - end of entire message
|
||||
elif "messageStop" in event:
|
||||
# Handle end of message
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
"Streaming response truncated due to max_tokens"
|
||||
)
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning(
|
||||
"Streaming response filtered due to content policy"
|
||||
)
|
||||
break
|
||||
|
||||
# Metadata - contains usage information and trace details
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
logging.debug(
|
||||
f"Trace information available: {metadata['trace']}"
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
@@ -430,25 +736,27 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return converse_messages, system_message
|
||||
|
||||
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
|
||||
@staticmethod
|
||||
def _format_tools_for_converse(tools: list[dict]) -> list[ConverseToolTypeDef]:
|
||||
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
converse_tools = []
|
||||
converse_tools: list[ConverseToolTypeDef] = []
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
|
||||
|
||||
converse_tool = {
|
||||
"toolSpec": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
tool_spec: ToolSpec = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
|
||||
input_schema: ToolInputSchema = {"json": parameters}
|
||||
tool_spec["inputSchema"] = input_schema
|
||||
|
||||
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}
|
||||
|
||||
converse_tools.append(converse_tool)
|
||||
|
||||
@@ -460,9 +768,9 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return converse_tools
|
||||
|
||||
def _get_inference_config(self) -> dict[str, Any]:
|
||||
def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef:
|
||||
"""Get inference configuration following AWS Converse API specification."""
|
||||
config = {}
|
||||
config: EnhancedInferenceConfigurationTypeDef = {}
|
||||
|
||||
if self.max_tokens:
|
||||
config["maxTokens"] = self.max_tokens
|
||||
@@ -503,7 +811,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
return full_error_msg
|
||||
|
||||
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
|
||||
def _track_token_usage_internal(self, usage: TokenUsageTypeDef) -> None: # type: ignore[override]
|
||||
"""Track token usage from Bedrock response."""
|
||||
input_tokens = usage.get("inputTokens", 0)
|
||||
output_tokens = usage.get("outputTokens", 0)
|
||||
|
||||
733
lib/crewai/tests/llms/bedrock/test_bedrock.py
Normal file
733
lib/crewai/tests/llms/bedrock/test_bedrock.py
Normal file
@@ -0,0 +1,733 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_aws_credentials():
|
||||
"""Automatically mock AWS credentials and boto3 Session for all tests in this module."""
|
||||
with patch.dict(os.environ, {
|
||||
"AWS_ACCESS_KEY_ID": "test-access-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
||||
"AWS_DEFAULT_REGION": "us-east-1"
|
||||
}):
|
||||
# Mock boto3 Session to prevent actual AWS connections
|
||||
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
|
||||
# Create mock session instance
|
||||
mock_session_instance = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Set up default mock responses to prevent hanging
|
||||
default_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [
|
||||
{'text': 'Test response'}
|
||||
]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 10,
|
||||
'outputTokens': 5,
|
||||
'totalTokens': 15
|
||||
}
|
||||
}
|
||||
mock_client.converse.return_value = default_response
|
||||
mock_client.converse_stream.return_value = {'stream': []}
|
||||
|
||||
# Configure the mock session instance to return the mock client
|
||||
mock_session_instance.client.return_value = mock_client
|
||||
|
||||
# Configure the mock Session class to return the mock session instance
|
||||
mock_session_class.return_value = mock_session_instance
|
||||
|
||||
yield mock_session_class, mock_client
|
||||
|
||||
|
||||
def test_bedrock_completion_is_used_when_bedrock_provider():
|
||||
"""
|
||||
Test that BedrockCompletion from completion.py is used when LLM uses provider 'bedrock'
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
assert llm.__class__.__name__ == "BedrockCompletion"
|
||||
assert llm.provider == "bedrock"
|
||||
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def test_bedrock_completion_module_is_imported():
|
||||
"""
|
||||
Test that the completion module is properly imported when using Bedrock provider
|
||||
"""
|
||||
module_name = "crewai.llms.providers.bedrock.completion"
|
||||
|
||||
# Remove module from cache if it exists
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Create LLM instance - this should trigger the import
|
||||
LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Verify the module was imported
|
||||
assert module_name in sys.modules
|
||||
completion_mod = sys.modules[module_name]
|
||||
assert isinstance(completion_mod, types.ModuleType)
|
||||
|
||||
# Verify the class exists in the module
|
||||
assert hasattr(completion_mod, 'BedrockCompletion')
|
||||
|
||||
|
||||
def test_fallback_to_litellm_when_native_bedrock_fails():
|
||||
"""
|
||||
Test that LLM falls back to LiteLLM when native Bedrock completion fails
|
||||
"""
|
||||
# Mock the _get_native_provider to return a failing class
|
||||
with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
|
||||
|
||||
class FailingCompletion:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Native AWS Bedrock SDK failed")
|
||||
|
||||
mock_get_provider.return_value = FailingCompletion
|
||||
|
||||
# This should fall back to LiteLLM
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Check that it's using LiteLLM
|
||||
assert hasattr(llm, 'is_litellm')
|
||||
assert llm.is_litellm == True
|
||||
|
||||
|
||||
def test_bedrock_completion_initialization_parameters():
|
||||
"""
|
||||
Test that BedrockCompletion is initialized with correct parameters
|
||||
"""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
region_name="us-west-2"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 2000
|
||||
assert llm.top_p == 0.9
|
||||
assert llm.top_k == 40
|
||||
assert llm.region_name == "us-west-2"
|
||||
|
||||
|
||||
def test_bedrock_specific_parameters():
|
||||
"""
|
||||
Test Bedrock-specific parameters like stop_sequences and streaming
|
||||
"""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
stop_sequences=["Human:", "Assistant:"],
|
||||
stream=True,
|
||||
region_name="us-east-1"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.region_name == "us-east-1"
|
||||
|
||||
|
||||
def test_bedrock_completion_call():
|
||||
"""
|
||||
Test that BedrockCompletion call method works
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the call method on the instance
|
||||
with patch.object(llm, 'call', return_value="Hello! I'm Claude on Bedrock, ready to help.") as mock_call:
|
||||
result = llm.call("Hello, how are you?")
|
||||
|
||||
assert result == "Hello! I'm Claude on Bedrock, ready to help."
|
||||
mock_call.assert_called_once_with("Hello, how are you?")
|
||||
|
||||
|
||||
def test_bedrock_completion_called_during_crew_execution():
|
||||
"""
|
||||
Test that BedrockCompletion.call is actually invoked when running a crew
|
||||
"""
|
||||
# Create the LLM instance first
|
||||
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the call method on the specific instance
|
||||
with patch.object(bedrock_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
|
||||
|
||||
# Create agent with explicit LLM configuration
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find population info",
|
||||
backstory="You research populations.",
|
||||
llm=bedrock_llm,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Find Tokyo population",
|
||||
expected_output="Population number",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify mock was called
|
||||
assert mock_call.called
|
||||
assert "14 million" in str(result)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Crew execution test - may hang, needs investigation")
|
||||
def test_bedrock_completion_call_arguments():
|
||||
"""
|
||||
Test that BedrockCompletion.call is invoked with correct arguments
|
||||
"""
|
||||
# Create LLM instance first
|
||||
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(bedrock_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed successfully."
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete a simple task",
|
||||
backstory="You are a test agent.",
|
||||
llm=bedrock_llm # Use same instance
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello world",
|
||||
expected_output="Hello world",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Verify call was made
|
||||
assert mock_call.called
|
||||
|
||||
# Check the arguments passed to the call method
|
||||
call_args = mock_call.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# The first argument should be the messages
|
||||
messages = call_args[0][0] # First positional argument
|
||||
assert isinstance(messages, (str, list))
|
||||
|
||||
# Verify that the task description appears in the messages
|
||||
if isinstance(messages, str):
|
||||
assert "hello world" in messages.lower()
|
||||
elif isinstance(messages, list):
|
||||
message_content = str(messages).lower()
|
||||
assert "hello world" in message_content
|
||||
|
||||
|
||||
def test_multiple_bedrock_calls_in_crew():
|
||||
"""
|
||||
Test that BedrockCompletion.call is invoked multiple times for multiple tasks
|
||||
"""
|
||||
# Create LLM instance first
|
||||
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(bedrock_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed."
|
||||
|
||||
agent = Agent(
|
||||
role="Multi-task Agent",
|
||||
goal="Complete multiple tasks",
|
||||
backstory="You can handle multiple tasks.",
|
||||
llm=bedrock_llm # Use same instance
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task1, task2]
|
||||
)
|
||||
crew.kickoff()
|
||||
|
||||
# Verify multiple calls were made
|
||||
assert mock_call.call_count >= 2 # At least one call per task
|
||||
|
||||
# Verify each call had proper arguments
|
||||
for call in mock_call.call_args_list:
|
||||
assert len(call[0]) > 0 # Has positional arguments
|
||||
messages = call[0][0]
|
||||
assert messages is not None
|
||||
|
||||
def test_bedrock_completion_with_tools():
|
||||
"""
|
||||
Test that BedrockCompletion.call is invoked with tools when agent has tools
|
||||
"""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Tool result for: {query}"
|
||||
|
||||
# Create LLM instance first
|
||||
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(bedrock_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed with tools."
|
||||
|
||||
agent = Agent(
|
||||
role="Tool User",
|
||||
goal="Use tools to complete tasks",
|
||||
backstory="You can use tools.",
|
||||
llm=bedrock_llm, # Use same instance
|
||||
tools=[sample_tool]
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the sample tool",
|
||||
expected_output="Tool usage result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
|
||||
call_args = mock_call.call_args
|
||||
call_kwargs = call_args[1] if len(call_args) > 1 else {}
|
||||
|
||||
if 'tools' in call_kwargs:
|
||||
assert call_kwargs['tools'] is not None
|
||||
assert len(call_kwargs['tools']) > 0
|
||||
|
||||
|
||||
def test_bedrock_raises_error_when_model_not_found(mock_aws_credentials):
|
||||
"""Test that BedrockCompletion raises appropriate error when model not found"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Get the mock client from the fixture
|
||||
_, mock_client = mock_aws_credentials
|
||||
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ResourceNotFoundException',
|
||||
'Message': 'Could not resolve the foundation model from the model identifier'
|
||||
}
|
||||
}
|
||||
mock_client.converse.side_effect = ClientError(error_response, 'converse')
|
||||
|
||||
llm = LLM(model="bedrock/model-doesnt-exist")
|
||||
|
||||
with pytest.raises(Exception): # Should raise some error for unsupported model
|
||||
llm.call("Hello")
|
||||
|
||||
|
||||
def test_bedrock_aws_credentials_configuration():
|
||||
"""
|
||||
Test that AWS credentials configuration works properly
|
||||
"""
|
||||
# Test with environment variables
|
||||
with patch.dict(os.environ, {
|
||||
"AWS_ACCESS_KEY_ID": "test-access-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
||||
"AWS_DEFAULT_REGION": "us-east-1"
|
||||
}):
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
assert llm.region_name == "us-east-1"
|
||||
|
||||
# Test with explicit credentials
|
||||
llm_explicit = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id="explicit-key",
|
||||
aws_secret_access_key="explicit-secret",
|
||||
region_name="us-west-2"
|
||||
)
|
||||
assert isinstance(llm_explicit, BedrockCompletion)
|
||||
assert llm_explicit.region_name == "us-west-2"
|
||||
|
||||
|
||||
def test_bedrock_model_capabilities():
|
||||
"""
|
||||
Test that model capabilities are correctly identified
|
||||
"""
|
||||
# Test Claude model
|
||||
llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm_claude, BedrockCompletion)
|
||||
assert llm_claude.is_claude_model == True
|
||||
assert llm_claude.supports_tools == True
|
||||
|
||||
# Test other Bedrock model
|
||||
llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1")
|
||||
assert isinstance(llm_titan, BedrockCompletion)
|
||||
assert llm_titan.supports_tools == True
|
||||
|
||||
|
||||
def test_bedrock_inference_config():
|
||||
"""
|
||||
Test that inference config is properly prepared
|
||||
"""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
|
||||
# Test config preparation
|
||||
config = llm._get_inference_config()
|
||||
|
||||
# Verify config has the expected parameters
|
||||
assert 'temperature' in config
|
||||
assert config['temperature'] == 0.7
|
||||
assert 'topP' in config
|
||||
assert config['topP'] == 0.9
|
||||
assert 'maxTokens' in config
|
||||
assert config['maxTokens'] == 1000
|
||||
assert 'topK' in config
|
||||
assert config['topK'] == 40
|
||||
|
||||
|
||||
def test_bedrock_model_detection():
|
||||
"""
|
||||
Test that various Bedrock model formats are properly detected
|
||||
"""
|
||||
# Test Bedrock model naming patterns
|
||||
bedrock_test_cases = [
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/amazon.titan-text-express-v1",
|
||||
"bedrock/meta.llama3-70b-instruct-v1:0"
|
||||
]
|
||||
|
||||
for model_name in bedrock_test_cases:
|
||||
llm = LLM(model=model_name)
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion), f"Failed for model: {model_name}"
|
||||
|
||||
|
||||
def test_bedrock_supports_stop_words():
|
||||
"""
|
||||
Test that Bedrock models support stop sequences
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
assert llm.supports_stop_words() == True
|
||||
|
||||
|
||||
def test_bedrock_context_window_size():
|
||||
"""
|
||||
Test that Bedrock models return correct context window sizes
|
||||
"""
|
||||
# Test Claude 3.5 Sonnet
|
||||
llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
context_size_claude = llm_claude.get_context_window_size()
|
||||
assert context_size_claude > 150000 # Should be substantial (200K tokens with ratio)
|
||||
|
||||
# Test Titan
|
||||
llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1")
|
||||
context_size_titan = llm_titan.get_context_window_size()
|
||||
assert context_size_titan > 5000 # Should have 8K context window
|
||||
|
||||
|
||||
def test_bedrock_message_formatting():
|
||||
"""
|
||||
Test that messages are properly formatted for Bedrock Converse API
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test message formatting
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_converse(test_messages)
|
||||
|
||||
# System message should be extracted
|
||||
assert system_message == "You are a helpful assistant."
|
||||
|
||||
# Remaining messages should be in Converse format
|
||||
assert len(formatted_messages) >= 3 # Should have user, assistant, user messages
|
||||
|
||||
# First message should be user role
|
||||
assert formatted_messages[0]["role"] == "user"
|
||||
# Second should be assistant
|
||||
assert formatted_messages[1]["role"] == "assistant"
|
||||
|
||||
# Messages should have content array with text
|
||||
assert isinstance(formatted_messages[0]["content"], list)
|
||||
assert "text" in formatted_messages[0]["content"][0]
|
||||
|
||||
|
||||
def test_bedrock_streaming_parameter():
|
||||
"""
|
||||
Test that streaming parameter is properly handled
|
||||
"""
|
||||
# Test non-streaming
|
||||
llm_no_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=False)
|
||||
assert llm_no_stream.stream == False
|
||||
|
||||
# Test streaming
|
||||
llm_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=True)
|
||||
assert llm_stream.stream == True
|
||||
|
||||
|
||||
def test_bedrock_tool_conversion():
|
||||
"""
|
||||
Test that tools are properly converted to Bedrock Converse format
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock tool in CrewAI format
|
||||
crewai_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# Test tool conversion
|
||||
bedrock_tools = llm._format_tools_for_converse(crewai_tools)
|
||||
|
||||
assert len(bedrock_tools) == 1
|
||||
# Bedrock tools should have toolSpec structure
|
||||
assert "toolSpec" in bedrock_tools[0]
|
||||
assert bedrock_tools[0]["toolSpec"]["name"] == "test_tool"
|
||||
assert bedrock_tools[0]["toolSpec"]["description"] == "A test tool"
|
||||
assert "inputSchema" in bedrock_tools[0]["toolSpec"]
|
||||
|
||||
|
||||
def test_bedrock_environment_variable_credentials(mock_aws_credentials):
|
||||
"""
|
||||
Test that AWS credentials are properly loaded from environment
|
||||
"""
|
||||
mock_session_class, _ = mock_aws_credentials
|
||||
|
||||
# Reset the mock to clear any previous calls
|
||||
mock_session_class.reset_mock()
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AWS_ACCESS_KEY_ID": "test-access-key-123",
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key-456"
|
||||
}):
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Verify Session was called with environment credentials
|
||||
assert mock_session_class.called
|
||||
# Get the most recent call - Session is called as Session(...)
|
||||
call_kwargs = mock_session_class.call_args[1] if mock_session_class.call_args else {}
|
||||
assert call_kwargs.get('aws_access_key_id') == "test-access-key-123"
|
||||
assert call_kwargs.get('aws_secret_access_key') == "test-secret-key-456"
|
||||
|
||||
|
||||
def test_bedrock_token_usage_tracking():
|
||||
"""
|
||||
Test that token usage is properly tracked for Bedrock responses
|
||||
"""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the Bedrock response with usage information
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [
|
||||
{'text': 'test response'}
|
||||
]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 50,
|
||||
'outputTokens': 25,
|
||||
'totalTokens': 75
|
||||
}
|
||||
}
|
||||
mock_converse.return_value = mock_response
|
||||
|
||||
result = llm.call("Hello")
|
||||
|
||||
# Verify the response
|
||||
assert result == "test response"
|
||||
|
||||
# Verify token usage was tracked
|
||||
assert llm._token_usage['prompt_tokens'] == 50
|
||||
assert llm._token_usage['completion_tokens'] == 25
|
||||
assert llm._token_usage['total_tokens'] == 75
|
||||
|
||||
|
||||
def test_bedrock_tool_use_conversation_flow():
|
||||
"""
|
||||
Test that the Bedrock completion properly handles tool use conversation flow
|
||||
"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Create BedrockCompletion instance
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock tool function
|
||||
def mock_weather_tool(location: str) -> str:
|
||||
return f"The weather in {location} is sunny and 75°F"
|
||||
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Bedrock client responses
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
# First response: tool use request
|
||||
tool_use_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [
|
||||
{
|
||||
'toolUse': {
|
||||
'toolUseId': 'tool-123',
|
||||
'name': 'get_weather',
|
||||
'input': {'location': 'San Francisco'}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 100,
|
||||
'outputTokens': 50,
|
||||
'totalTokens': 150
|
||||
}
|
||||
}
|
||||
|
||||
# Second response: final answer after tool execution
|
||||
final_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [
|
||||
{'text': 'Based on the weather data, it is sunny and 75°F in San Francisco.'}
|
||||
]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 120,
|
||||
'outputTokens': 30,
|
||||
'totalTokens': 150
|
||||
}
|
||||
}
|
||||
|
||||
# Configure mock to return different responses on successive calls
|
||||
mock_converse.side_effect = [tool_use_response, final_response]
|
||||
|
||||
# Test the call
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
result = llm.call(
|
||||
messages=messages,
|
||||
available_functions=available_functions
|
||||
)
|
||||
|
||||
# Verify the final response contains the weather information
|
||||
assert "sunny" in result.lower() or "75" in result
|
||||
|
||||
# Verify that the API was called twice (once for tool use, once for final answer)
|
||||
assert mock_converse.call_count == 2
|
||||
|
||||
|
||||
def test_bedrock_handles_cohere_conversation_requirements():
|
||||
"""
|
||||
Test that Bedrock properly handles Cohere model's requirement for user message at end
|
||||
"""
|
||||
llm = LLM(model="bedrock/cohere.command-r-plus-v1:0")
|
||||
|
||||
# Test message formatting with conversation ending in assistant message
|
||||
test_messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_converse(test_messages)
|
||||
|
||||
# For Cohere models, should add a user message at the end
|
||||
assert formatted_messages[-1]["role"] == "user"
|
||||
assert "continue" in formatted_messages[-1]["content"][0]["text"].lower()
|
||||
|
||||
|
||||
def test_bedrock_client_error_handling():
|
||||
"""
|
||||
Test that Bedrock properly handles various AWS client errors
|
||||
"""
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test ValidationException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ValidationException',
|
||||
'Message': 'Invalid request format'
|
||||
}
|
||||
}
|
||||
mock_converse.side_effect = ClientError(error_response, 'converse')
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
llm.call("Hello")
|
||||
assert "validation" in str(exc_info.value).lower()
|
||||
|
||||
# Test ThrottlingException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ThrottlingException',
|
||||
'Message': 'Rate limit exceeded'
|
||||
}
|
||||
}
|
||||
mock_converse.side_effect = ClientError(error_response, 'converse')
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
llm.call("Hello")
|
||||
assert "throttled" in str(exc_info.value).lower()
|
||||
@@ -25,6 +25,7 @@ dev = [
|
||||
"types-pyyaml==6.0.*",
|
||||
"types-regex==2024.11.6.*",
|
||||
"types-appdirs==1.4.*",
|
||||
"boto3-stubs[bedrock-runtime]>=1.40.54",
|
||||
]
|
||||
|
||||
|
||||
|
||||
62
uv.lock
generated
62
uv.lock
generated
@@ -38,6 +38,7 @@ members = [
|
||||
[manifest.dependency-groups]
|
||||
dev = [
|
||||
{ name = "bandit", specifier = ">=1.8.6" },
|
||||
{ name = "boto3-stubs", extras = ["bedrock-runtime"], specifier = ">=1.40.54" },
|
||||
{ name = "mypy", specifier = ">=1.18.2" },
|
||||
{ name = "pre-commit", specifier = ">=4.3.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.2" },
|
||||
@@ -479,6 +480,25 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/db/7d3c27f530c2b354d546ad7fb94505be8b78a5ecabe34c6a1f9a9d6be03e/boto3-1.40.45-py3-none-any.whl", hash = "sha256:5b145752d20f29908e3cb8c823bee31c77e6bcf18787e570f36bbc545cc779ed", size = 139345, upload-time = "2025-10-03T19:32:11.145Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3-stubs"
|
||||
version = "1.40.54"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore-stubs" },
|
||||
{ name = "types-s3transfer" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e2/70/245477b7f07c9e1533c47fa69e611b172814423a6fd4637004f0d2a13b73/boto3_stubs-1.40.54.tar.gz", hash = "sha256:e21a9eda979a451935eb3196de3efbe15b9470e6bf9027406d1f6d0ac08b339e", size = 100919, upload-time = "2025-10-16T19:49:17.079Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/52/ee9dadd1cc8911e16f18ca9fa036a10328e0a0d3fddd54fadcc1ca0f9143/boto3_stubs-1.40.54-py3-none-any.whl", hash = "sha256:548a4786785ba7b43ef4ef1a2a764bebbb0301525f3201091fcf412e4c8ce323", size = 69712, upload-time = "2025-10-16T19:49:12.847Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
bedrock-runtime = [
|
||||
{ name = "mypy-boto3-bedrock-runtime" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.40.45"
|
||||
@@ -494,6 +514,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/06/df47e2ecb74bd184c9d056666afd3db011a649eaca663337835a6dd5aee6/botocore-1.40.45-py3-none-any.whl", hash = "sha256:9abf473d8372ade8442c0d4634a9decb89c854d7862ffd5500574eb63ab8f240", size = 14063670, upload-time = "2025-10-03T19:31:58.999Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore-stubs"
|
||||
version = "1.40.54"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "types-awscrt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ea/c0/3e78314f9baa850aae648fb6b2506748046e1c3e374d6bb3514478e34590/botocore_stubs-1.40.54.tar.gz", hash = "sha256:fb38a794ab2b896f9cc237ec725546746accaffd34f382475a8d1b98ca1078e1", size = 42225, upload-time = "2025-10-16T20:26:56.711Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/9f/ab316f57a7e32d4a5b790070ffa5986991098044897b08f1b65951bced2a/botocore_stubs-1.40.54-py3-none-any.whl", hash = "sha256:997e6f1c03e079c244caedf315f7a515a07480af9f93f53535e506f17cdbe880", size = 66542, upload-time = "2025-10-16T20:26:54.109Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "browserbase"
|
||||
version = "1.4.0"
|
||||
@@ -4055,6 +4087,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-boto3-bedrock-runtime"
|
||||
version = "1.40.41"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c7/38/79989f7bce998776ed1a01c17f3f58e7bc6f5fc2bcbdff929701526fa2f1/mypy_boto3_bedrock_runtime-1.40.41.tar.gz", hash = "sha256:ee9bda6d6d478c8d0995e84e884bdf1798e150d437974ae27c175774a58ffaa5", size = 28333, upload-time = "2025-09-29T19:26:04.804Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/6c/d3431dadf473bb76aa590b1ed8cc91726a48b029b542eff9d3024f2d70b9/mypy_boto3_bedrock_runtime-1.40.41-py3-none-any.whl", hash = "sha256:d65dff200986ff06c6b3579ddcea102555f2067c8987fca379bf4f9ed8ba3121", size = 34181, upload-time = "2025-09-29T19:26:01.898Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.1.0"
|
||||
@@ -8181,6 +8225,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/07/41f5b9b11f11855eb67760ed680330e0ce9136a44b51c24dd52edb1c4eb1/types_appdirs-1.4.3.5-py3-none-any.whl", hash = "sha256:337c750e423c40911d389359b4edabe5bbc2cdd5cd0bd0518b71d2839646273b", size = 2667, upload-time = "2023-03-14T15:21:32.431Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-awscrt"
|
||||
version = "0.28.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/60/19/a3a6377c9e2e389c1421c033a1830c29cac08f2e1e05a082ea84eb22c75f/types_awscrt-0.28.1.tar.gz", hash = "sha256:66d77ec283e1dc907526a44511a12624118723a396c36d3f3dd9855cb614ce14", size = 17410, upload-time = "2025-10-11T21:55:07.443Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/c7/0266b797d19b82aebe0e177efe35de7aabdc192bc1605ce3309331f0a505/types_awscrt-0.28.1-py3-none-any.whl", hash = "sha256:d88f43ef779f90b841ba99badb72fe153077225a4e426ae79e943184827b4443", size = 41851, upload-time = "2025-10-11T21:55:06.235Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-pyyaml"
|
||||
version = "6.0.12.20250915"
|
||||
@@ -8251,6 +8304,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/ea/91b718b8c0b88e4f61cdd61357cc4a1f8767b32be691fb388299003a3ae3/types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5", size = 15347, upload-time = "2024-04-06T02:13:37.412Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-s3transfer"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8e/9b/8913198b7fc700acc1dcb84827137bb2922052e43dde0f4fb0ed2dc6f118/types_s3transfer-0.14.0.tar.gz", hash = "sha256:17f800a87c7eafab0434e9d87452c809c290ae906c2024c24261c564479e9c95", size = 14218, upload-time = "2025-10-11T21:11:27.892Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/c3/4dfb2e87c15ca582b7d956dfb7e549de1d005c758eb9a305e934e1b83fda/types_s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:108134854069a38b048e9b710b9b35904d22a9d0f37e4e1889c2e6b58e5b3253", size = 19697, upload-time = "2025-10-11T21:11:26.749Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-urllib3"
|
||||
version = "1.26.25.14"
|
||||
|
||||
Reference in New Issue
Block a user