mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
- Add custom exceptions for better error handling
- Add parameter validation for Bedrock tools - Improve response processing and debug information - Maintain backward compatibility with existing implementations
This commit is contained in:
@@ -11,6 +11,9 @@ from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
@@ -62,6 +65,31 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
# Update the description if provided
|
||||
if description:
|
||||
self.description = description
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate agent_id
|
||||
if not self.agent_id:
|
||||
raise BedrockValidationError("agent_id cannot be empty")
|
||||
if not isinstance(self.agent_id, str):
|
||||
raise BedrockValidationError("agent_id must be a string")
|
||||
|
||||
# Validate agent_alias_id
|
||||
if not self.agent_alias_id:
|
||||
raise BedrockValidationError("agent_alias_id cannot be empty")
|
||||
if not isinstance(self.agent_alias_id, str):
|
||||
raise BedrockValidationError("agent_alias_id must be a string")
|
||||
|
||||
# Validate session_id if provided
|
||||
if self.session_id and not isinstance(self.session_id, str):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
@@ -123,7 +151,7 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
if 'chunk' in response:
|
||||
debug_info["chunk_keys"] = list(response['chunk'].keys())
|
||||
|
||||
return json.dumps(debug_info, indent=2)
|
||||
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
|
||||
|
||||
return completion
|
||||
|
||||
@@ -132,9 +160,13 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response and 'Code' in e.response['Error']:
|
||||
error_code = e.response['Error']['Code']
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
return f"Error invoking Bedrock Agent ({error_code}): {error_message}"
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||
except BedrockAgentError:
|
||||
# Re-raise BedrockAgentError exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
raise BedrockAgentError(f"Unexpected error: {str(e)}")
|
||||
17
src/crewai_tools/aws/bedrock/exceptions.py
Normal file
17
src/crewai_tools/aws/bedrock/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Custom exceptions for AWS Bedrock integration."""
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Base exception for Bedrock-related errors."""
|
||||
pass
|
||||
|
||||
class BedrockAgentError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Agent operations."""
|
||||
pass
|
||||
|
||||
class BedrockKnowledgeBaseError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Knowledge Base operations."""
|
||||
pass
|
||||
|
||||
class BedrockValidationError(BedrockError):
|
||||
"""Exception raised for validation errors in Bedrock operations."""
|
||||
pass
|
||||
@@ -8,6 +8,9 @@ from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
@@ -39,7 +42,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||
|
||||
Args:
|
||||
knowledge_base_id (str): The unique identifier of the knowledge base to query (length: 0-10, pattern: ^[0-9a-zA-Z]+$)
|
||||
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
||||
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
|
||||
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
|
||||
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
|
||||
@@ -50,19 +53,14 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Get knowledge_base_id from environment variable if not provided
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
|
||||
self.number_of_results = number_of_results
|
||||
|
||||
# Initialize retrieval_configuration with number_of_results if provided
|
||||
if retrieval_configuration is None and number_of_results is not None:
|
||||
self.retrieval_configuration = {
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": number_of_results
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.retrieval_configuration = retrieval_configuration
|
||||
|
||||
self.guardrail_configuration = guardrail_configuration
|
||||
self.next_token = next_token
|
||||
|
||||
# Initialize retrieval_configuration with provided parameters or use the one provided
|
||||
if retrieval_configuration is None:
|
||||
self.retrieval_configuration = self._build_retrieval_configuration()
|
||||
else:
|
||||
self.retrieval_configuration = retrieval_configuration
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
@@ -70,15 +68,115 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Update the description to include the knowledge base details
|
||||
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
||||
|
||||
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||
"""Build the retrieval configuration based on provided parameters.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The constructed retrieval configuration
|
||||
"""
|
||||
vector_search_config = {}
|
||||
|
||||
# Add number of results if provided
|
||||
if self.number_of_results is not None:
|
||||
vector_search_config["numberOfResults"] = self.number_of_results
|
||||
|
||||
return {"vectorSearchConfiguration": vector_search_config}
|
||||
|
||||
def _validate_parameters(self):
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
# Validate knowledge_base_id
|
||||
if not self.knowledge_base_id or len(self.knowledge_base_id) > 10 or not all(c.isalnum() for c in self.knowledge_base_id):
|
||||
raise ValueError("knowledge_base_id must be 0-10 alphanumeric characters")
|
||||
try:
|
||||
# Validate knowledge_base_id
|
||||
if not self.knowledge_base_id:
|
||||
raise BedrockValidationError("knowledge_base_id cannot be empty")
|
||||
if not isinstance(self.knowledge_base_id, str):
|
||||
raise BedrockValidationError("knowledge_base_id must be a string")
|
||||
if len(self.knowledge_base_id) > 10:
|
||||
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
|
||||
if not all(c.isalnum() for c in self.knowledge_base_id):
|
||||
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token:
|
||||
if not isinstance(self.next_token, str):
|
||||
raise BedrockValidationError("next_token must be a string")
|
||||
if len(self.next_token) < 1 or len(self.next_token) > 2048:
|
||||
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
|
||||
if ' ' in self.next_token:
|
||||
raise BedrockValidationError("next_token cannot contain spaces")
|
||||
|
||||
# Validate number_of_results if provided
|
||||
if self.number_of_results is not None:
|
||||
if not isinstance(self.number_of_results, int):
|
||||
raise BedrockValidationError("number_of_results must be an integer")
|
||||
if self.number_of_results < 1:
|
||||
raise BedrockValidationError("number_of_results must be greater than 0")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token and (len(self.next_token) < 1 or len(self.next_token) > 2048 or ' ' in self.next_token):
|
||||
raise ValueError("next_token must be 1-2048 characters and match pattern ^\\S*$")
|
||||
Args:
|
||||
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed result with standardized format
|
||||
"""
|
||||
# Extract content
|
||||
content_obj = result.get('content', {})
|
||||
content = content_obj.get('text', '')
|
||||
content_type = content_obj.get('type', 'text')
|
||||
|
||||
# Extract location information
|
||||
location = result.get('location', {})
|
||||
location_type = location.get('type', 'unknown')
|
||||
source_uri = None
|
||||
|
||||
# Map for location types and their URI fields
|
||||
location_mapping = {
|
||||
's3Location': {'field': 'uri', 'type': 'S3'},
|
||||
'confluenceLocation': {'field': 'url', 'type': 'Confluence'},
|
||||
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'},
|
||||
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'},
|
||||
'webLocation': {'field': 'url', 'type': 'Web'},
|
||||
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'},
|
||||
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
||||
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
||||
}
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
source_uri = location[loc_key].get(config['field'])
|
||||
if not location_type or location_type == 'unknown':
|
||||
location_type = config['type']
|
||||
break
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
'content': content,
|
||||
'content_type': content_type,
|
||||
'source_type': location_type,
|
||||
'source_uri': source_uri
|
||||
}
|
||||
|
||||
# Add optional fields if available
|
||||
if 'score' in result:
|
||||
result_object['score'] = result['score']
|
||||
|
||||
if 'metadata' in result:
|
||||
result_object['metadata'] = result['metadata']
|
||||
|
||||
# Handle byte content if present
|
||||
if 'byteContent' in content_obj:
|
||||
result_object['byte_content'] = content_obj['byteContent']
|
||||
|
||||
# Handle row content if present
|
||||
if 'row' in content_obj:
|
||||
result_object['row_content'] = content_obj['row']
|
||||
|
||||
return result_object
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
@@ -113,62 +211,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Process the response
|
||||
results = []
|
||||
for result in response.get('retrievalResults', []):
|
||||
# Extract content
|
||||
content_obj = result.get('content', {})
|
||||
content = content_obj.get('text', '')
|
||||
content_type = content_obj.get('type', 'text')
|
||||
|
||||
# Extract location information
|
||||
location = result.get('location', {})
|
||||
location_type = location.get('type', 'unknown')
|
||||
source_uri = None
|
||||
|
||||
# Map for location types and their URI fields
|
||||
location_mapping = {
|
||||
's3Location': {'field': 'uri', 'type': 'S3'},
|
||||
'confluenceLocation': {'field': 'url', 'type': 'Confluence'},
|
||||
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'},
|
||||
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'},
|
||||
'webLocation': {'field': 'url', 'type': 'Web'},
|
||||
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'},
|
||||
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
||||
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
||||
}
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
source_uri = location[loc_key].get(config['field'])
|
||||
if not location_type or location_type == 'unknown':
|
||||
location_type = config['type']
|
||||
break
|
||||
|
||||
# Include score if available
|
||||
score = result.get('score')
|
||||
|
||||
# Include metadata if available
|
||||
metadata = result.get('metadata')
|
||||
|
||||
# Create a well-formed JSON object for each result
|
||||
result_object = {
|
||||
'content': content,
|
||||
'content_type': content_type,
|
||||
'source_type': location_type,
|
||||
'source_uri': source_uri
|
||||
}
|
||||
|
||||
# Add score if available
|
||||
if score is not None:
|
||||
result_object['score'] = score
|
||||
|
||||
# Add metadata if available
|
||||
if metadata:
|
||||
result_object['metadata'] = metadata
|
||||
|
||||
# Add the JSON object to results
|
||||
results.append(result_object)
|
||||
processed_result = self._process_retrieval_result(result)
|
||||
results.append(processed_result)
|
||||
|
||||
# Include nextToken in the response if available
|
||||
# Build the response object
|
||||
response_object = {}
|
||||
if results:
|
||||
response_object["results"] = results
|
||||
@@ -185,4 +231,14 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
return json.dumps(response_object, indent=2)
|
||||
|
||||
except ClientError as e:
|
||||
return f"Error retrieving from Bedrock Knowledge Base: {str(e)}"
|
||||
error_code = "Unknown"
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||
except Exception as e:
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}")
|
||||
Reference in New Issue
Block a user