feat: enhance BedrockCompletion class with advanced features

* feat: enhance BedrockCompletion class with advanced features and error handling

- Added support for guardrail configuration, additional model request fields, and custom response field paths in the BedrockCompletion class.
- Improved error handling for AWS exceptions and added token usage tracking with stop reason logging.
- Enhanced streaming response handling with comprehensive event management, including tool use and content block processing.
- Updated documentation to reflect new features and initialization parameters.
- Introduced a new test suite for BedrockCompletion to validate functionality and ensure robust integration with AWS Bedrock APIs.

* chore: add boto typing

* fix: use typing_extensions.Required for Python 3.10 compatibility

---------

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
Lorenze Jay
2025-10-17 08:30:35 -07:00
committed by GitHub
parent 02d7ce7621
commit 3b32793e78
4 changed files with 1145 additions and 41 deletions

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing_extensions import Required
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
@@ -11,6 +15,20 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
if TYPE_CHECKING:
from mypy_boto3_bedrock_runtime.type_defs import (
GuardrailConfigurationTypeDef,
GuardrailStreamConfigurationTypeDef,
InferenceConfigurationTypeDef,
MessageOutputTypeDef,
MessageTypeDef,
SystemContentBlockTypeDef,
TokenUsageTypeDef,
ToolConfigurationTypeDef,
ToolTypeDef,
)
try:
from boto3.session import Session
from botocore.config import Config
@@ -21,11 +39,104 @@ except ImportError:
) from None
if TYPE_CHECKING:
class EnhancedInferenceConfigurationTypeDef(
InferenceConfigurationTypeDef, total=False
):
"""Extended InferenceConfigurationTypeDef with topK support.
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
This extends the base type to include topK while maintaining all other fields.
"""
topK: int # noqa: N815 - AWS API uses topK naming
else:
class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False):
"""Extended InferenceConfigurationTypeDef with topK support.
AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs.
This extends the base type to include topK while maintaining all other fields.
"""
maxTokens: int
temperature: float
topP: float # noqa: N815 - AWS API uses topP naming
stopSequences: list[str]
topK: int # noqa: N815 - AWS API uses topK naming
class ToolInputSchema(TypedDict):
"""Type definition for tool input schema in Converse API."""
json: dict[str, Any]
class ToolSpec(TypedDict, total=False):
"""Type definition for tool specification in Converse API."""
name: Required[str]
description: Required[str]
inputSchema: ToolInputSchema
class ConverseToolTypeDef(TypedDict):
"""Type definition for a Converse API tool."""
toolSpec: ToolSpec
class BedrockConverseRequestBody(TypedDict, total=False):
"""Type definition for AWS Bedrock Converse API request body.
Based on AWS Bedrock Converse API specification.
"""
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
system: list[SystemContentBlockTypeDef]
toolConfig: ToolConfigurationTypeDef
guardrailConfig: GuardrailConfigurationTypeDef
additionalModelRequestFields: dict[str, Any]
additionalModelResponseFieldPaths: list[str]
class BedrockConverseStreamRequestBody(TypedDict, total=False):
"""Type definition for AWS Bedrock Converse Stream API request body.
Based on AWS Bedrock Converse Stream API specification.
"""
inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef]
system: list[SystemContentBlockTypeDef]
toolConfig: ToolConfigurationTypeDef
guardrailConfig: GuardrailStreamConfigurationTypeDef
additionalModelRequestFields: dict[str, Any]
additionalModelResponseFieldPaths: list[str]
class BedrockCompletion(BaseLLM):
"""AWS Bedrock native completion implementation using the Converse API.
This class provides direct integration with AWS Bedrock using the modern
Converse API, which provides a unified interface across all Bedrock models.
Features:
- Full tool calling support with proper conversation continuation
- Streaming and non-streaming responses with comprehensive event handling
- Guardrail configuration for content filtering
- Model-specific parameters via additionalModelRequestFields
- Custom response field extraction
- Proper error handling for all AWS exception types
- Token usage tracking and stop reason logging
- Support for both text and tool use content blocks
The implementation follows AWS Bedrock Converse API best practices including:
- Proper tool use ID tracking for multi-turn tool conversations
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
- Response metadata and trace information capture
- Model-specific conversation format handling (e.g., Cohere requirements)
"""
def __init__(
@@ -41,9 +152,30 @@ class BedrockCompletion(BaseLLM):
top_k: int | None = None,
stop_sequences: Sequence[str] | None = None,
stream: bool = False,
guardrail_config: dict[str, Any] | None = None,
additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None,
**kwargs,
):
"""Initialize AWS Bedrock completion client."""
"""Initialize AWS Bedrock completion client.
Args:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
temperature: Sampling temperature for response generation
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
**kwargs: Additional parameters
"""
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
@@ -66,7 +198,6 @@ class BedrockCompletion(BaseLLM):
# Configure client with timeouts and retries following AWS best practices
config = Config(
connect_timeout=60,
read_timeout=300,
retries={
"max_attempts": 3,
@@ -85,6 +216,13 @@ class BedrockCompletion(BaseLLM):
self.stream = stream
self.stop_sequences = stop_sequences or []
# Store advanced features (optional)
self.guardrail_config = guardrail_config
self.additional_model_request_fields = additional_model_request_fields
self.additional_model_response_field_paths = (
additional_model_response_field_paths
)
# Model-specific settings
self.is_claude_model = "claude" in model.lower()
self.supports_tools = True # Converse API supports tools for most models
@@ -96,7 +234,7 @@ class BedrockCompletion(BaseLLM):
def call(
self,
messages: str | list[dict[str, str]],
tools: Sequence[Mapping[str, Any]] | None = None,
tools: list[dict[Any, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -119,24 +257,45 @@ class BedrockCompletion(BaseLLM):
messages
)
# Prepare tool configuration
tool_config = None
if tools:
tool_config = {"tools": self._format_tools_for_converse(tools)}
# Prepare request body
body = {
body: BedrockConverseRequestBody = {
"inferenceConfig": self._get_inference_config(),
}
# Add system message if present
if system_message:
body["system"] = [{"text": system_message}]
body["system"] = cast(
"list[SystemContentBlockTypeDef]",
cast(object, [{"text": system_message}]),
)
# Add tool config if present
if tool_config:
if tools:
tool_config: ToolConfigurationTypeDef = {
"tools": cast(
"Sequence[ToolTypeDef]",
cast(object, self._format_tools_for_converse(tools)),
)
}
body["toolConfig"] = tool_config
# Add optional advanced features if configured
if self.guardrail_config:
guardrail_config: GuardrailConfigurationTypeDef = cast(
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
)
body["guardrailConfig"] = guardrail_config
if self.additional_model_request_fields:
body["additionalModelRequestFields"] = (
self.additional_model_request_fields
)
if self.additional_model_response_field_paths:
body["additionalModelResponseFieldPaths"] = (
self.additional_model_response_field_paths
)
if self.stream:
return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
@@ -161,7 +320,7 @@ class BedrockCompletion(BaseLLM):
def _handle_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
@@ -183,13 +342,26 @@ class BedrockCompletion(BaseLLM):
# Call Bedrock Converse API with proper error handling
response = self.client.converse(
modelId=self.model_id, messages=messages, **body
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body,
)
# Track token usage according to AWS response format
if "usage" in response:
self._track_token_usage_internal(response["usage"])
stop_reason = response.get("stopReason")
if stop_reason:
logging.debug(f"Response stop reason: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning("Response truncated due to max_tokens limit")
elif stop_reason == "content_filtered":
logging.warning("Response was filtered due to content policy")
# Extract content following AWS response structure
output = response.get("output", {})
message = output.get("message", {})
@@ -201,28 +373,59 @@ class BedrockCompletion(BaseLLM):
"I apologize, but I received an empty response. Please try again."
)
# Extract text content from response
# Process content blocks and handle tool use correctly
text_content = ""
for content_block in content:
# Handle different content block types as per AWS documentation
# Handle text content
if "text" in content_block:
text_content += content_block["text"]
elif content_block.get("type") == "toolUse" and available_functions:
# Handle tool use according to AWS format
tool_use = content_block["toolUse"]
function_name = tool_use.get("name")
function_args = tool_use.get("input", {})
result = self._handle_tool_execution(
# Handle tool use - corrected structure according to AWS API docs
elif "toolUse" in content_block and available_functions:
tool_use_block = content_block["toolUse"]
tool_use_id = tool_use_block.get("toolUseId")
function_name = tool_use_block["name"]
function_args = tool_use_block.get("input", {})
logging.debug(
f"Tool use requested: {function_name} with ID {tool_use_id}"
)
# Execute the tool
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
available_functions=dict(available_functions),
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
if tool_result is not None:
messages.append(
{
"role": "assistant",
"content": [{"toolUse": tool_use_block}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [{"text": str(tool_result)}],
}
}
],
}
)
return self._handle_converse(
messages, body, available_functions, from_task, from_agent
)
# Apply stop sequences if configured
text_content = self._apply_stop_words(text_content)
@@ -298,23 +501,43 @@ class BedrockCompletion(BaseLLM):
def _handle_streaming_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming converse API call."""
"""Handle streaming converse API call with comprehensive event handling."""
full_response = ""
current_tool_use = None
tool_use_id = None
try:
response = self.client.converse_stream(
modelId=self.model_id, messages=messages, **body
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body, # type: ignore[arg-type]
)
stream = response.get("stream")
if stream:
for event in stream:
if "contentBlockDelta" in event:
if "messageStart" in event:
role = event["messageStart"].get("role")
logging.debug(f"Streaming message started with role: {role}")
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_id = current_tool_use.get("toolUseId")
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
text_chunk = delta["text"]
@@ -325,10 +548,93 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
# Content block stop - end of a content block
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
# If we were accumulating a tool use, it's now complete
if current_tool_use and available_functions:
function_name = current_tool_use["name"]
function_args = cast(
dict[str, Any], current_tool_use.get("input", {})
)
# Execute tool
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if tool_result is not None and tool_use_id:
# Continue conversation with tool result
messages.append(
{
"role": "assistant",
"content": [{"toolUse": current_tool_use}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [
{"text": str(tool_result)}
],
}
}
],
}
)
# Recursive call - note this switches to non-streaming
return self._handle_converse(
messages,
body,
available_functions,
from_task,
from_agent,
)
current_tool_use = None
tool_use_id = None
# Message stop - end of entire message
elif "messageStop" in event:
# Handle end of message
stop_reason = event["messageStop"].get("stopReason")
logging.debug(f"Streaming message stopped: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning(
"Streaming response truncated due to max_tokens"
)
elif stop_reason == "content_filtered":
logging.warning(
"Streaming response filtered due to content policy"
)
break
# Metadata - contains usage information and trace details
elif "metadata" in event:
metadata = event["metadata"]
if "usage" in metadata:
usage_metrics = metadata["usage"]
self._track_token_usage_internal(usage_metrics)
logging.debug(f"Token usage: {usage_metrics}")
if "trace" in metadata:
logging.debug(
f"Trace information available: {metadata['trace']}"
)
except ClientError as e:
error_msg = self._handle_client_error(e)
raise RuntimeError(error_msg) from e
@@ -430,25 +736,27 @@ class BedrockCompletion(BaseLLM):
return converse_messages, system_message
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
@staticmethod
def _format_tools_for_converse(tools: list[dict]) -> list[ConverseToolTypeDef]:
"""Convert CrewAI tools to Converse API format following AWS specification."""
from crewai.llms.providers.utils.common import safe_tool_conversion
converse_tools = []
converse_tools: list[ConverseToolTypeDef] = []
for tool in tools:
try:
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
converse_tool = {
"toolSpec": {
"name": name,
"description": description,
}
tool_spec: ToolSpec = {
"name": name,
"description": description,
}
if parameters and isinstance(parameters, dict):
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
input_schema: ToolInputSchema = {"json": parameters}
tool_spec["inputSchema"] = input_schema
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}
converse_tools.append(converse_tool)
@@ -460,9 +768,9 @@ class BedrockCompletion(BaseLLM):
return converse_tools
def _get_inference_config(self) -> dict[str, Any]:
def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef:
"""Get inference configuration following AWS Converse API specification."""
config = {}
config: EnhancedInferenceConfigurationTypeDef = {}
if self.max_tokens:
config["maxTokens"] = self.max_tokens
@@ -503,7 +811,7 @@ class BedrockCompletion(BaseLLM):
return full_error_msg
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
def _track_token_usage_internal(self, usage: TokenUsageTypeDef) -> None: # type: ignore[override]
"""Track token usage from Bedrock response."""
input_tokens = usage.get("inputTokens", 0)
output_tokens = usage.get("outputTokens", 0)

View File

@@ -0,0 +1,733 @@
import os
import sys
import types
from unittest.mock import patch, MagicMock
import pytest
from crewai.llm import LLM
from crewai.crew import Crew
from crewai.agent import Agent
from crewai.task import Task
@pytest.fixture(autouse=True)
def mock_aws_credentials():
"""Automatically mock AWS credentials and boto3 Session for all tests in this module."""
with patch.dict(os.environ, {
"AWS_ACCESS_KEY_ID": "test-access-key",
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
"AWS_DEFAULT_REGION": "us-east-1"
}):
# Mock boto3 Session to prevent actual AWS connections
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
# Create mock session instance
mock_session_instance = MagicMock()
mock_client = MagicMock()
# Set up default mock responses to prevent hanging
default_response = {
'output': {
'message': {
'role': 'assistant',
'content': [
{'text': 'Test response'}
]
}
},
'usage': {
'inputTokens': 10,
'outputTokens': 5,
'totalTokens': 15
}
}
mock_client.converse.return_value = default_response
mock_client.converse_stream.return_value = {'stream': []}
# Configure the mock session instance to return the mock client
mock_session_instance.client.return_value = mock_client
# Configure the mock Session class to return the mock session instance
mock_session_class.return_value = mock_session_instance
yield mock_session_class, mock_client
def test_bedrock_completion_is_used_when_bedrock_provider():
"""
Test that BedrockCompletion from completion.py is used when LLM uses provider 'bedrock'
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
assert llm.__class__.__name__ == "BedrockCompletion"
assert llm.provider == "bedrock"
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
def test_bedrock_completion_module_is_imported():
"""
Test that the completion module is properly imported when using Bedrock provider
"""
module_name = "crewai.llms.providers.bedrock.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
del sys.modules[module_name]
# Create LLM instance - this should trigger the import
LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Verify the module was imported
assert module_name in sys.modules
completion_mod = sys.modules[module_name]
assert isinstance(completion_mod, types.ModuleType)
# Verify the class exists in the module
assert hasattr(completion_mod, 'BedrockCompletion')
def test_fallback_to_litellm_when_native_bedrock_fails():
"""
Test that LLM falls back to LiteLLM when native Bedrock completion fails
"""
# Mock the _get_native_provider to return a failing class
with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
class FailingCompletion:
def __init__(self, *args, **kwargs):
raise Exception("Native AWS Bedrock SDK failed")
mock_get_provider.return_value = FailingCompletion
# This should fall back to LiteLLM
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Check that it's using LiteLLM
assert hasattr(llm, 'is_litellm')
assert llm.is_litellm == True
def test_bedrock_completion_initialization_parameters():
"""
Test that BedrockCompletion is initialized with correct parameters
"""
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
temperature=0.7,
max_tokens=2000,
top_p=0.9,
top_k=40,
region_name="us-west-2"
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
assert llm.temperature == 0.7
assert llm.max_tokens == 2000
assert llm.top_p == 0.9
assert llm.top_k == 40
assert llm.region_name == "us-west-2"
def test_bedrock_specific_parameters():
"""
Test Bedrock-specific parameters like stop_sequences and streaming
"""
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
stop_sequences=["Human:", "Assistant:"],
stream=True,
region_name="us-east-1"
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.region_name == "us-east-1"
def test_bedrock_completion_call():
"""
Test that BedrockCompletion call method works
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the call method on the instance
with patch.object(llm, 'call', return_value="Hello! I'm Claude on Bedrock, ready to help.") as mock_call:
result = llm.call("Hello, how are you?")
assert result == "Hello! I'm Claude on Bedrock, ready to help."
mock_call.assert_called_once_with("Hello, how are you?")
def test_bedrock_completion_called_during_crew_execution():
"""
Test that BedrockCompletion.call is actually invoked when running a crew
"""
# Create the LLM instance first
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the call method on the specific instance
with patch.object(bedrock_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
# Create agent with explicit LLM configuration
agent = Agent(
role="Research Assistant",
goal="Find population info",
backstory="You research populations.",
llm=bedrock_llm,
)
task = Task(
description="Find Tokyo population",
expected_output="Population number",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
# Verify mock was called
assert mock_call.called
assert "14 million" in str(result)
@pytest.mark.skip(reason="Crew execution test - may hang, needs investigation")
def test_bedrock_completion_call_arguments():
"""
Test that BedrockCompletion.call is invoked with correct arguments
"""
# Create LLM instance first
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the instance method
with patch.object(bedrock_llm, 'call') as mock_call:
mock_call.return_value = "Task completed successfully."
agent = Agent(
role="Test Agent",
goal="Complete a simple task",
backstory="You are a test agent.",
llm=bedrock_llm # Use same instance
)
task = Task(
description="Say hello world",
expected_output="Hello world",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
# Verify call was made
assert mock_call.called
# Check the arguments passed to the call method
call_args = mock_call.call_args
assert call_args is not None
# The first argument should be the messages
messages = call_args[0][0] # First positional argument
assert isinstance(messages, (str, list))
# Verify that the task description appears in the messages
if isinstance(messages, str):
assert "hello world" in messages.lower()
elif isinstance(messages, list):
message_content = str(messages).lower()
assert "hello world" in message_content
def test_multiple_bedrock_calls_in_crew():
"""
Test that BedrockCompletion.call is invoked multiple times for multiple tasks
"""
# Create LLM instance first
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the instance method
with patch.object(bedrock_llm, 'call') as mock_call:
mock_call.return_value = "Task completed."
agent = Agent(
role="Multi-task Agent",
goal="Complete multiple tasks",
backstory="You can handle multiple tasks.",
llm=bedrock_llm # Use same instance
)
task1 = Task(
description="First task",
expected_output="First result",
agent=agent,
)
task2 = Task(
description="Second task",
expected_output="Second result",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task1, task2]
)
crew.kickoff()
# Verify multiple calls were made
assert mock_call.call_count >= 2 # At least one call per task
# Verify each call had proper arguments
for call in mock_call.call_args_list:
assert len(call[0]) > 0 # Has positional arguments
messages = call[0][0]
assert messages is not None
def test_bedrock_completion_with_tools():
"""
Test that BedrockCompletion.call is invoked with tools when agent has tools
"""
from crewai.tools import tool
@tool
def sample_tool(query: str) -> str:
"""A sample tool for testing"""
return f"Tool result for: {query}"
# Create LLM instance first
bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the instance method
with patch.object(bedrock_llm, 'call') as mock_call:
mock_call.return_value = "Task completed with tools."
agent = Agent(
role="Tool User",
goal="Use tools to complete tasks",
backstory="You can use tools.",
llm=bedrock_llm, # Use same instance
tools=[sample_tool]
)
task = Task(
description="Use the sample tool",
expected_output="Tool usage result",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
crew.kickoff()
assert mock_call.called
call_args = mock_call.call_args
call_kwargs = call_args[1] if len(call_args) > 1 else {}
if 'tools' in call_kwargs:
assert call_kwargs['tools'] is not None
assert len(call_kwargs['tools']) > 0
def test_bedrock_raises_error_when_model_not_found(mock_aws_credentials):
"""Test that BedrockCompletion raises appropriate error when model not found"""
from botocore.exceptions import ClientError
# Get the mock client from the fixture
_, mock_client = mock_aws_credentials
error_response = {
'Error': {
'Code': 'ResourceNotFoundException',
'Message': 'Could not resolve the foundation model from the model identifier'
}
}
mock_client.converse.side_effect = ClientError(error_response, 'converse')
llm = LLM(model="bedrock/model-doesnt-exist")
with pytest.raises(Exception): # Should raise some error for unsupported model
llm.call("Hello")
def test_bedrock_aws_credentials_configuration():
"""
Test that AWS credentials configuration works properly
"""
# Test with environment variables
with patch.dict(os.environ, {
"AWS_ACCESS_KEY_ID": "test-access-key",
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
"AWS_DEFAULT_REGION": "us-east-1"
}):
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.region_name == "us-east-1"
# Test with explicit credentials
llm_explicit = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
aws_access_key_id="explicit-key",
aws_secret_access_key="explicit-secret",
region_name="us-west-2"
)
assert isinstance(llm_explicit, BedrockCompletion)
assert llm_explicit.region_name == "us-west-2"
def test_bedrock_model_capabilities():
"""
Test that model capabilities are correctly identified
"""
# Test Claude model
llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm_claude, BedrockCompletion)
assert llm_claude.is_claude_model == True
assert llm_claude.supports_tools == True
# Test other Bedrock model
llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1")
assert isinstance(llm_titan, BedrockCompletion)
assert llm_titan.supports_tools == True
def test_bedrock_inference_config():
"""
Test that inference config is properly prepared
"""
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
temperature=0.7,
top_p=0.9,
top_k=40,
max_tokens=1000
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
# Test config preparation
config = llm._get_inference_config()
# Verify config has the expected parameters
assert 'temperature' in config
assert config['temperature'] == 0.7
assert 'topP' in config
assert config['topP'] == 0.9
assert 'maxTokens' in config
assert config['maxTokens'] == 1000
assert 'topK' in config
assert config['topK'] == 40
def test_bedrock_model_detection():
"""
Test that various Bedrock model formats are properly detected
"""
# Test Bedrock model naming patterns
bedrock_test_cases = [
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/amazon.titan-text-express-v1",
"bedrock/meta.llama3-70b-instruct-v1:0"
]
for model_name in bedrock_test_cases:
llm = LLM(model=model_name)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion), f"Failed for model: {model_name}"
def test_bedrock_supports_stop_words():
"""
Test that Bedrock models support stop sequences
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
assert llm.supports_stop_words() == True
def test_bedrock_context_window_size():
"""
Test that Bedrock models return correct context window sizes
"""
# Test Claude 3.5 Sonnet
llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
context_size_claude = llm_claude.get_context_window_size()
assert context_size_claude > 150000 # Should be substantial (200K tokens with ratio)
# Test Titan
llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1")
context_size_titan = llm_titan.get_context_window_size()
assert context_size_titan > 5000 # Should have 8K context window
def test_bedrock_message_formatting():
"""
Test that messages are properly formatted for Bedrock Converse API
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test message formatting
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
formatted_messages, system_message = llm._format_messages_for_converse(test_messages)
# System message should be extracted
assert system_message == "You are a helpful assistant."
# Remaining messages should be in Converse format
assert len(formatted_messages) >= 3 # Should have user, assistant, user messages
# First message should be user role
assert formatted_messages[0]["role"] == "user"
# Second should be assistant
assert formatted_messages[1]["role"] == "assistant"
# Messages should have content array with text
assert isinstance(formatted_messages[0]["content"], list)
assert "text" in formatted_messages[0]["content"][0]
def test_bedrock_streaming_parameter():
"""
Test that streaming parameter is properly handled
"""
# Test non-streaming
llm_no_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=False)
assert llm_no_stream.stream == False
# Test streaming
llm_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=True)
assert llm_stream.stream == True
def test_bedrock_tool_conversion():
"""
Test that tools are properly converted to Bedrock Converse format
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock tool in CrewAI format
crewai_tools = [{
"type": "function",
"function": {
"name": "test_tool",
"description": "A test tool",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"]
}
}
}]
# Test tool conversion
bedrock_tools = llm._format_tools_for_converse(crewai_tools)
assert len(bedrock_tools) == 1
# Bedrock tools should have toolSpec structure
assert "toolSpec" in bedrock_tools[0]
assert bedrock_tools[0]["toolSpec"]["name"] == "test_tool"
assert bedrock_tools[0]["toolSpec"]["description"] == "A test tool"
assert "inputSchema" in bedrock_tools[0]["toolSpec"]
def test_bedrock_environment_variable_credentials(mock_aws_credentials):
"""
Test that AWS credentials are properly loaded from environment
"""
mock_session_class, _ = mock_aws_credentials
# Reset the mock to clear any previous calls
mock_session_class.reset_mock()
with patch.dict(os.environ, {
"AWS_ACCESS_KEY_ID": "test-access-key-123",
"AWS_SECRET_ACCESS_KEY": "test-secret-key-456"
}):
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Verify Session was called with environment credentials
assert mock_session_class.called
# Get the most recent call - Session is called as Session(...)
call_kwargs = mock_session_class.call_args[1] if mock_session_class.call_args else {}
assert call_kwargs.get('aws_access_key_id') == "test-access-key-123"
assert call_kwargs.get('aws_secret_access_key') == "test-secret-key-456"
def test_bedrock_token_usage_tracking():
"""
Test that token usage is properly tracked for Bedrock responses
"""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the Bedrock response with usage information
with patch.object(llm.client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
'role': 'assistant',
'content': [
{'text': 'test response'}
]
}
},
'usage': {
'inputTokens': 50,
'outputTokens': 25,
'totalTokens': 75
}
}
mock_converse.return_value = mock_response
result = llm.call("Hello")
# Verify the response
assert result == "test response"
# Verify token usage was tracked
assert llm._token_usage['prompt_tokens'] == 50
assert llm._token_usage['completion_tokens'] == 25
assert llm._token_usage['total_tokens'] == 75
def test_bedrock_tool_use_conversation_flow():
"""
Test that the Bedrock completion properly handles tool use conversation flow
"""
from unittest.mock import Mock
# Create BedrockCompletion instance
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock tool function
def mock_weather_tool(location: str) -> str:
return f"The weather in {location} is sunny and 75°F"
available_functions = {"get_weather": mock_weather_tool}
# Mock the Bedrock client responses
with patch.object(llm.client, 'converse') as mock_converse:
# First response: tool use request
tool_use_response = {
'output': {
'message': {
'role': 'assistant',
'content': [
{
'toolUse': {
'toolUseId': 'tool-123',
'name': 'get_weather',
'input': {'location': 'San Francisco'}
}
}
]
}
},
'usage': {
'inputTokens': 100,
'outputTokens': 50,
'totalTokens': 150
}
}
# Second response: final answer after tool execution
final_response = {
'output': {
'message': {
'role': 'assistant',
'content': [
{'text': 'Based on the weather data, it is sunny and 75°F in San Francisco.'}
]
}
},
'usage': {
'inputTokens': 120,
'outputTokens': 30,
'totalTokens': 150
}
}
# Configure mock to return different responses on successive calls
mock_converse.side_effect = [tool_use_response, final_response]
# Test the call
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
result = llm.call(
messages=messages,
available_functions=available_functions
)
# Verify the final response contains the weather information
assert "sunny" in result.lower() or "75" in result
# Verify that the API was called twice (once for tool use, once for final answer)
assert mock_converse.call_count == 2
def test_bedrock_handles_cohere_conversation_requirements():
"""
Test that Bedrock properly handles Cohere model's requirement for user message at end
"""
llm = LLM(model="bedrock/cohere.command-r-plus-v1:0")
# Test message formatting with conversation ending in assistant message
test_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
formatted_messages, system_message = llm._format_messages_for_converse(test_messages)
# For Cohere models, should add a user message at the end
assert formatted_messages[-1]["role"] == "user"
assert "continue" in formatted_messages[-1]["content"][0]["text"].lower()
def test_bedrock_client_error_handling():
"""
Test that Bedrock properly handles various AWS client errors
"""
from botocore.exceptions import ClientError
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test ValidationException
with patch.object(llm.client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ValidationException',
'Message': 'Invalid request format'
}
}
mock_converse.side_effect = ClientError(error_response, 'converse')
with pytest.raises(ValueError) as exc_info:
llm.call("Hello")
assert "validation" in str(exc_info.value).lower()
# Test ThrottlingException
with patch.object(llm.client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ThrottlingException',
'Message': 'Rate limit exceeded'
}
}
mock_converse.side_effect = ClientError(error_response, 'converse')
with pytest.raises(RuntimeError) as exc_info:
llm.call("Hello")
assert "throttled" in str(exc_info.value).lower()

View File

@@ -25,6 +25,7 @@ dev = [
"types-pyyaml==6.0.*",
"types-regex==2024.11.6.*",
"types-appdirs==1.4.*",
"boto3-stubs[bedrock-runtime]>=1.40.54",
]

62
uv.lock generated
View File

@@ -38,6 +38,7 @@ members = [
[manifest.dependency-groups]
dev = [
{ name = "bandit", specifier = ">=1.8.6" },
{ name = "boto3-stubs", extras = ["bedrock-runtime"], specifier = ">=1.40.54" },
{ name = "mypy", specifier = ">=1.18.2" },
{ name = "pre-commit", specifier = ">=4.3.0" },
{ name = "pytest", specifier = ">=8.4.2" },
@@ -479,6 +480,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/db/7d3c27f530c2b354d546ad7fb94505be8b78a5ecabe34c6a1f9a9d6be03e/boto3-1.40.45-py3-none-any.whl", hash = "sha256:5b145752d20f29908e3cb8c823bee31c77e6bcf18787e570f36bbc545cc779ed", size = 139345, upload-time = "2025-10-03T19:32:11.145Z" },
]
[[package]]
name = "boto3-stubs"
version = "1.40.54"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore-stubs" },
{ name = "types-s3transfer" },
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e2/70/245477b7f07c9e1533c47fa69e611b172814423a6fd4637004f0d2a13b73/boto3_stubs-1.40.54.tar.gz", hash = "sha256:e21a9eda979a451935eb3196de3efbe15b9470e6bf9027406d1f6d0ac08b339e", size = 100919, upload-time = "2025-10-16T19:49:17.079Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9d/52/ee9dadd1cc8911e16f18ca9fa036a10328e0a0d3fddd54fadcc1ca0f9143/boto3_stubs-1.40.54-py3-none-any.whl", hash = "sha256:548a4786785ba7b43ef4ef1a2a764bebbb0301525f3201091fcf412e4c8ce323", size = 69712, upload-time = "2025-10-16T19:49:12.847Z" },
]
[package.optional-dependencies]
bedrock-runtime = [
{ name = "mypy-boto3-bedrock-runtime" },
]
[[package]]
name = "botocore"
version = "1.40.45"
@@ -494,6 +514,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/06/df47e2ecb74bd184c9d056666afd3db011a649eaca663337835a6dd5aee6/botocore-1.40.45-py3-none-any.whl", hash = "sha256:9abf473d8372ade8442c0d4634a9decb89c854d7862ffd5500574eb63ab8f240", size = 14063670, upload-time = "2025-10-03T19:31:58.999Z" },
]
[[package]]
name = "botocore-stubs"
version = "1.40.54"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "types-awscrt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ea/c0/3e78314f9baa850aae648fb6b2506748046e1c3e374d6bb3514478e34590/botocore_stubs-1.40.54.tar.gz", hash = "sha256:fb38a794ab2b896f9cc237ec725546746accaffd34f382475a8d1b98ca1078e1", size = 42225, upload-time = "2025-10-16T20:26:56.711Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2f/9f/ab316f57a7e32d4a5b790070ffa5986991098044897b08f1b65951bced2a/botocore_stubs-1.40.54-py3-none-any.whl", hash = "sha256:997e6f1c03e079c244caedf315f7a515a07480af9f93f53535e506f17cdbe880", size = 66542, upload-time = "2025-10-16T20:26:54.109Z" },
]
[[package]]
name = "browserbase"
version = "1.4.0"
@@ -4055,6 +4087,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
]
[[package]]
name = "mypy-boto3-bedrock-runtime"
version = "1.40.41"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c7/38/79989f7bce998776ed1a01c17f3f58e7bc6f5fc2bcbdff929701526fa2f1/mypy_boto3_bedrock_runtime-1.40.41.tar.gz", hash = "sha256:ee9bda6d6d478c8d0995e84e884bdf1798e150d437974ae27c175774a58ffaa5", size = 28333, upload-time = "2025-09-29T19:26:04.804Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/6c/d3431dadf473bb76aa590b1ed8cc91726a48b029b542eff9d3024f2d70b9/mypy_boto3_bedrock_runtime-1.40.41-py3-none-any.whl", hash = "sha256:d65dff200986ff06c6b3579ddcea102555f2067c8987fca379bf4f9ed8ba3121", size = 34181, upload-time = "2025-09-29T19:26:01.898Z" },
]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
@@ -8181,6 +8225,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cf/07/41f5b9b11f11855eb67760ed680330e0ce9136a44b51c24dd52edb1c4eb1/types_appdirs-1.4.3.5-py3-none-any.whl", hash = "sha256:337c750e423c40911d389359b4edabe5bbc2cdd5cd0bd0518b71d2839646273b", size = 2667, upload-time = "2023-03-14T15:21:32.431Z" },
]
[[package]]
name = "types-awscrt"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/60/19/a3a6377c9e2e389c1421c033a1830c29cac08f2e1e05a082ea84eb22c75f/types_awscrt-0.28.1.tar.gz", hash = "sha256:66d77ec283e1dc907526a44511a12624118723a396c36d3f3dd9855cb614ce14", size = 17410, upload-time = "2025-10-11T21:55:07.443Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ea/c7/0266b797d19b82aebe0e177efe35de7aabdc192bc1605ce3309331f0a505/types_awscrt-0.28.1-py3-none-any.whl", hash = "sha256:d88f43ef779f90b841ba99badb72fe153077225a4e426ae79e943184827b4443", size = 41851, upload-time = "2025-10-11T21:55:06.235Z" },
]
[[package]]
name = "types-pyyaml"
version = "6.0.12.20250915"
@@ -8251,6 +8304,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8b/ea/91b718b8c0b88e4f61cdd61357cc4a1f8767b32be691fb388299003a3ae3/types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5", size = 15347, upload-time = "2024-04-06T02:13:37.412Z" },
]
[[package]]
name = "types-s3transfer"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8e/9b/8913198b7fc700acc1dcb84827137bb2922052e43dde0f4fb0ed2dc6f118/types_s3transfer-0.14.0.tar.gz", hash = "sha256:17f800a87c7eafab0434e9d87452c809c290ae906c2024c24261c564479e9c95", size = 14218, upload-time = "2025-10-11T21:11:27.892Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/92/c3/4dfb2e87c15ca582b7d956dfb7e549de1d005c758eb9a305e934e1b83fda/types_s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:108134854069a38b048e9b710b9b35904d22a9d0f37e4e1889c2e6b58e5b3253", size = 19697, upload-time = "2025-10-11T21:11:26.749Z" },
]
[[package]]
name = "types-urllib3"
version = "1.26.25.14"