mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 07:08:14 +00:00
git-subtree-dir: packages/tools git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
176 lines
6.6 KiB
Python
176 lines
6.6 KiB
Python
from typing import Type, Optional, Dict, Any, List
|
|
import os
|
|
import json
|
|
import uuid
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from dotenv import load_dotenv
|
|
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
from ..exceptions import BedrockAgentError, BedrockValidationError
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
class BedrockInvokeAgentToolInput(BaseModel):
|
|
"""Input schema for BedrockInvokeAgentTool."""
|
|
query: str = Field(..., description="The query to send to the agent")
|
|
|
|
|
|
class BedrockInvokeAgentTool(BaseTool):
|
|
name: str = "Bedrock Agent Invoke Tool"
|
|
description: str = "An agent responsible for policy analysis."
|
|
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput
|
|
agent_id: str = None
|
|
agent_alias_id: str = None
|
|
session_id: str = None
|
|
enable_trace: bool = False
|
|
end_session: bool = False
|
|
package_dependencies: List[str] = ["boto3"]
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str = None,
|
|
agent_alias_id: str = None,
|
|
session_id: str = None,
|
|
enable_trace: bool = False,
|
|
end_session: bool = False,
|
|
description: Optional[str] = None,
|
|
**kwargs
|
|
):
|
|
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
|
|
|
Args:
|
|
agent_id (str): The unique identifier of the Bedrock agent
|
|
agent_alias_id (str): The unique identifier of the agent alias
|
|
session_id (str): The unique identifier of the session
|
|
enable_trace (bool): Whether to enable trace for the agent invocation
|
|
end_session (bool): Whether to end the session with the agent
|
|
description (Optional[str]): Custom description for the tool
|
|
"""
|
|
super().__init__(**kwargs)
|
|
|
|
# Get values from environment variables if not provided
|
|
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID')
|
|
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID')
|
|
self.session_id = session_id or str(int(time.time())) # Use timestamp as session ID if not provided
|
|
self.enable_trace = enable_trace
|
|
self.end_session = end_session
|
|
|
|
# Update the description if provided
|
|
if description:
|
|
self.description = description
|
|
|
|
# Validate parameters
|
|
self._validate_parameters()
|
|
|
|
def _validate_parameters(self):
|
|
"""Validate the parameters according to AWS API requirements."""
|
|
try:
|
|
# Validate agent_id
|
|
if not self.agent_id:
|
|
raise BedrockValidationError("agent_id cannot be empty")
|
|
if not isinstance(self.agent_id, str):
|
|
raise BedrockValidationError("agent_id must be a string")
|
|
|
|
# Validate agent_alias_id
|
|
if not self.agent_alias_id:
|
|
raise BedrockValidationError("agent_alias_id cannot be empty")
|
|
if not isinstance(self.agent_alias_id, str):
|
|
raise BedrockValidationError("agent_alias_id must be a string")
|
|
|
|
# Validate session_id if provided
|
|
if self.session_id and not isinstance(self.session_id, str):
|
|
raise BedrockValidationError("session_id must be a string")
|
|
|
|
except BedrockValidationError as e:
|
|
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
|
|
|
def _run(self, query: str) -> str:
|
|
try:
|
|
import boto3
|
|
from botocore.exceptions import ClientError
|
|
except ImportError:
|
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
|
|
|
try:
|
|
# Initialize the Bedrock Agent Runtime client
|
|
bedrock_agent = boto3.client(
|
|
"bedrock-agent-runtime",
|
|
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-west-2'))
|
|
)
|
|
|
|
# Format the prompt with current time
|
|
current_utc = datetime.now(timezone.utc)
|
|
prompt = f"""
|
|
The current time is: {current_utc}
|
|
|
|
Below is the users query or task. Complete it and answer it consicely and to the point:
|
|
{query}
|
|
"""
|
|
|
|
# Invoke the agent
|
|
response = bedrock_agent.invoke_agent(
|
|
agentId=self.agent_id,
|
|
agentAliasId=self.agent_alias_id,
|
|
sessionId=self.session_id,
|
|
inputText=prompt,
|
|
enableTrace=self.enable_trace,
|
|
endSession=self.end_session
|
|
)
|
|
|
|
# Process the response
|
|
completion = ""
|
|
|
|
# Check if response contains a completion field
|
|
if 'completion' in response:
|
|
# Process streaming response format
|
|
for event in response.get('completion', []):
|
|
if 'chunk' in event and 'bytes' in event['chunk']:
|
|
chunk_bytes = event['chunk']['bytes']
|
|
if isinstance(chunk_bytes, (bytes, bytearray)):
|
|
completion += chunk_bytes.decode('utf-8')
|
|
else:
|
|
completion += str(chunk_bytes)
|
|
|
|
# If no completion found in streaming format, try direct format
|
|
if not completion and 'chunk' in response and 'bytes' in response['chunk']:
|
|
chunk_bytes = response['chunk']['bytes']
|
|
if isinstance(chunk_bytes, (bytes, bytearray)):
|
|
completion = chunk_bytes.decode('utf-8')
|
|
else:
|
|
completion = str(chunk_bytes)
|
|
|
|
# If still no completion, return debug info
|
|
if not completion:
|
|
debug_info = {
|
|
"error": "Could not extract completion from response",
|
|
"response_keys": list(response.keys())
|
|
}
|
|
|
|
# Add more debug info
|
|
if 'chunk' in response:
|
|
debug_info["chunk_keys"] = list(response['chunk'].keys())
|
|
|
|
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
|
|
|
|
return completion
|
|
|
|
except ClientError as e:
|
|
error_code = "Unknown"
|
|
error_message = str(e)
|
|
|
|
# Try to extract error code if available
|
|
if hasattr(e, 'response') and 'Error' in e.response:
|
|
error_code = e.response['Error'].get('Code', 'Unknown')
|
|
error_message = e.response['Error'].get('Message', str(e))
|
|
|
|
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
|
except BedrockAgentError:
|
|
# Re-raise BedrockAgentError exceptions
|
|
raise
|
|
except Exception as e:
|
|
raise BedrockAgentError(f"Unexpected error: {str(e)}") |