mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix Bedrock cross-region inference profile support
- Add support for model_id parameter in BedrockCompletion - When model_id is provided, use it instead of model for API calls - This fixes issue #3791 where cross-region inference profiles (which require ARN as model_id) were not working - Add comprehensive tests for cross-region inference profiles - All existing bedrock tests pass Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -175,8 +175,11 @@ class BedrockCompletion(BaseLLM):
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
**kwargs: Additional parameters
|
||||
**kwargs: Additional parameters (including model_id for cross-region inference)
|
||||
"""
|
||||
# Extract model_id from kwargs if provided (for cross-region inference profiles)
|
||||
custom_model_id = kwargs.pop("model_id", None)
|
||||
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
@@ -230,7 +233,7 @@ class BedrockCompletion(BaseLLM):
|
||||
self.supports_streaming = True
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
self.model_id = custom_model_id if custom_model_id else model
|
||||
|
||||
def call(
|
||||
self,
|
||||
|
||||
@@ -18,6 +18,7 @@ def mock_aws_credentials():
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
||||
"AWS_DEFAULT_REGION": "us-east-1"
|
||||
}):
|
||||
import crewai.llms.providers.bedrock.completion
|
||||
# Mock boto3 Session to prevent actual AWS connections
|
||||
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
|
||||
# Create mock session instance
|
||||
@@ -736,3 +737,76 @@ def test_bedrock_client_error_handling():
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
llm.call("Hello")
|
||||
assert "throttled" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_bedrock_cross_region_inference_profile():
|
||||
"""
|
||||
Test that Bedrock supports cross-region inference profiles with model_id parameter.
|
||||
|
||||
This tests the fix for issue #3791 where cross-region inference profiles
|
||||
(which require using ARN as model_id) were not working in version 1.20.0.
|
||||
|
||||
When using cross-region inference profiles, users need to:
|
||||
1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0")
|
||||
2. Set model_id to the inference profile ARN
|
||||
|
||||
The BedrockCompletion should use the model_id parameter when provided,
|
||||
not the model parameter, for the actual API call.
|
||||
"""
|
||||
# Test with cross-region inference profile ARN
|
||||
inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
model_id=inference_profile_arn,
|
||||
temperature=0.3,
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
|
||||
assert llm.model_id == inference_profile_arn
|
||||
assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
|
||||
# Verify that the client.converse call would use the correct model_id
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
mock_converse.return_value = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [{'text': 'Test response'}]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 10,
|
||||
'outputTokens': 5,
|
||||
'totalTokens': 15
|
||||
}
|
||||
}
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
# Verify the converse call was made with the inference profile ARN
|
||||
mock_converse.assert_called_once()
|
||||
call_kwargs = mock_converse.call_args[1]
|
||||
assert call_kwargs['modelId'] == inference_profile_arn
|
||||
|
||||
|
||||
def test_bedrock_model_id_parameter_takes_precedence():
|
||||
"""
|
||||
Test that when both model and model_id are provided, model_id takes precedence
|
||||
for the actual API call, while model is used for internal identification.
|
||||
"""
|
||||
custom_model_id = "custom-model-identifier"
|
||||
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
model_id=custom_model_id,
|
||||
)
|
||||
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
|
||||
assert llm.model_id == custom_model_id
|
||||
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
Reference in New Issue
Block a user