diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 615d8838d..155090d58 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence import logging import os -from typing import Any +from typing import TYPE_CHECKING, Any, TypedDict, cast + +from typing_extensions import Required from crewai.events.types.llm_events import LLMCallType from crewai.llms.base_llm import BaseLLM @@ -11,6 +15,20 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( ) +if TYPE_CHECKING: + from mypy_boto3_bedrock_runtime.type_defs import ( + GuardrailConfigurationTypeDef, + GuardrailStreamConfigurationTypeDef, + InferenceConfigurationTypeDef, + MessageOutputTypeDef, + MessageTypeDef, + SystemContentBlockTypeDef, + TokenUsageTypeDef, + ToolConfigurationTypeDef, + ToolTypeDef, + ) + + try: from boto3.session import Session from botocore.config import Config @@ -21,11 +39,104 @@ except ImportError: ) from None +if TYPE_CHECKING: + + class EnhancedInferenceConfigurationTypeDef( + InferenceConfigurationTypeDef, total=False + ): + """Extended InferenceConfigurationTypeDef with topK support. + + AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs. + This extends the base type to include topK while maintaining all other fields. + """ + + topK: int # noqa: N815 - AWS API uses topK naming + +else: + + class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False): + """Extended InferenceConfigurationTypeDef with topK support. + + AWS Bedrock supports topK for Claude models, but it's not in the boto3 type stubs. + This extends the base type to include topK while maintaining all other fields. + """ + + maxTokens: int + temperature: float + topP: float # noqa: N815 - AWS API uses topP naming + stopSequences: list[str] + topK: int # noqa: N815 - AWS API uses topK naming + + +class ToolInputSchema(TypedDict): + """Type definition for tool input schema in Converse API.""" + + json: dict[str, Any] + + +class ToolSpec(TypedDict, total=False): + """Type definition for tool specification in Converse API.""" + + name: Required[str] + description: Required[str] + inputSchema: ToolInputSchema + + +class ConverseToolTypeDef(TypedDict): + """Type definition for a Converse API tool.""" + + toolSpec: ToolSpec + + +class BedrockConverseRequestBody(TypedDict, total=False): + """Type definition for AWS Bedrock Converse API request body. + + Based on AWS Bedrock Converse API specification. + """ + + inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef] + system: list[SystemContentBlockTypeDef] + toolConfig: ToolConfigurationTypeDef + guardrailConfig: GuardrailConfigurationTypeDef + additionalModelRequestFields: dict[str, Any] + additionalModelResponseFieldPaths: list[str] + + +class BedrockConverseStreamRequestBody(TypedDict, total=False): + """Type definition for AWS Bedrock Converse Stream API request body. + + Based on AWS Bedrock Converse Stream API specification. + """ + + inferenceConfig: Required[EnhancedInferenceConfigurationTypeDef] + system: list[SystemContentBlockTypeDef] + toolConfig: ToolConfigurationTypeDef + guardrailConfig: GuardrailStreamConfigurationTypeDef + additionalModelRequestFields: dict[str, Any] + additionalModelResponseFieldPaths: list[str] + + class BedrockCompletion(BaseLLM): """AWS Bedrock native completion implementation using the Converse API. This class provides direct integration with AWS Bedrock using the modern Converse API, which provides a unified interface across all Bedrock models. + + Features: + - Full tool calling support with proper conversation continuation + - Streaming and non-streaming responses with comprehensive event handling + - Guardrail configuration for content filtering + - Model-specific parameters via additionalModelRequestFields + - Custom response field extraction + - Proper error handling for all AWS exception types + - Token usage tracking and stop reason logging + - Support for both text and tool use content blocks + + The implementation follows AWS Bedrock Converse API best practices including: + - Proper tool use ID tracking for multi-turn tool conversations + - Complete streaming event handling (messageStart, contentBlockStart, etc.) + - Response metadata and trace information capture + - Model-specific conversation format handling (e.g., Cohere requirements) """ def __init__( @@ -41,9 +152,30 @@ class BedrockCompletion(BaseLLM): top_k: int | None = None, stop_sequences: Sequence[str] | None = None, stream: bool = False, + guardrail_config: dict[str, Any] | None = None, + additional_model_request_fields: dict[str, Any] | None = None, + additional_model_response_field_paths: list[str] | None = None, **kwargs, ): - """Initialize AWS Bedrock completion client.""" + """Initialize AWS Bedrock completion client. + + Args: + model: The Bedrock model ID to use + aws_access_key_id: AWS access key (defaults to environment variable) + aws_secret_access_key: AWS secret key (defaults to environment variable) + aws_session_token: AWS session token for temporary credentials + region_name: AWS region name + temperature: Sampling temperature for response generation + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter (Claude models only) + stop_sequences: List of sequences that stop generation + stream: Whether to use streaming responses + guardrail_config: Guardrail configuration for content filtering + additional_model_request_fields: Model-specific request parameters + additional_model_response_field_paths: Custom response field paths + **kwargs: Additional parameters + """ # Extract provider from kwargs to avoid duplicate argument kwargs.pop("provider", None) @@ -66,7 +198,6 @@ class BedrockCompletion(BaseLLM): # Configure client with timeouts and retries following AWS best practices config = Config( - connect_timeout=60, read_timeout=300, retries={ "max_attempts": 3, @@ -85,6 +216,13 @@ class BedrockCompletion(BaseLLM): self.stream = stream self.stop_sequences = stop_sequences or [] + # Store advanced features (optional) + self.guardrail_config = guardrail_config + self.additional_model_request_fields = additional_model_request_fields + self.additional_model_response_field_paths = ( + additional_model_response_field_paths + ) + # Model-specific settings self.is_claude_model = "claude" in model.lower() self.supports_tools = True # Converse API supports tools for most models @@ -96,7 +234,7 @@ class BedrockCompletion(BaseLLM): def call( self, messages: str | list[dict[str, str]], - tools: Sequence[Mapping[str, Any]] | None = None, + tools: list[dict[Any, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, from_task: Any | None = None, @@ -119,24 +257,45 @@ class BedrockCompletion(BaseLLM): messages ) - # Prepare tool configuration - tool_config = None - if tools: - tool_config = {"tools": self._format_tools_for_converse(tools)} - # Prepare request body - body = { + body: BedrockConverseRequestBody = { "inferenceConfig": self._get_inference_config(), } # Add system message if present if system_message: - body["system"] = [{"text": system_message}] + body["system"] = cast( + "list[SystemContentBlockTypeDef]", + cast(object, [{"text": system_message}]), + ) # Add tool config if present - if tool_config: + if tools: + tool_config: ToolConfigurationTypeDef = { + "tools": cast( + "Sequence[ToolTypeDef]", + cast(object, self._format_tools_for_converse(tools)), + ) + } body["toolConfig"] = tool_config + # Add optional advanced features if configured + if self.guardrail_config: + guardrail_config: GuardrailConfigurationTypeDef = cast( + "GuardrailConfigurationTypeDef", cast(object, self.guardrail_config) + ) + body["guardrailConfig"] = guardrail_config + + if self.additional_model_request_fields: + body["additionalModelRequestFields"] = ( + self.additional_model_request_fields + ) + + if self.additional_model_response_field_paths: + body["additionalModelResponseFieldPaths"] = ( + self.additional_model_response_field_paths + ) + if self.stream: return self._handle_streaming_converse( formatted_messages, body, available_functions, from_task, from_agent @@ -161,7 +320,7 @@ class BedrockCompletion(BaseLLM): def _handle_converse( self, messages: list[dict[str, Any]], - body: dict[str, Any], + body: BedrockConverseRequestBody, available_functions: Mapping[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, @@ -183,13 +342,26 @@ class BedrockCompletion(BaseLLM): # Call Bedrock Converse API with proper error handling response = self.client.converse( - modelId=self.model_id, messages=messages, **body + modelId=self.model_id, + messages=cast( + "Sequence[MessageTypeDef | MessageOutputTypeDef]", + cast(object, messages), + ), + **body, ) # Track token usage according to AWS response format if "usage" in response: self._track_token_usage_internal(response["usage"]) + stop_reason = response.get("stopReason") + if stop_reason: + logging.debug(f"Response stop reason: {stop_reason}") + if stop_reason == "max_tokens": + logging.warning("Response truncated due to max_tokens limit") + elif stop_reason == "content_filtered": + logging.warning("Response was filtered due to content policy") + # Extract content following AWS response structure output = response.get("output", {}) message = output.get("message", {}) @@ -201,28 +373,59 @@ class BedrockCompletion(BaseLLM): "I apologize, but I received an empty response. Please try again." ) - # Extract text content from response + # Process content blocks and handle tool use correctly text_content = "" + for content_block in content: - # Handle different content block types as per AWS documentation + # Handle text content if "text" in content_block: text_content += content_block["text"] - elif content_block.get("type") == "toolUse" and available_functions: - # Handle tool use according to AWS format - tool_use = content_block["toolUse"] - function_name = tool_use.get("name") - function_args = tool_use.get("input", {}) - result = self._handle_tool_execution( + # Handle tool use - corrected structure according to AWS API docs + elif "toolUse" in content_block and available_functions: + tool_use_block = content_block["toolUse"] + tool_use_id = tool_use_block.get("toolUseId") + function_name = tool_use_block["name"] + function_args = tool_use_block.get("input", {}) + + logging.debug( + f"Tool use requested: {function_name} with ID {tool_use_id}" + ) + + # Execute the tool + tool_result = self._handle_tool_execution( function_name=function_name, function_args=function_args, - available_functions=available_functions, + available_functions=dict(available_functions), from_task=from_task, from_agent=from_agent, ) - if result is not None: - return result + if tool_result is not None: + messages.append( + { + "role": "assistant", + "content": [{"toolUse": tool_use_block}], + } + ) + + messages.append( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": tool_use_id, + "content": [{"text": str(tool_result)}], + } + } + ], + } + ) + + return self._handle_converse( + messages, body, available_functions, from_task, from_agent + ) # Apply stop sequences if configured text_content = self._apply_stop_words(text_content) @@ -298,23 +501,43 @@ class BedrockCompletion(BaseLLM): def _handle_streaming_converse( self, messages: list[dict[str, Any]], - body: dict[str, Any], + body: BedrockConverseRequestBody, available_functions: dict[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, ) -> str: - """Handle streaming converse API call.""" + """Handle streaming converse API call with comprehensive event handling.""" full_response = "" + current_tool_use = None + tool_use_id = None try: response = self.client.converse_stream( - modelId=self.model_id, messages=messages, **body + modelId=self.model_id, + messages=cast( + "Sequence[MessageTypeDef | MessageOutputTypeDef]", + cast(object, messages), + ), + **body, # type: ignore[arg-type] ) stream = response.get("stream") if stream: for event in stream: - if "contentBlockDelta" in event: + if "messageStart" in event: + role = event["messageStart"].get("role") + logging.debug(f"Streaming message started with role: {role}") + + elif "contentBlockStart" in event: + start = event["contentBlockStart"].get("start", {}) + if "toolUse" in start: + current_tool_use = start["toolUse"] + tool_use_id = current_tool_use.get("toolUseId") + logging.debug( + f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})" + ) + + elif "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: text_chunk = delta["text"] @@ -325,10 +548,93 @@ class BedrockCompletion(BaseLLM): from_task=from_task, from_agent=from_agent, ) + elif "toolUse" in delta and current_tool_use: + tool_input = delta["toolUse"].get("input", "") + if tool_input: + logging.debug(f"Tool input delta: {tool_input}") + + # Content block stop - end of a content block + elif "contentBlockStop" in event: + logging.debug("Content block stopped in stream") + # If we were accumulating a tool use, it's now complete + if current_tool_use and available_functions: + function_name = current_tool_use["name"] + function_args = cast( + dict[str, Any], current_tool_use.get("input", {}) + ) + + # Execute tool + tool_result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if tool_result is not None and tool_use_id: + # Continue conversation with tool result + messages.append( + { + "role": "assistant", + "content": [{"toolUse": current_tool_use}], + } + ) + + messages.append( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": tool_use_id, + "content": [ + {"text": str(tool_result)} + ], + } + } + ], + } + ) + + # Recursive call - note this switches to non-streaming + return self._handle_converse( + messages, + body, + available_functions, + from_task, + from_agent, + ) + + current_tool_use = None + tool_use_id = None + + # Message stop - end of entire message elif "messageStop" in event: - # Handle end of message + stop_reason = event["messageStop"].get("stopReason") + logging.debug(f"Streaming message stopped: {stop_reason}") + if stop_reason == "max_tokens": + logging.warning( + "Streaming response truncated due to max_tokens" + ) + elif stop_reason == "content_filtered": + logging.warning( + "Streaming response filtered due to content policy" + ) break + # Metadata - contains usage information and trace details + elif "metadata" in event: + metadata = event["metadata"] + if "usage" in metadata: + usage_metrics = metadata["usage"] + self._track_token_usage_internal(usage_metrics) + logging.debug(f"Token usage: {usage_metrics}") + if "trace" in metadata: + logging.debug( + f"Trace information available: {metadata['trace']}" + ) + except ClientError as e: error_msg = self._handle_client_error(e) raise RuntimeError(error_msg) from e @@ -430,25 +736,27 @@ class BedrockCompletion(BaseLLM): return converse_messages, system_message - def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]: + @staticmethod + def _format_tools_for_converse(tools: list[dict]) -> list[ConverseToolTypeDef]: """Convert CrewAI tools to Converse API format following AWS specification.""" from crewai.llms.providers.utils.common import safe_tool_conversion - converse_tools = [] + converse_tools: list[ConverseToolTypeDef] = [] for tool in tools: try: name, description, parameters = safe_tool_conversion(tool, "Bedrock") - converse_tool = { - "toolSpec": { - "name": name, - "description": description, - } + tool_spec: ToolSpec = { + "name": name, + "description": description, } if parameters and isinstance(parameters, dict): - converse_tool["toolSpec"]["inputSchema"] = {"json": parameters} + input_schema: ToolInputSchema = {"json": parameters} + tool_spec["inputSchema"] = input_schema + + converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec} converse_tools.append(converse_tool) @@ -460,9 +768,9 @@ class BedrockCompletion(BaseLLM): return converse_tools - def _get_inference_config(self) -> dict[str, Any]: + def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef: """Get inference configuration following AWS Converse API specification.""" - config = {} + config: EnhancedInferenceConfigurationTypeDef = {} if self.max_tokens: config["maxTokens"] = self.max_tokens @@ -503,7 +811,7 @@ class BedrockCompletion(BaseLLM): return full_error_msg - def _track_token_usage_internal(self, usage: dict[str, Any]) -> None: + def _track_token_usage_internal(self, usage: TokenUsageTypeDef) -> None: # type: ignore[override] """Track token usage from Bedrock response.""" input_tokens = usage.get("inputTokens", 0) output_tokens = usage.get("outputTokens", 0) diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py new file mode 100644 index 000000000..eb9bbf3d4 --- /dev/null +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -0,0 +1,733 @@ +import os +import sys +import types +from unittest.mock import patch, MagicMock +import pytest + +from crewai.llm import LLM +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task + + +@pytest.fixture(autouse=True) +def mock_aws_credentials(): + """Automatically mock AWS credentials and boto3 Session for all tests in this module.""" + with patch.dict(os.environ, { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_DEFAULT_REGION": "us-east-1" + }): + # Mock boto3 Session to prevent actual AWS connections + with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class: + # Create mock session instance + mock_session_instance = MagicMock() + mock_client = MagicMock() + + # Set up default mock responses to prevent hanging + default_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [ + {'text': 'Test response'} + ] + } + }, + 'usage': { + 'inputTokens': 10, + 'outputTokens': 5, + 'totalTokens': 15 + } + } + mock_client.converse.return_value = default_response + mock_client.converse_stream.return_value = {'stream': []} + + # Configure the mock session instance to return the mock client + mock_session_instance.client.return_value = mock_client + + # Configure the mock Session class to return the mock session instance + mock_session_class.return_value = mock_session_instance + + yield mock_session_class, mock_client + + +def test_bedrock_completion_is_used_when_bedrock_provider(): + """ + Test that BedrockCompletion from completion.py is used when LLM uses provider 'bedrock' + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + assert llm.__class__.__name__ == "BedrockCompletion" + assert llm.provider == "bedrock" + assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0" + + +def test_bedrock_completion_module_is_imported(): + """ + Test that the completion module is properly imported when using Bedrock provider + """ + module_name = "crewai.llms.providers.bedrock.completion" + + # Remove module from cache if it exists + if module_name in sys.modules: + del sys.modules[module_name] + + # Create LLM instance - this should trigger the import + LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Verify the module was imported + assert module_name in sys.modules + completion_mod = sys.modules[module_name] + assert isinstance(completion_mod, types.ModuleType) + + # Verify the class exists in the module + assert hasattr(completion_mod, 'BedrockCompletion') + + +def test_fallback_to_litellm_when_native_bedrock_fails(): + """ + Test that LLM falls back to LiteLLM when native Bedrock completion fails + """ + # Mock the _get_native_provider to return a failing class + with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider: + + class FailingCompletion: + def __init__(self, *args, **kwargs): + raise Exception("Native AWS Bedrock SDK failed") + + mock_get_provider.return_value = FailingCompletion + + # This should fall back to LiteLLM + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Check that it's using LiteLLM + assert hasattr(llm, 'is_litellm') + assert llm.is_litellm == True + + +def test_bedrock_completion_initialization_parameters(): + """ + Test that BedrockCompletion is initialized with correct parameters + """ + llm = LLM( + model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + temperature=0.7, + max_tokens=2000, + top_p=0.9, + top_k=40, + region_name="us-west-2" + ) + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0" + assert llm.temperature == 0.7 + assert llm.max_tokens == 2000 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.region_name == "us-west-2" + + +def test_bedrock_specific_parameters(): + """ + Test Bedrock-specific parameters like stop_sequences and streaming + """ + llm = LLM( + model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + stop_sequences=["Human:", "Assistant:"], + stream=True, + region_name="us-east-1" + ) + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stream == True + assert llm.region_name == "us-east-1" + + +def test_bedrock_completion_call(): + """ + Test that BedrockCompletion call method works + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the call method on the instance + with patch.object(llm, 'call', return_value="Hello! I'm Claude on Bedrock, ready to help.") as mock_call: + result = llm.call("Hello, how are you?") + + assert result == "Hello! I'm Claude on Bedrock, ready to help." + mock_call.assert_called_once_with("Hello, how are you?") + + +def test_bedrock_completion_called_during_crew_execution(): + """ + Test that BedrockCompletion.call is actually invoked when running a crew + """ + # Create the LLM instance first + bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the call method on the specific instance + with patch.object(bedrock_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call: + + # Create agent with explicit LLM configuration + agent = Agent( + role="Research Assistant", + goal="Find population info", + backstory="You research populations.", + llm=bedrock_llm, + ) + + task = Task( + description="Find Tokyo population", + expected_output="Population number", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + result = crew.kickoff() + + # Verify mock was called + assert mock_call.called + assert "14 million" in str(result) + + +@pytest.mark.skip(reason="Crew execution test - may hang, needs investigation") +def test_bedrock_completion_call_arguments(): + """ + Test that BedrockCompletion.call is invoked with correct arguments + """ + # Create LLM instance first + bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the instance method + with patch.object(bedrock_llm, 'call') as mock_call: + mock_call.return_value = "Task completed successfully." + + agent = Agent( + role="Test Agent", + goal="Complete a simple task", + backstory="You are a test agent.", + llm=bedrock_llm # Use same instance + ) + + task = Task( + description="Say hello world", + expected_output="Hello world", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + # Verify call was made + assert mock_call.called + + # Check the arguments passed to the call method + call_args = mock_call.call_args + assert call_args is not None + + # The first argument should be the messages + messages = call_args[0][0] # First positional argument + assert isinstance(messages, (str, list)) + + # Verify that the task description appears in the messages + if isinstance(messages, str): + assert "hello world" in messages.lower() + elif isinstance(messages, list): + message_content = str(messages).lower() + assert "hello world" in message_content + + +def test_multiple_bedrock_calls_in_crew(): + """ + Test that BedrockCompletion.call is invoked multiple times for multiple tasks + """ + # Create LLM instance first + bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the instance method + with patch.object(bedrock_llm, 'call') as mock_call: + mock_call.return_value = "Task completed." + + agent = Agent( + role="Multi-task Agent", + goal="Complete multiple tasks", + backstory="You can handle multiple tasks.", + llm=bedrock_llm # Use same instance + ) + + task1 = Task( + description="First task", + expected_output="First result", + agent=agent, + ) + + task2 = Task( + description="Second task", + expected_output="Second result", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task1, task2] + ) + crew.kickoff() + + # Verify multiple calls were made + assert mock_call.call_count >= 2 # At least one call per task + + # Verify each call had proper arguments + for call in mock_call.call_args_list: + assert len(call[0]) > 0 # Has positional arguments + messages = call[0][0] + assert messages is not None + +def test_bedrock_completion_with_tools(): + """ + Test that BedrockCompletion.call is invoked with tools when agent has tools + """ + from crewai.tools import tool + + @tool + def sample_tool(query: str) -> str: + """A sample tool for testing""" + return f"Tool result for: {query}" + + # Create LLM instance first + bedrock_llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the instance method + with patch.object(bedrock_llm, 'call') as mock_call: + mock_call.return_value = "Task completed with tools." + + agent = Agent( + role="Tool User", + goal="Use tools to complete tasks", + backstory="You can use tools.", + llm=bedrock_llm, # Use same instance + tools=[sample_tool] + ) + + task = Task( + description="Use the sample tool", + expected_output="Tool usage result", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + + crew.kickoff() + + assert mock_call.called + + call_args = mock_call.call_args + call_kwargs = call_args[1] if len(call_args) > 1 else {} + + if 'tools' in call_kwargs: + assert call_kwargs['tools'] is not None + assert len(call_kwargs['tools']) > 0 + + +def test_bedrock_raises_error_when_model_not_found(mock_aws_credentials): + """Test that BedrockCompletion raises appropriate error when model not found""" + from botocore.exceptions import ClientError + + # Get the mock client from the fixture + _, mock_client = mock_aws_credentials + + error_response = { + 'Error': { + 'Code': 'ResourceNotFoundException', + 'Message': 'Could not resolve the foundation model from the model identifier' + } + } + mock_client.converse.side_effect = ClientError(error_response, 'converse') + + llm = LLM(model="bedrock/model-doesnt-exist") + + with pytest.raises(Exception): # Should raise some error for unsupported model + llm.call("Hello") + + +def test_bedrock_aws_credentials_configuration(): + """ + Test that AWS credentials configuration works properly + """ + # Test with environment variables + with patch.dict(os.environ, { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_DEFAULT_REGION": "us-east-1" + }): + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + assert llm.region_name == "us-east-1" + + # Test with explicit credentials + llm_explicit = LLM( + model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + aws_access_key_id="explicit-key", + aws_secret_access_key="explicit-secret", + region_name="us-west-2" + ) + assert isinstance(llm_explicit, BedrockCompletion) + assert llm_explicit.region_name == "us-west-2" + + +def test_bedrock_model_capabilities(): + """ + Test that model capabilities are correctly identified + """ + # Test Claude model + llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm_claude, BedrockCompletion) + assert llm_claude.is_claude_model == True + assert llm_claude.supports_tools == True + + # Test other Bedrock model + llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1") + assert isinstance(llm_titan, BedrockCompletion) + assert llm_titan.supports_tools == True + + +def test_bedrock_inference_config(): + """ + Test that inference config is properly prepared + """ + llm = LLM( + model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + temperature=0.7, + top_p=0.9, + top_k=40, + max_tokens=1000 + ) + + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion) + + # Test config preparation + config = llm._get_inference_config() + + # Verify config has the expected parameters + assert 'temperature' in config + assert config['temperature'] == 0.7 + assert 'topP' in config + assert config['topP'] == 0.9 + assert 'maxTokens' in config + assert config['maxTokens'] == 1000 + assert 'topK' in config + assert config['topK'] == 40 + + +def test_bedrock_model_detection(): + """ + Test that various Bedrock model formats are properly detected + """ + # Test Bedrock model naming patterns + bedrock_test_cases = [ + "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/amazon.titan-text-express-v1", + "bedrock/meta.llama3-70b-instruct-v1:0" + ] + + for model_name in bedrock_test_cases: + llm = LLM(model=model_name) + from crewai.llms.providers.bedrock.completion import BedrockCompletion + assert isinstance(llm, BedrockCompletion), f"Failed for model: {model_name}" + + +def test_bedrock_supports_stop_words(): + """ + Test that Bedrock models support stop sequences + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + assert llm.supports_stop_words() == True + + +def test_bedrock_context_window_size(): + """ + Test that Bedrock models return correct context window sizes + """ + # Test Claude 3.5 Sonnet + llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + context_size_claude = llm_claude.get_context_window_size() + assert context_size_claude > 150000 # Should be substantial (200K tokens with ratio) + + # Test Titan + llm_titan = LLM(model="bedrock/amazon.titan-text-express-v1") + context_size_titan = llm_titan.get_context_window_size() + assert context_size_titan > 5000 # Should have 8K context window + + +def test_bedrock_message_formatting(): + """ + Test that messages are properly formatted for Bedrock Converse API + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Test message formatting + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + formatted_messages, system_message = llm._format_messages_for_converse(test_messages) + + # System message should be extracted + assert system_message == "You are a helpful assistant." + + # Remaining messages should be in Converse format + assert len(formatted_messages) >= 3 # Should have user, assistant, user messages + + # First message should be user role + assert formatted_messages[0]["role"] == "user" + # Second should be assistant + assert formatted_messages[1]["role"] == "assistant" + + # Messages should have content array with text + assert isinstance(formatted_messages[0]["content"], list) + assert "text" in formatted_messages[0]["content"][0] + + +def test_bedrock_streaming_parameter(): + """ + Test that streaming parameter is properly handled + """ + # Test non-streaming + llm_no_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=False) + assert llm_no_stream.stream == False + + # Test streaming + llm_stream = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stream=True) + assert llm_stream.stream == True + + +def test_bedrock_tool_conversion(): + """ + Test that tools are properly converted to Bedrock Converse format + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock tool in CrewAI format + crewai_tools = [{ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + } + }] + + # Test tool conversion + bedrock_tools = llm._format_tools_for_converse(crewai_tools) + + assert len(bedrock_tools) == 1 + # Bedrock tools should have toolSpec structure + assert "toolSpec" in bedrock_tools[0] + assert bedrock_tools[0]["toolSpec"]["name"] == "test_tool" + assert bedrock_tools[0]["toolSpec"]["description"] == "A test tool" + assert "inputSchema" in bedrock_tools[0]["toolSpec"] + + +def test_bedrock_environment_variable_credentials(mock_aws_credentials): + """ + Test that AWS credentials are properly loaded from environment + """ + mock_session_class, _ = mock_aws_credentials + + # Reset the mock to clear any previous calls + mock_session_class.reset_mock() + + with patch.dict(os.environ, { + "AWS_ACCESS_KEY_ID": "test-access-key-123", + "AWS_SECRET_ACCESS_KEY": "test-secret-key-456" + }): + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Verify Session was called with environment credentials + assert mock_session_class.called + # Get the most recent call - Session is called as Session(...) + call_kwargs = mock_session_class.call_args[1] if mock_session_class.call_args else {} + assert call_kwargs.get('aws_access_key_id') == "test-access-key-123" + assert call_kwargs.get('aws_secret_access_key') == "test-secret-key-456" + + +def test_bedrock_token_usage_tracking(): + """ + Test that token usage is properly tracked for Bedrock responses + """ + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock the Bedrock response with usage information + with patch.object(llm.client, 'converse') as mock_converse: + mock_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [ + {'text': 'test response'} + ] + } + }, + 'usage': { + 'inputTokens': 50, + 'outputTokens': 25, + 'totalTokens': 75 + } + } + mock_converse.return_value = mock_response + + result = llm.call("Hello") + + # Verify the response + assert result == "test response" + + # Verify token usage was tracked + assert llm._token_usage['prompt_tokens'] == 50 + assert llm._token_usage['completion_tokens'] == 25 + assert llm._token_usage['total_tokens'] == 75 + + +def test_bedrock_tool_use_conversation_flow(): + """ + Test that the Bedrock completion properly handles tool use conversation flow + """ + from unittest.mock import Mock + + # Create BedrockCompletion instance + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Mock tool function + def mock_weather_tool(location: str) -> str: + return f"The weather in {location} is sunny and 75°F" + + available_functions = {"get_weather": mock_weather_tool} + + # Mock the Bedrock client responses + with patch.object(llm.client, 'converse') as mock_converse: + # First response: tool use request + tool_use_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [ + { + 'toolUse': { + 'toolUseId': 'tool-123', + 'name': 'get_weather', + 'input': {'location': 'San Francisco'} + } + } + ] + } + }, + 'usage': { + 'inputTokens': 100, + 'outputTokens': 50, + 'totalTokens': 150 + } + } + + # Second response: final answer after tool execution + final_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [ + {'text': 'Based on the weather data, it is sunny and 75°F in San Francisco.'} + ] + } + }, + 'usage': { + 'inputTokens': 120, + 'outputTokens': 30, + 'totalTokens': 150 + } + } + + # Configure mock to return different responses on successive calls + mock_converse.side_effect = [tool_use_response, final_response] + + # Test the call + messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] + result = llm.call( + messages=messages, + available_functions=available_functions + ) + + # Verify the final response contains the weather information + assert "sunny" in result.lower() or "75" in result + + # Verify that the API was called twice (once for tool use, once for final answer) + assert mock_converse.call_count == 2 + + +def test_bedrock_handles_cohere_conversation_requirements(): + """ + Test that Bedrock properly handles Cohere model's requirement for user message at end + """ + llm = LLM(model="bedrock/cohere.command-r-plus-v1:0") + + # Test message formatting with conversation ending in assistant message + test_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + + formatted_messages, system_message = llm._format_messages_for_converse(test_messages) + + # For Cohere models, should add a user message at the end + assert formatted_messages[-1]["role"] == "user" + assert "continue" in formatted_messages[-1]["content"][0]["text"].lower() + + +def test_bedrock_client_error_handling(): + """ + Test that Bedrock properly handles various AWS client errors + """ + from botocore.exceptions import ClientError + + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Test ValidationException + with patch.object(llm.client, 'converse') as mock_converse: + error_response = { + 'Error': { + 'Code': 'ValidationException', + 'Message': 'Invalid request format' + } + } + mock_converse.side_effect = ClientError(error_response, 'converse') + + with pytest.raises(ValueError) as exc_info: + llm.call("Hello") + assert "validation" in str(exc_info.value).lower() + + # Test ThrottlingException + with patch.object(llm.client, 'converse') as mock_converse: + error_response = { + 'Error': { + 'Code': 'ThrottlingException', + 'Message': 'Rate limit exceeded' + } + } + mock_converse.side_effect = ClientError(error_response, 'converse') + + with pytest.raises(RuntimeError) as exc_info: + llm.call("Hello") + assert "throttled" in str(exc_info.value).lower() diff --git a/pyproject.toml b/pyproject.toml index dd2344578..57c01c700 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "types-pyyaml==6.0.*", "types-regex==2024.11.6.*", "types-appdirs==1.4.*", + "boto3-stubs[bedrock-runtime]>=1.40.54", ] diff --git a/uv.lock b/uv.lock index 2bb10ecf4..911741fab 100644 --- a/uv.lock +++ b/uv.lock @@ -38,6 +38,7 @@ members = [ [manifest.dependency-groups] dev = [ { name = "bandit", specifier = ">=1.8.6" }, + { name = "boto3-stubs", extras = ["bedrock-runtime"], specifier = ">=1.40.54" }, { name = "mypy", specifier = ">=1.18.2" }, { name = "pre-commit", specifier = ">=4.3.0" }, { name = "pytest", specifier = ">=8.4.2" }, @@ -479,6 +480,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/db/7d3c27f530c2b354d546ad7fb94505be8b78a5ecabe34c6a1f9a9d6be03e/boto3-1.40.45-py3-none-any.whl", hash = "sha256:5b145752d20f29908e3cb8c823bee31c77e6bcf18787e570f36bbc545cc779ed", size = 139345, upload-time = "2025-10-03T19:32:11.145Z" }, ] +[[package]] +name = "boto3-stubs" +version = "1.40.54" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore-stubs" }, + { name = "types-s3transfer" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/70/245477b7f07c9e1533c47fa69e611b172814423a6fd4637004f0d2a13b73/boto3_stubs-1.40.54.tar.gz", hash = "sha256:e21a9eda979a451935eb3196de3efbe15b9470e6bf9027406d1f6d0ac08b339e", size = 100919, upload-time = "2025-10-16T19:49:17.079Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/52/ee9dadd1cc8911e16f18ca9fa036a10328e0a0d3fddd54fadcc1ca0f9143/boto3_stubs-1.40.54-py3-none-any.whl", hash = "sha256:548a4786785ba7b43ef4ef1a2a764bebbb0301525f3201091fcf412e4c8ce323", size = 69712, upload-time = "2025-10-16T19:49:12.847Z" }, +] + +[package.optional-dependencies] +bedrock-runtime = [ + { name = "mypy-boto3-bedrock-runtime" }, +] + [[package]] name = "botocore" version = "1.40.45" @@ -494,6 +514,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/06/df47e2ecb74bd184c9d056666afd3db011a649eaca663337835a6dd5aee6/botocore-1.40.45-py3-none-any.whl", hash = "sha256:9abf473d8372ade8442c0d4634a9decb89c854d7862ffd5500574eb63ab8f240", size = 14063670, upload-time = "2025-10-03T19:31:58.999Z" }, ] +[[package]] +name = "botocore-stubs" +version = "1.40.54" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "types-awscrt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/c0/3e78314f9baa850aae648fb6b2506748046e1c3e374d6bb3514478e34590/botocore_stubs-1.40.54.tar.gz", hash = "sha256:fb38a794ab2b896f9cc237ec725546746accaffd34f382475a8d1b98ca1078e1", size = 42225, upload-time = "2025-10-16T20:26:56.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/9f/ab316f57a7e32d4a5b790070ffa5986991098044897b08f1b65951bced2a/botocore_stubs-1.40.54-py3-none-any.whl", hash = "sha256:997e6f1c03e079c244caedf315f7a515a07480af9f93f53535e506f17cdbe880", size = 66542, upload-time = "2025-10-16T20:26:54.109Z" }, +] + [[package]] name = "browserbase" version = "1.4.0" @@ -4055,6 +4087,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" }, ] +[[package]] +name = "mypy-boto3-bedrock-runtime" +version = "1.40.41" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/38/79989f7bce998776ed1a01c17f3f58e7bc6f5fc2bcbdff929701526fa2f1/mypy_boto3_bedrock_runtime-1.40.41.tar.gz", hash = "sha256:ee9bda6d6d478c8d0995e84e884bdf1798e150d437974ae27c175774a58ffaa5", size = 28333, upload-time = "2025-09-29T19:26:04.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/6c/d3431dadf473bb76aa590b1ed8cc91726a48b029b542eff9d3024f2d70b9/mypy_boto3_bedrock_runtime-1.40.41-py3-none-any.whl", hash = "sha256:d65dff200986ff06c6b3579ddcea102555f2067c8987fca379bf4f9ed8ba3121", size = 34181, upload-time = "2025-09-29T19:26:01.898Z" }, +] + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -8181,6 +8225,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cf/07/41f5b9b11f11855eb67760ed680330e0ce9136a44b51c24dd52edb1c4eb1/types_appdirs-1.4.3.5-py3-none-any.whl", hash = "sha256:337c750e423c40911d389359b4edabe5bbc2cdd5cd0bd0518b71d2839646273b", size = 2667, upload-time = "2023-03-14T15:21:32.431Z" }, ] +[[package]] +name = "types-awscrt" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/19/a3a6377c9e2e389c1421c033a1830c29cac08f2e1e05a082ea84eb22c75f/types_awscrt-0.28.1.tar.gz", hash = "sha256:66d77ec283e1dc907526a44511a12624118723a396c36d3f3dd9855cb614ce14", size = 17410, upload-time = "2025-10-11T21:55:07.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/c7/0266b797d19b82aebe0e177efe35de7aabdc192bc1605ce3309331f0a505/types_awscrt-0.28.1-py3-none-any.whl", hash = "sha256:d88f43ef779f90b841ba99badb72fe153077225a4e426ae79e943184827b4443", size = 41851, upload-time = "2025-10-11T21:55:06.235Z" }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250915" @@ -8251,6 +8304,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/ea/91b718b8c0b88e4f61cdd61357cc4a1f8767b32be691fb388299003a3ae3/types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5", size = 15347, upload-time = "2024-04-06T02:13:37.412Z" }, ] +[[package]] +name = "types-s3transfer" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/9b/8913198b7fc700acc1dcb84827137bb2922052e43dde0f4fb0ed2dc6f118/types_s3transfer-0.14.0.tar.gz", hash = "sha256:17f800a87c7eafab0434e9d87452c809c290ae906c2024c24261c564479e9c95", size = 14218, upload-time = "2025-10-11T21:11:27.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/c3/4dfb2e87c15ca582b7d956dfb7e549de1d005c758eb9a305e934e1b83fda/types_s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:108134854069a38b048e9b710b9b35904d22a9d0f37e4e1889c2e6b58e5b3253", size = 19697, upload-time = "2025-10-11T21:11:26.749Z" }, +] + [[package]] name = "types-urllib3" version = "1.26.25.14"