mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Fix Bedrock cross-region inference profile support
- Add support for model_id parameter in BedrockCompletion - When model_id is provided, use it instead of model for API calls - This fixes issue #3791 where cross-region inference profiles (which require ARN as model_id) were not working - Add comprehensive tests for cross-region inference profiles - All existing bedrock tests pass Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -175,8 +175,11 @@ class BedrockCompletion(BaseLLM):
|
|||||||
guardrail_config: Guardrail configuration for content filtering
|
guardrail_config: Guardrail configuration for content filtering
|
||||||
additional_model_request_fields: Model-specific request parameters
|
additional_model_request_fields: Model-specific request parameters
|
||||||
additional_model_response_field_paths: Custom response field paths
|
additional_model_response_field_paths: Custom response field paths
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters (including model_id for cross-region inference)
|
||||||
"""
|
"""
|
||||||
|
# Extract model_id from kwargs if provided (for cross-region inference profiles)
|
||||||
|
custom_model_id = kwargs.pop("model_id", None)
|
||||||
|
|
||||||
# Extract provider from kwargs to avoid duplicate argument
|
# Extract provider from kwargs to avoid duplicate argument
|
||||||
kwargs.pop("provider", None)
|
kwargs.pop("provider", None)
|
||||||
|
|
||||||
@@ -230,7 +233,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
self.supports_streaming = True
|
self.supports_streaming = True
|
||||||
|
|
||||||
# Handle inference profiles for newer models
|
# Handle inference profiles for newer models
|
||||||
self.model_id = model
|
self.model_id = custom_model_id if custom_model_id else model
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ def mock_aws_credentials():
|
|||||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
||||||
"AWS_DEFAULT_REGION": "us-east-1"
|
"AWS_DEFAULT_REGION": "us-east-1"
|
||||||
}):
|
}):
|
||||||
|
import crewai.llms.providers.bedrock.completion
|
||||||
# Mock boto3 Session to prevent actual AWS connections
|
# Mock boto3 Session to prevent actual AWS connections
|
||||||
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
|
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
|
||||||
# Create mock session instance
|
# Create mock session instance
|
||||||
@@ -736,3 +737,76 @@ def test_bedrock_client_error_handling():
|
|||||||
with pytest.raises(RuntimeError) as exc_info:
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
llm.call("Hello")
|
llm.call("Hello")
|
||||||
assert "throttled" in str(exc_info.value).lower()
|
assert "throttled" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_cross_region_inference_profile():
|
||||||
|
"""
|
||||||
|
Test that Bedrock supports cross-region inference profiles with model_id parameter.
|
||||||
|
|
||||||
|
This tests the fix for issue #3791 where cross-region inference profiles
|
||||||
|
(which require using ARN as model_id) were not working in version 1.20.0.
|
||||||
|
|
||||||
|
When using cross-region inference profiles, users need to:
|
||||||
|
1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0")
|
||||||
|
2. Set model_id to the inference profile ARN
|
||||||
|
|
||||||
|
The BedrockCompletion should use the model_id parameter when provided,
|
||||||
|
not the model parameter, for the actual API call.
|
||||||
|
"""
|
||||||
|
# Test with cross-region inference profile ARN
|
||||||
|
inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
model_id=inference_profile_arn,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||||
|
assert isinstance(llm, BedrockCompletion)
|
||||||
|
|
||||||
|
assert llm.model_id == inference_profile_arn
|
||||||
|
assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
|
# Verify that the client.converse call would use the correct model_id
|
||||||
|
with patch.object(llm.client, 'converse') as mock_converse:
|
||||||
|
mock_converse.return_value = {
|
||||||
|
'output': {
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': [{'text': 'Test response'}]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'usage': {
|
||||||
|
'inputTokens': 10,
|
||||||
|
'outputTokens': 5,
|
||||||
|
'totalTokens': 15
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
# Verify the converse call was made with the inference profile ARN
|
||||||
|
mock_converse.assert_called_once()
|
||||||
|
call_kwargs = mock_converse.call_args[1]
|
||||||
|
assert call_kwargs['modelId'] == inference_profile_arn
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_model_id_parameter_takes_precedence():
|
||||||
|
"""
|
||||||
|
Test that when both model and model_id are provided, model_id takes precedence
|
||||||
|
for the actual API call, while model is used for internal identification.
|
||||||
|
"""
|
||||||
|
custom_model_id = "custom-model-identifier"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
model_id=custom_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||||
|
assert isinstance(llm, BedrockCompletion)
|
||||||
|
|
||||||
|
assert llm.model_id == custom_model_id
|
||||||
|
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
|||||||
Reference in New Issue
Block a user