Lorenze/bedrock llm (#3693)

* feat: add AWS Bedrock support and update dependencies

- Introduced BedrockCompletion class for AWS Bedrock integration in LLM.
- Added boto3 as a new dependency in both pyproject.toml and uv.lock.
- Updated LLM class to support Bedrock provider.
- Created new files for Bedrock provider implementation.

* using converse api

* converse

* linted

* refactor: update BedrockCompletion class to improve parameter handling

- Changed max_tokens from a fixed integer to an optional integer.
- Simplified model ID assignment by removing the inference profile mapping method.
- Cleaned up comments and unnecessary code related to tool specifications and model-specific parameters.
This commit is contained in:
Lorenze Jay
2025-10-13 20:42:34 -07:00
committed by GitHub
parent 6c5ac13242
commit cec4e4c2e9
5 changed files with 569 additions and 1 deletions

View File

@@ -85,6 +85,9 @@ voyageai = [
litellm = [
"litellm>=1.74.9",
]
boto3 = [
"boto3>=1.40.45",
]
[project.scripts]

View File

@@ -367,6 +367,14 @@ class LLM(BaseLLM):
except ImportError:
return None
elif provider == "bedrock":
try:
from crewai.llms.providers.bedrock.completion import BedrockCompletion
return BedrockCompletion
except ImportError:
return None
return None
def __init__(

View File

@@ -0,0 +1,553 @@
from collections.abc import Mapping, Sequence
import logging
import os
from typing import Any
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
try:
from boto3.session import Session
from botocore.config import Config
from botocore.exceptions import BotoCoreError, ClientError
except ImportError:
raise ImportError(
"AWS Bedrock native provider not available, to install: `uv add boto3`"
) from None
class BedrockCompletion(BaseLLM):
"""AWS Bedrock native completion implementation using the Converse API.
This class provides direct integration with AWS Bedrock using the modern
Converse API, which provides a unified interface across all Bedrock models.
"""
def __init__(
self,
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
region_name: str = "us-east-1",
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
top_k: int | None = None,
stop_sequences: Sequence[str] | None = None,
stream: bool = False,
**kwargs,
):
"""Initialize AWS Bedrock completion client."""
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
super().__init__(
model=model,
temperature=temperature,
stop=stop_sequences or [],
provider="bedrock",
**kwargs,
)
# Initialize Bedrock client with proper configuration
session = Session(
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key
or os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
region_name=region_name,
)
# Configure client with timeouts and retries following AWS best practices
config = Config(
connect_timeout=60,
read_timeout=300,
retries={
"max_attempts": 3,
"mode": "adaptive",
},
tcp_keepalive=True,
)
self.client = session.client("bedrock-runtime", config=config)
self.region_name = region_name
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_claude_model = "claude" in model.lower()
self.supports_tools = True # Converse API supports tools for most models
self.supports_streaming = True
# Handle inference profiles for newer models
self.model_id = model
def call(
self,
messages: str | list[dict[str, str]],
tools: Sequence[Mapping[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call AWS Bedrock Converse API."""
try:
# Emit call started event
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
# Format messages for Converse API
formatted_messages, system_message = self._format_messages_for_converse(
messages
)
# Prepare tool configuration
tool_config = None
if tools:
tool_config = {"tools": self._format_tools_for_converse(tools)}
# Prepare request body
body = {
"inferenceConfig": self._get_inference_config(),
}
# Add system message if present
if system_message:
body["system"] = [{"text": system_message}]
# Add tool config if present
if tool_config:
body["toolConfig"] = tool_config
if self.stream:
return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
)
return self._handle_converse(
formatted_messages, body, available_functions, from_task, from_agent
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"AWS Bedrock API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _handle_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle non-streaming converse API call following AWS best practices."""
try:
# Validate messages format before API call
if not messages:
raise ValueError("Messages cannot be empty")
# Ensure we have valid message structure
for i, msg in enumerate(messages):
if (
not isinstance(msg, dict)
or "role" not in msg
or "content" not in msg
):
raise ValueError(f"Invalid message format at index {i}")
# Call Bedrock Converse API with proper error handling
response = self.client.converse(
modelId=self.model_id, messages=messages, **body
)
# Track token usage according to AWS response format
if "usage" in response:
self._track_token_usage_internal(response["usage"])
# Extract content following AWS response structure
output = response.get("output", {})
message = output.get("message", {})
content = message.get("content", [])
if not content:
logging.warning("No content in Bedrock response")
return (
"I apologize, but I received an empty response. Please try again."
)
# Extract text content from response
text_content = ""
for content_block in content:
# Handle different content block types as per AWS documentation
if "text" in content_block:
text_content += content_block["text"]
elif content_block.get("type") == "toolUse" and available_functions:
# Handle tool use according to AWS format
tool_use = content_block["toolUse"]
function_name = tool_use.get("name")
function_args = tool_use.get("input", {})
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Apply stop sequences if configured
text_content = self._apply_stop_words(text_content)
# Validate final response
if not text_content or text_content.strip() == "":
logging.warning("Extracted empty text content from Bedrock response")
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
self._emit_call_completed_event(
response=text_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return text_content
except ClientError as e:
# Handle all AWS ClientError exceptions as per documentation
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_msg = e.response.get("Error", {}).get("Message", str(e))
# Log the specific error for debugging
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
# Handle specific error codes as documented
if error_code == "ValidationException":
# This is the error we're seeing with Cohere
if "last turn" in error_msg and "user message" in error_msg:
raise ValueError(
f"Conversation format error: {error_msg}. Check message alternation."
) from e
raise ValueError(f"Request validation failed: {error_msg}") from e
if error_code == "AccessDeniedException":
raise PermissionError(
f"Access denied to model {self.model_id}: {error_msg}"
) from e
if error_code == "ResourceNotFoundException":
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
if error_code == "ThrottlingException":
raise RuntimeError(
f"API throttled, please retry later: {error_msg}"
) from e
if error_code == "ModelTimeoutException":
raise TimeoutError(f"Model request timed out: {error_msg}") from e
if error_code == "ServiceQuotaExceededException":
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
if error_code == "ModelNotReadyException":
raise RuntimeError(
f"Model {self.model_id} not ready: {error_msg}"
) from e
if error_code == "ModelErrorException":
raise RuntimeError(f"Model error: {error_msg}") from e
if error_code == "InternalServerException":
raise RuntimeError(f"Internal server error: {error_msg}") from e
if error_code == "ServiceUnavailableException":
raise RuntimeError(f"Service unavailable: {error_msg}") from e
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
except BotoCoreError as e:
error_msg = f"Bedrock connection error: {e}"
logging.error(error_msg)
raise ConnectionError(error_msg) from e
except Exception as e:
# Catch any other unexpected errors
error_msg = f"Unexpected error in Bedrock converse call: {e}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def _handle_streaming_converse(
self,
messages: list[dict[str, Any]],
body: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming converse API call."""
full_response = ""
try:
response = self.client.converse_stream(
modelId=self.model_id, messages=messages, **body
)
stream = response.get("stream")
if stream:
for event in stream:
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
text_chunk = delta["text"]
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
full_response += text_chunk
self._emit_stream_chunk_event(
chunk=text_chunk,
from_task=from_task,
from_agent=from_agent,
)
elif "messageStop" in event:
# Handle end of message
break
except ClientError as e:
error_msg = self._handle_client_error(e)
raise RuntimeError(error_msg) from e
except BotoCoreError as e:
error_msg = f"Bedrock streaming connection error: {e}"
logging.error(error_msg)
raise ConnectionError(error_msg) from e
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Ensure we don't return empty content
if not full_response or full_response.strip() == "":
logging.warning("Bedrock streaming returned empty content, using fallback")
full_response = (
"I apologize, but I couldn't generate a response. Please try again."
)
# Emit completion event
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return full_response
def _format_messages_for_converse(
self, messages: str | list[dict[str, str]]
) -> tuple[list[dict[str, Any]], str | None]:
"""Format messages for Converse API following AWS documentation."""
# Use base class formatting first
formatted_messages = self._format_messages(messages)
converse_messages = []
system_message = None
for message in formatted_messages:
role = message.get("role")
content = message.get("content", "")
if role == "system":
# Extract system message - Converse API handles it separately
if system_message:
system_message += f"\n\n{content}"
else:
system_message = content
else:
# Convert to Converse API format with proper content structure
converse_messages.append({"role": role, "content": [{"text": content}]})
# CRITICAL: Handle model-specific conversation requirements
# Cohere and some other models require conversation to end with user message
if converse_messages:
last_message = converse_messages[-1]
if last_message["role"] == "assistant":
# For Cohere models, add a continuation user message
if "cohere" in self.model.lower():
converse_messages.append(
{
"role": "user",
"content": [
{
"text": "Please continue and provide your final answer."
}
],
}
)
# For other models that might have similar requirements
elif any(
model_family in self.model.lower()
for model_family in ["command", "coral"]
):
converse_messages.append(
{
"role": "user",
"content": [{"text": "Continue your response."}],
}
)
# Ensure first message is from user (required by Converse API)
if not converse_messages:
converse_messages.append(
{
"role": "user",
"content": [{"text": "Hello, please help me with my request."}],
}
)
elif converse_messages[0]["role"] != "user":
converse_messages.insert(
0,
{
"role": "user",
"content": [{"text": "Hello, please help me with my request."}],
},
)
return converse_messages, system_message
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
"""Convert CrewAI tools to Converse API format following AWS specification."""
from crewai.llms.providers.utils.common import safe_tool_conversion
converse_tools = []
for tool in tools:
try:
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
converse_tool = {
"toolSpec": {
"name": name,
"description": description,
}
}
if parameters and isinstance(parameters, dict):
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
converse_tools.append(converse_tool)
except Exception as e: # noqa: PERF203
logging.warning(
f"Failed to convert tool {tool.get('name', 'unknown')}: {e}"
)
continue
return converse_tools
def _get_inference_config(self) -> dict[str, Any]:
"""Get inference configuration following AWS Converse API specification."""
config = {}
if self.max_tokens:
config["maxTokens"] = self.max_tokens
if self.temperature is not None:
config["temperature"] = float(self.temperature)
if self.top_p is not None:
config["topP"] = float(self.top_p)
if self.stop_sequences:
config["stopSequences"] = self.stop_sequences
if self.is_claude_model and self.top_k is not None:
# top_k is supported by Claude models
config["topK"] = int(self.top_k)
return config
def _handle_client_error(self, e: ClientError) -> str:
"""Handle AWS ClientError with specific error codes and return error message."""
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_msg = e.response.get("Error", {}).get("Message", str(e))
error_mapping = {
"AccessDeniedException": f"Access denied to model {self.model_id}: {error_msg}",
"ResourceNotFoundException": f"Model {self.model_id} not found: {error_msg}",
"ThrottlingException": f"API throttled, please retry later: {error_msg}",
"ValidationException": f"Invalid request: {error_msg}",
"ModelTimeoutException": f"Model request timed out: {error_msg}",
"ServiceQuotaExceededException": f"Service quota exceeded: {error_msg}",
"ModelNotReadyException": f"Model {self.model_id} not ready: {error_msg}",
"ModelErrorException": f"Model error: {error_msg}",
}
full_error_msg = error_mapping.get(
error_code, f"Bedrock API error: {error_msg}"
)
logging.error(f"Bedrock client error ({error_code}): {full_error_msg}")
return full_error_msg
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
"""Track token usage from Bedrock response."""
input_tokens = usage.get("inputTokens", 0)
output_tokens = usage.get("outputTokens", 0)
total_tokens = usage.get("totalTokens", input_tokens + output_tokens)
self._token_usage["prompt_tokens"] += input_tokens
self._token_usage["completion_tokens"] += output_tokens
self._token_usage["total_tokens"] += total_tokens
self._token_usage["successful_requests"] += 1
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
# Context window sizes for common Bedrock models
context_windows = {
"anthropic.claude-3-5-sonnet": 200000,
"anthropic.claude-3-5-haiku": 200000,
"anthropic.claude-3-opus": 200000,
"anthropic.claude-3-sonnet": 200000,
"anthropic.claude-3-haiku": 200000,
"anthropic.claude-3-7-sonnet": 200000,
"anthropic.claude-v2": 100000,
"amazon.titan-text-express": 8000,
"ai21.j2-ultra": 8192,
"cohere.command-text": 4096,
"meta.llama2-13b-chat": 4096,
"meta.llama2-70b-chat": 4096,
"meta.llama3-70b-instruct": 128000,
"deepseek.r1": 32768,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)

6
uv.lock generated
View File

@@ -1020,6 +1020,9 @@ aisuite = [
aws = [
{ name = "boto3" },
]
boto3 = [
{ name = "boto3" },
]
docling = [
{ name = "docling" },
]
@@ -1060,6 +1063,7 @@ requires-dist = [
{ name = "appdirs", specifier = ">=1.4.4" },
{ name = "blinker", specifier = ">=1.9.0" },
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
{ name = "chromadb", specifier = "~=1.1.0" },
{ name = "click", specifier = ">=8.1.7" },
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
@@ -1095,7 +1099,7 @@ requires-dist = [
{ name = "uv", specifier = ">=0.4.25" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
]
provides-extras = ["aisuite", "aws", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
[[package]]
name = "crewai-devtools"