diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 4607316c5..3e71bcddb 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -175,8 +175,11 @@ class BedrockCompletion(BaseLLM): guardrail_config: Guardrail configuration for content filtering additional_model_request_fields: Model-specific request parameters additional_model_response_field_paths: Custom response field paths - **kwargs: Additional parameters + **kwargs: Additional parameters (including model_id for cross-region inference) """ + # Extract model_id from kwargs if provided (for cross-region inference profiles) + custom_model_id = kwargs.pop("model_id", None) + # Extract provider from kwargs to avoid duplicate argument kwargs.pop("provider", None) @@ -230,7 +233,7 @@ class BedrockCompletion(BaseLLM): self.supports_streaming = True # Handle inference profiles for newer models - self.model_id = model + self.model_id = custom_model_id if custom_model_id else model def call( self, diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index 9fd172cc6..36e0434cb 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -18,6 +18,7 @@ def mock_aws_credentials(): "AWS_SECRET_ACCESS_KEY": "test-secret-key", "AWS_DEFAULT_REGION": "us-east-1" }): + import crewai.llms.providers.bedrock.completion # Mock boto3 Session to prevent actual AWS connections with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class: # Create mock session instance @@ -736,3 +737,76 @@ def test_bedrock_client_error_handling(): with pytest.raises(RuntimeError) as exc_info: llm.call("Hello") assert "throttled" in str(exc_info.value).lower() + + +def test_bedrock_cross_region_inference_profile(): + """ + Test that Bedrock supports cross-region inference profiles with model_id parameter. + + This tests the fix for issue #3791 where cross-region inference profiles + (which require using ARN as model_id) were not working in version 1.20.0. + + When using cross-region inference profiles, users need to: + 1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0") + 2. Set model_id to the inference profile ARN + + The BedrockCompletion should use the model_id parameter when provided, + not the model parameter, for the actual API call. + """ + # Test with cross-region inference profile ARN + inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0" + + llm = LLM( + model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0", + model_id=inference_profile_arn, + temperature=0.3, + max_tokens=4000, + ) + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + + assert llm.model_id == inference_profile_arn + assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0" + + # Verify that the client.converse call would use the correct model_id + with patch.object(llm.client, 'converse') as mock_converse: + mock_converse.return_value = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [{'text': 'Test response'}] + } + }, + 'usage': { + 'inputTokens': 10, + 'outputTokens': 5, + 'totalTokens': 15 + } + } + + llm.call("Test message") + + # Verify the converse call was made with the inference profile ARN + mock_converse.assert_called_once() + call_kwargs = mock_converse.call_args[1] + assert call_kwargs['modelId'] == inference_profile_arn + + +def test_bedrock_model_id_parameter_takes_precedence(): + """ + Test that when both model and model_id are provided, model_id takes precedence + for the actual API call, while model is used for internal identification. + """ + custom_model_id = "custom-model-identifier" + + llm = LLM( + model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + model_id=custom_model_id, + ) + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + + assert llm.model_id == custom_model_id + assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"