mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Lorenze/bedrock llm (#3693)
* feat: add AWS Bedrock support and update dependencies - Introduced BedrockCompletion class for AWS Bedrock integration in LLM. - Added boto3 as a new dependency in both pyproject.toml and uv.lock. - Updated LLM class to support Bedrock provider. - Created new files for Bedrock provider implementation. * using converse api * converse * linted * refactor: update BedrockCompletion class to improve parameter handling - Changed max_tokens from a fixed integer to an optional integer. - Simplified model ID assignment by removing the inference profile mapping method. - Cleaned up comments and unnecessary code related to tool specifications and model-specific parameters.
This commit is contained in:
@@ -85,6 +85,9 @@ voyageai = [
|
|||||||
litellm = [
|
litellm = [
|
||||||
"litellm>=1.74.9",
|
"litellm>=1.74.9",
|
||||||
]
|
]
|
||||||
|
boto3 = [
|
||||||
|
"boto3>=1.40.45",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -367,6 +367,14 @@ class LLM(BaseLLM):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
elif provider == "bedrock":
|
||||||
|
try:
|
||||||
|
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||||
|
|
||||||
|
return BedrockCompletion
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
553
lib/crewai/src/crewai/llms/providers/bedrock/completion.py
Normal file
553
lib/crewai/src/crewai/llms/providers/bedrock/completion.py
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
|
LLMContextLengthExceededError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from boto3.session import Session
|
||||||
|
from botocore.config import Config
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"AWS Bedrock native provider not available, to install: `uv add boto3`"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockCompletion(BaseLLM):
|
||||||
|
"""AWS Bedrock native completion implementation using the Converse API.
|
||||||
|
|
||||||
|
This class provides direct integration with AWS Bedrock using the modern
|
||||||
|
Converse API, which provides a unified interface across all Bedrock models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
aws_access_key_id: str | None = None,
|
||||||
|
aws_secret_access_key: str | None = None,
|
||||||
|
aws_session_token: str | None = None,
|
||||||
|
region_name: str = "us-east-1",
|
||||||
|
temperature: float | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
top_p: float | None = None,
|
||||||
|
top_k: int | None = None,
|
||||||
|
stop_sequences: Sequence[str] | None = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize AWS Bedrock completion client."""
|
||||||
|
# Extract provider from kwargs to avoid duplicate argument
|
||||||
|
kwargs.pop("provider", None)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop_sequences or [],
|
||||||
|
provider="bedrock",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize Bedrock client with proper configuration
|
||||||
|
session = Session(
|
||||||
|
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=aws_secret_access_key
|
||||||
|
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure client with timeouts and retries following AWS best practices
|
||||||
|
config = Config(
|
||||||
|
connect_timeout=60,
|
||||||
|
read_timeout=300,
|
||||||
|
retries={
|
||||||
|
"max_attempts": 3,
|
||||||
|
"mode": "adaptive",
|
||||||
|
},
|
||||||
|
tcp_keepalive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = session.client("bedrock-runtime", config=config)
|
||||||
|
self.region_name = region_name
|
||||||
|
|
||||||
|
# Store completion parameters
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
self.stream = stream
|
||||||
|
self.stop_sequences = stop_sequences or []
|
||||||
|
|
||||||
|
# Model-specific settings
|
||||||
|
self.is_claude_model = "claude" in model.lower()
|
||||||
|
self.supports_tools = True # Converse API supports tools for most models
|
||||||
|
self.supports_streaming = True
|
||||||
|
|
||||||
|
# Handle inference profiles for newer models
|
||||||
|
self.model_id = model
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: str | list[dict[str, str]],
|
||||||
|
tools: Sequence[Mapping[str, Any]] | None = None,
|
||||||
|
callbacks: list[Any] | None = None,
|
||||||
|
available_functions: dict[str, Any] | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
) -> str | Any:
|
||||||
|
"""Call AWS Bedrock Converse API."""
|
||||||
|
try:
|
||||||
|
# Emit call started event
|
||||||
|
self._emit_call_started_event(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
callbacks=callbacks,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format messages for Converse API
|
||||||
|
formatted_messages, system_message = self._format_messages_for_converse(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare tool configuration
|
||||||
|
tool_config = None
|
||||||
|
if tools:
|
||||||
|
tool_config = {"tools": self._format_tools_for_converse(tools)}
|
||||||
|
|
||||||
|
# Prepare request body
|
||||||
|
body = {
|
||||||
|
"inferenceConfig": self._get_inference_config(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add system message if present
|
||||||
|
if system_message:
|
||||||
|
body["system"] = [{"text": system_message}]
|
||||||
|
|
||||||
|
# Add tool config if present
|
||||||
|
if tool_config:
|
||||||
|
body["toolConfig"] = tool_config
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
return self._handle_streaming_converse(
|
||||||
|
formatted_messages, body, available_functions, from_task, from_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._handle_converse(
|
||||||
|
formatted_messages, body, available_functions, from_task, from_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if is_context_length_exceeded(e):
|
||||||
|
logging.error(f"Context window exceeded: {e}")
|
||||||
|
raise LLMContextLengthExceededError(str(e)) from e
|
||||||
|
|
||||||
|
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
self._emit_call_failed_event(
|
||||||
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _handle_converse(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
body: dict[str, Any],
|
||||||
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Handle non-streaming converse API call following AWS best practices."""
|
||||||
|
try:
|
||||||
|
# Validate messages format before API call
|
||||||
|
if not messages:
|
||||||
|
raise ValueError("Messages cannot be empty")
|
||||||
|
|
||||||
|
# Ensure we have valid message structure
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if (
|
||||||
|
not isinstance(msg, dict)
|
||||||
|
or "role" not in msg
|
||||||
|
or "content" not in msg
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid message format at index {i}")
|
||||||
|
|
||||||
|
# Call Bedrock Converse API with proper error handling
|
||||||
|
response = self.client.converse(
|
||||||
|
modelId=self.model_id, messages=messages, **body
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track token usage according to AWS response format
|
||||||
|
if "usage" in response:
|
||||||
|
self._track_token_usage_internal(response["usage"])
|
||||||
|
|
||||||
|
# Extract content following AWS response structure
|
||||||
|
output = response.get("output", {})
|
||||||
|
message = output.get("message", {})
|
||||||
|
content = message.get("content", [])
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
logging.warning("No content in Bedrock response")
|
||||||
|
return (
|
||||||
|
"I apologize, but I received an empty response. Please try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text content from response
|
||||||
|
text_content = ""
|
||||||
|
for content_block in content:
|
||||||
|
# Handle different content block types as per AWS documentation
|
||||||
|
if "text" in content_block:
|
||||||
|
text_content += content_block["text"]
|
||||||
|
elif content_block.get("type") == "toolUse" and available_functions:
|
||||||
|
# Handle tool use according to AWS format
|
||||||
|
tool_use = content_block["toolUse"]
|
||||||
|
function_name = tool_use.get("name")
|
||||||
|
function_args = tool_use.get("input", {})
|
||||||
|
|
||||||
|
result = self._handle_tool_execution(
|
||||||
|
function_name=function_name,
|
||||||
|
function_args=function_args,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Apply stop sequences if configured
|
||||||
|
text_content = self._apply_stop_words(text_content)
|
||||||
|
|
||||||
|
# Validate final response
|
||||||
|
if not text_content or text_content.strip() == "":
|
||||||
|
logging.warning("Extracted empty text content from Bedrock response")
|
||||||
|
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||||
|
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=text_content,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_content
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
# Handle all AWS ClientError exceptions as per documentation
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||||
|
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||||
|
|
||||||
|
# Log the specific error for debugging
|
||||||
|
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||||
|
|
||||||
|
# Handle specific error codes as documented
|
||||||
|
if error_code == "ValidationException":
|
||||||
|
# This is the error we're seeing with Cohere
|
||||||
|
if "last turn" in error_msg and "user message" in error_msg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conversation format error: {error_msg}. Check message alternation."
|
||||||
|
) from e
|
||||||
|
raise ValueError(f"Request validation failed: {error_msg}") from e
|
||||||
|
if error_code == "AccessDeniedException":
|
||||||
|
raise PermissionError(
|
||||||
|
f"Access denied to model {self.model_id}: {error_msg}"
|
||||||
|
) from e
|
||||||
|
if error_code == "ResourceNotFoundException":
|
||||||
|
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
|
||||||
|
if error_code == "ThrottlingException":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"API throttled, please retry later: {error_msg}"
|
||||||
|
) from e
|
||||||
|
if error_code == "ModelTimeoutException":
|
||||||
|
raise TimeoutError(f"Model request timed out: {error_msg}") from e
|
||||||
|
if error_code == "ServiceQuotaExceededException":
|
||||||
|
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
|
||||||
|
if error_code == "ModelNotReadyException":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model {self.model_id} not ready: {error_msg}"
|
||||||
|
) from e
|
||||||
|
if error_code == "ModelErrorException":
|
||||||
|
raise RuntimeError(f"Model error: {error_msg}") from e
|
||||||
|
if error_code == "InternalServerException":
|
||||||
|
raise RuntimeError(f"Internal server error: {error_msg}") from e
|
||||||
|
if error_code == "ServiceUnavailableException":
|
||||||
|
raise RuntimeError(f"Service unavailable: {error_msg}") from e
|
||||||
|
|
||||||
|
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
|
||||||
|
|
||||||
|
except BotoCoreError as e:
|
||||||
|
error_msg = f"Bedrock connection error: {e}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise ConnectionError(error_msg) from e
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any other unexpected errors
|
||||||
|
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def _handle_streaming_converse(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
body: dict[str, Any],
|
||||||
|
available_functions: dict[str, Any] | None = None,
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Handle streaming converse API call."""
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.converse_stream(
|
||||||
|
modelId=self.model_id, messages=messages, **body
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = response.get("stream")
|
||||||
|
if stream:
|
||||||
|
for event in stream:
|
||||||
|
if "contentBlockDelta" in event:
|
||||||
|
delta = event["contentBlockDelta"]["delta"]
|
||||||
|
if "text" in delta:
|
||||||
|
text_chunk = delta["text"]
|
||||||
|
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
|
||||||
|
full_response += text_chunk
|
||||||
|
self._emit_stream_chunk_event(
|
||||||
|
chunk=text_chunk,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
elif "messageStop" in event:
|
||||||
|
# Handle end of message
|
||||||
|
break
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
error_msg = self._handle_client_error(e)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
except BotoCoreError as e:
|
||||||
|
error_msg = f"Bedrock streaming connection error: {e}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise ConnectionError(error_msg) from e
|
||||||
|
|
||||||
|
# Apply stop words to full response
|
||||||
|
full_response = self._apply_stop_words(full_response)
|
||||||
|
|
||||||
|
# Ensure we don't return empty content
|
||||||
|
if not full_response or full_response.strip() == "":
|
||||||
|
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||||
|
full_response = (
|
||||||
|
"I apologize, but I couldn't generate a response. Please try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit completion event
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=full_response,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
def _format_messages_for_converse(
|
||||||
|
self, messages: str | list[dict[str, str]]
|
||||||
|
) -> tuple[list[dict[str, Any]], str | None]:
|
||||||
|
"""Format messages for Converse API following AWS documentation."""
|
||||||
|
# Use base class formatting first
|
||||||
|
formatted_messages = self._format_messages(messages)
|
||||||
|
|
||||||
|
converse_messages = []
|
||||||
|
system_message = None
|
||||||
|
|
||||||
|
for message in formatted_messages:
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
# Extract system message - Converse API handles it separately
|
||||||
|
if system_message:
|
||||||
|
system_message += f"\n\n{content}"
|
||||||
|
else:
|
||||||
|
system_message = content
|
||||||
|
else:
|
||||||
|
# Convert to Converse API format with proper content structure
|
||||||
|
converse_messages.append({"role": role, "content": [{"text": content}]})
|
||||||
|
|
||||||
|
# CRITICAL: Handle model-specific conversation requirements
|
||||||
|
# Cohere and some other models require conversation to end with user message
|
||||||
|
if converse_messages:
|
||||||
|
last_message = converse_messages[-1]
|
||||||
|
if last_message["role"] == "assistant":
|
||||||
|
# For Cohere models, add a continuation user message
|
||||||
|
if "cohere" in self.model.lower():
|
||||||
|
converse_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Please continue and provide your final answer."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# For other models that might have similar requirements
|
||||||
|
elif any(
|
||||||
|
model_family in self.model.lower()
|
||||||
|
for model_family in ["command", "coral"]
|
||||||
|
):
|
||||||
|
converse_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Continue your response."}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure first message is from user (required by Converse API)
|
||||||
|
if not converse_messages:
|
||||||
|
converse_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Hello, please help me with my request."}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif converse_messages[0]["role"] != "user":
|
||||||
|
converse_messages.insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Hello, please help me with my request."}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return converse_messages, system_message
|
||||||
|
|
||||||
|
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
|
||||||
|
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||||
|
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||||
|
|
||||||
|
converse_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
try:
|
||||||
|
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
|
||||||
|
|
||||||
|
converse_tool = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if parameters and isinstance(parameters, dict):
|
||||||
|
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
|
||||||
|
|
||||||
|
converse_tools.append(converse_tool)
|
||||||
|
|
||||||
|
except Exception as e: # noqa: PERF203
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to convert tool {tool.get('name', 'unknown')}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return converse_tools
|
||||||
|
|
||||||
|
def _get_inference_config(self) -> dict[str, Any]:
|
||||||
|
"""Get inference configuration following AWS Converse API specification."""
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
if self.max_tokens:
|
||||||
|
config["maxTokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if self.temperature is not None:
|
||||||
|
config["temperature"] = float(self.temperature)
|
||||||
|
if self.top_p is not None:
|
||||||
|
config["topP"] = float(self.top_p)
|
||||||
|
if self.stop_sequences:
|
||||||
|
config["stopSequences"] = self.stop_sequences
|
||||||
|
|
||||||
|
if self.is_claude_model and self.top_k is not None:
|
||||||
|
# top_k is supported by Claude models
|
||||||
|
config["topK"] = int(self.top_k)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _handle_client_error(self, e: ClientError) -> str:
|
||||||
|
"""Handle AWS ClientError with specific error codes and return error message."""
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||||
|
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||||
|
|
||||||
|
error_mapping = {
|
||||||
|
"AccessDeniedException": f"Access denied to model {self.model_id}: {error_msg}",
|
||||||
|
"ResourceNotFoundException": f"Model {self.model_id} not found: {error_msg}",
|
||||||
|
"ThrottlingException": f"API throttled, please retry later: {error_msg}",
|
||||||
|
"ValidationException": f"Invalid request: {error_msg}",
|
||||||
|
"ModelTimeoutException": f"Model request timed out: {error_msg}",
|
||||||
|
"ServiceQuotaExceededException": f"Service quota exceeded: {error_msg}",
|
||||||
|
"ModelNotReadyException": f"Model {self.model_id} not ready: {error_msg}",
|
||||||
|
"ModelErrorException": f"Model error: {error_msg}",
|
||||||
|
}
|
||||||
|
|
||||||
|
full_error_msg = error_mapping.get(
|
||||||
|
error_code, f"Bedrock API error: {error_msg}"
|
||||||
|
)
|
||||||
|
logging.error(f"Bedrock client error ({error_code}): {full_error_msg}")
|
||||||
|
|
||||||
|
return full_error_msg
|
||||||
|
|
||||||
|
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
|
||||||
|
"""Track token usage from Bedrock response."""
|
||||||
|
input_tokens = usage.get("inputTokens", 0)
|
||||||
|
output_tokens = usage.get("outputTokens", 0)
|
||||||
|
total_tokens = usage.get("totalTokens", input_tokens + output_tokens)
|
||||||
|
|
||||||
|
self._token_usage["prompt_tokens"] += input_tokens
|
||||||
|
self._token_usage["completion_tokens"] += output_tokens
|
||||||
|
self._token_usage["total_tokens"] += total_tokens
|
||||||
|
self._token_usage["successful_requests"] += 1
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the model supports function calling."""
|
||||||
|
return self.supports_tools
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the model supports stop words."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size for the model."""
|
||||||
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||||
|
|
||||||
|
# Context window sizes for common Bedrock models
|
||||||
|
context_windows = {
|
||||||
|
"anthropic.claude-3-5-sonnet": 200000,
|
||||||
|
"anthropic.claude-3-5-haiku": 200000,
|
||||||
|
"anthropic.claude-3-opus": 200000,
|
||||||
|
"anthropic.claude-3-sonnet": 200000,
|
||||||
|
"anthropic.claude-3-haiku": 200000,
|
||||||
|
"anthropic.claude-3-7-sonnet": 200000,
|
||||||
|
"anthropic.claude-v2": 100000,
|
||||||
|
"amazon.titan-text-express": 8000,
|
||||||
|
"ai21.j2-ultra": 8192,
|
||||||
|
"cohere.command-text": 4096,
|
||||||
|
"meta.llama2-13b-chat": 4096,
|
||||||
|
"meta.llama2-70b-chat": 4096,
|
||||||
|
"meta.llama3-70b-instruct": 128000,
|
||||||
|
"deepseek.r1": 32768,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find the best match for the model name
|
||||||
|
for model_prefix, size in context_windows.items():
|
||||||
|
if self.model.startswith(model_prefix):
|
||||||
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
|
# Default context window size
|
||||||
|
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
6
uv.lock
generated
6
uv.lock
generated
@@ -1020,6 +1020,9 @@ aisuite = [
|
|||||||
aws = [
|
aws = [
|
||||||
{ name = "boto3" },
|
{ name = "boto3" },
|
||||||
]
|
]
|
||||||
|
boto3 = [
|
||||||
|
{ name = "boto3" },
|
||||||
|
]
|
||||||
docling = [
|
docling = [
|
||||||
{ name = "docling" },
|
{ name = "docling" },
|
||||||
]
|
]
|
||||||
@@ -1060,6 +1063,7 @@ requires-dist = [
|
|||||||
{ name = "appdirs", specifier = ">=1.4.4" },
|
{ name = "appdirs", specifier = ">=1.4.4" },
|
||||||
{ name = "blinker", specifier = ">=1.9.0" },
|
{ name = "blinker", specifier = ">=1.9.0" },
|
||||||
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
|
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
|
||||||
|
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
|
||||||
{ name = "chromadb", specifier = "~=1.1.0" },
|
{ name = "chromadb", specifier = "~=1.1.0" },
|
||||||
{ name = "click", specifier = ">=8.1.7" },
|
{ name = "click", specifier = ">=8.1.7" },
|
||||||
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
|
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
|
||||||
@@ -1095,7 +1099,7 @@ requires-dist = [
|
|||||||
{ name = "uv", specifier = ">=0.4.25" },
|
{ name = "uv", specifier = ">=0.4.25" },
|
||||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
|
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
|
||||||
]
|
]
|
||||||
provides-extras = ["aisuite", "aws", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crewai-devtools"
|
name = "crewai-devtools"
|
||||||
|
|||||||
Reference in New Issue
Block a user