Fix Bedrock cross-region inference profile support

- Add support for model_id parameter in BedrockCompletion
- When model_id is provided, use it instead of model for API calls
- This fixes issue #3791 where cross-region inference profiles
  (which require ARN as model_id) were not working
- Add comprehensive tests for cross-region inference profiles
- All existing bedrock tests pass

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-24 18:21:38 +00:00
parent a83c57a2f2
commit 143d7b88a9
2 changed files with 79 additions and 2 deletions

View File

@@ -175,8 +175,11 @@ class BedrockCompletion(BaseLLM):
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
**kwargs: Additional parameters
**kwargs: Additional parameters (including model_id for cross-region inference)
"""
# Extract model_id from kwargs if provided (for cross-region inference profiles)
custom_model_id = kwargs.pop("model_id", None)
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
@@ -230,7 +233,7 @@ class BedrockCompletion(BaseLLM):
self.supports_streaming = True
# Handle inference profiles for newer models
self.model_id = model
self.model_id = custom_model_id if custom_model_id else model
def call(
self,

View File

@@ -18,6 +18,7 @@ def mock_aws_credentials():
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
"AWS_DEFAULT_REGION": "us-east-1"
}):
import crewai.llms.providers.bedrock.completion
# Mock boto3 Session to prevent actual AWS connections
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
# Create mock session instance
@@ -736,3 +737,76 @@ def test_bedrock_client_error_handling():
with pytest.raises(RuntimeError) as exc_info:
llm.call("Hello")
assert "throttled" in str(exc_info.value).lower()
def test_bedrock_cross_region_inference_profile():
"""
Test that Bedrock supports cross-region inference profiles with model_id parameter.
This tests the fix for issue #3791 where cross-region inference profiles
(which require using ARN as model_id) were not working in version 1.20.0.
When using cross-region inference profiles, users need to:
1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0")
2. Set model_id to the inference profile ARN
The BedrockCompletion should use the model_id parameter when provided,
not the model parameter, for the actual API call.
"""
# Test with cross-region inference profile ARN
inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0"
llm = LLM(
model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0",
model_id=inference_profile_arn,
temperature=0.3,
max_tokens=4000,
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model_id == inference_profile_arn
assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0"
# Verify that the client.converse call would use the correct model_id
with patch.object(llm.client, 'converse') as mock_converse:
mock_converse.return_value = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'Test response'}]
}
},
'usage': {
'inputTokens': 10,
'outputTokens': 5,
'totalTokens': 15
}
}
llm.call("Test message")
# Verify the converse call was made with the inference profile ARN
mock_converse.assert_called_once()
call_kwargs = mock_converse.call_args[1]
assert call_kwargs['modelId'] == inference_profile_arn
def test_bedrock_model_id_parameter_takes_precedence():
"""
Test that when both model and model_id are provided, model_id takes precedence
for the actual API call, while model is used for internal identification.
"""
custom_model_id = "custom-model-identifier"
llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
model_id=custom_model_id,
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model_id == custom_model_id
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"