From db309ca1ae5f26fcffd297b5c2112ee348d136fd Mon Sep 17 00:00:00 2001 From: Raju Rangan Date: Tue, 11 Mar 2025 16:46:12 -0400 Subject: [PATCH] - Add custom exceptions for better error handling - Add parameter validation for Bedrock tools - Improve response processing and debug information - Maintain backward compatibility with existing implementations --- .../aws/bedrock/agents/invoke_agent_tool.py | 42 +++- src/crewai_tools/aws/bedrock/exceptions.py | 17 ++ .../bedrock/knowledge_base/retriever_tool.py | 204 +++++++++++------- 3 files changed, 184 insertions(+), 79 deletions(-) create mode 100644 src/crewai_tools/aws/bedrock/exceptions.py diff --git a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py index 41ecad75b..6c43480c0 100644 --- a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py +++ b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py @@ -11,6 +11,9 @@ from pydantic import BaseModel, Field import boto3 from botocore.exceptions import ClientError +# Import custom exceptions +from ..exceptions import BedrockAgentError, BedrockValidationError + # Load environment variables from .env file load_dotenv() @@ -62,6 +65,31 @@ class BedrockInvokeAgentTool(BaseTool): # Update the description if provided if description: self.description = description + + # Validate parameters + self._validate_parameters() + + def _validate_parameters(self): + """Validate the parameters according to AWS API requirements.""" + try: + # Validate agent_id + if not self.agent_id: + raise BedrockValidationError("agent_id cannot be empty") + if not isinstance(self.agent_id, str): + raise BedrockValidationError("agent_id must be a string") + + # Validate agent_alias_id + if not self.agent_alias_id: + raise BedrockValidationError("agent_alias_id cannot be empty") + if not isinstance(self.agent_alias_id, str): + raise BedrockValidationError("agent_alias_id must be a string") + + # Validate session_id if provided + if self.session_id and not isinstance(self.session_id, str): + raise BedrockValidationError("session_id must be a string") + + except BedrockValidationError as e: + raise BedrockValidationError(f"Parameter validation failed: {str(e)}") def _run(self, query: str) -> str: try: @@ -123,7 +151,7 @@ Below is the users query or task. Complete it and answer it consicely and to the if 'chunk' in response: debug_info["chunk_keys"] = list(response['chunk'].keys()) - return json.dumps(debug_info, indent=2) + raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}") return completion @@ -132,9 +160,13 @@ Below is the users query or task. Complete it and answer it consicely and to the error_message = str(e) # Try to extract error code if available - if hasattr(e, 'response') and 'Error' in e.response and 'Code' in e.response['Error']: - error_code = e.response['Error']['Code'] + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error'].get('Code', 'Unknown') + error_message = e.response['Error'].get('Message', str(e)) - return f"Error invoking Bedrock Agent ({error_code}): {error_message}" + raise BedrockAgentError(f"Error ({error_code}): {error_message}") + except BedrockAgentError: + # Re-raise BedrockAgentError exceptions + raise except Exception as e: - return f"Error: {str(e)}" \ No newline at end of file + raise BedrockAgentError(f"Unexpected error: {str(e)}") \ No newline at end of file diff --git a/src/crewai_tools/aws/bedrock/exceptions.py b/src/crewai_tools/aws/bedrock/exceptions.py new file mode 100644 index 000000000..d1aa2623c --- /dev/null +++ b/src/crewai_tools/aws/bedrock/exceptions.py @@ -0,0 +1,17 @@ +"""Custom exceptions for AWS Bedrock integration.""" + +class BedrockError(Exception): + """Base exception for Bedrock-related errors.""" + pass + +class BedrockAgentError(BedrockError): + """Exception raised for errors in the Bedrock Agent operations.""" + pass + +class BedrockKnowledgeBaseError(BedrockError): + """Exception raised for errors in the Bedrock Knowledge Base operations.""" + pass + +class BedrockValidationError(BedrockError): + """Exception raised for validation errors in Bedrock operations.""" + pass \ No newline at end of file diff --git a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py index c01e83cff..55a15b621 100644 --- a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py +++ b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py @@ -8,6 +8,9 @@ from pydantic import BaseModel, Field import boto3 from botocore.exceptions import ClientError +# Import custom exceptions +from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError + # Load environment variables from .env file load_dotenv() @@ -39,7 +42,7 @@ class BedrockKBRetrieverTool(BaseTool): """Initialize the BedrockKBRetrieverTool with knowledge base configuration. Args: - knowledge_base_id (str): The unique identifier of the knowledge base to query (length: 0-10, pattern: ^[0-9a-zA-Z]+$) + knowledge_base_id (str): The unique identifier of the knowledge base to query number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5. retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None. guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None. @@ -50,19 +53,14 @@ class BedrockKBRetrieverTool(BaseTool): # Get knowledge_base_id from environment variable if not provided self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID') self.number_of_results = number_of_results - - # Initialize retrieval_configuration with number_of_results if provided - if retrieval_configuration is None and number_of_results is not None: - self.retrieval_configuration = { - "vectorSearchConfiguration": { - "numberOfResults": number_of_results - } - } - else: - self.retrieval_configuration = retrieval_configuration - self.guardrail_configuration = guardrail_configuration self.next_token = next_token + + # Initialize retrieval_configuration with provided parameters or use the one provided + if retrieval_configuration is None: + self.retrieval_configuration = self._build_retrieval_configuration() + else: + self.retrieval_configuration = retrieval_configuration # Validate parameters self._validate_parameters() @@ -70,15 +68,115 @@ class BedrockKBRetrieverTool(BaseTool): # Update the description to include the knowledge base details self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query" + def _build_retrieval_configuration(self) -> Dict[str, Any]: + """Build the retrieval configuration based on provided parameters. + + Returns: + Dict[str, Any]: The constructed retrieval configuration + """ + vector_search_config = {} + + # Add number of results if provided + if self.number_of_results is not None: + vector_search_config["numberOfResults"] = self.number_of_results + + return {"vectorSearchConfiguration": vector_search_config} + def _validate_parameters(self): """Validate the parameters according to AWS API requirements.""" - # Validate knowledge_base_id - if not self.knowledge_base_id or len(self.knowledge_base_id) > 10 or not all(c.isalnum() for c in self.knowledge_base_id): - raise ValueError("knowledge_base_id must be 0-10 alphanumeric characters") + try: + # Validate knowledge_base_id + if not self.knowledge_base_id: + raise BedrockValidationError("knowledge_base_id cannot be empty") + if not isinstance(self.knowledge_base_id, str): + raise BedrockValidationError("knowledge_base_id must be a string") + if len(self.knowledge_base_id) > 10: + raise BedrockValidationError("knowledge_base_id must be 10 characters or less") + if not all(c.isalnum() for c in self.knowledge_base_id): + raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters") + + # Validate next_token if provided + if self.next_token: + if not isinstance(self.next_token, str): + raise BedrockValidationError("next_token must be a string") + if len(self.next_token) < 1 or len(self.next_token) > 2048: + raise BedrockValidationError("next_token must be between 1 and 2048 characters") + if ' ' in self.next_token: + raise BedrockValidationError("next_token cannot contain spaces") + + # Validate number_of_results if provided + if self.number_of_results is not None: + if not isinstance(self.number_of_results, int): + raise BedrockValidationError("number_of_results must be an integer") + if self.number_of_results < 1: + raise BedrockValidationError("number_of_results must be greater than 0") + + except BedrockValidationError as e: + raise BedrockValidationError(f"Parameter validation failed: {str(e)}") + + def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Process a single retrieval result from Bedrock Knowledge Base. - # Validate next_token if provided - if self.next_token and (len(self.next_token) < 1 or len(self.next_token) > 2048 or ' ' in self.next_token): - raise ValueError("next_token must be 1-2048 characters and match pattern ^\\S*$") + Args: + result (Dict[str, Any]): Raw result from Bedrock Knowledge Base + + Returns: + Dict[str, Any]: Processed result with standardized format + """ + # Extract content + content_obj = result.get('content', {}) + content = content_obj.get('text', '') + content_type = content_obj.get('type', 'text') + + # Extract location information + location = result.get('location', {}) + location_type = location.get('type', 'unknown') + source_uri = None + + # Map for location types and their URI fields + location_mapping = { + 's3Location': {'field': 'uri', 'type': 'S3'}, + 'confluenceLocation': {'field': 'url', 'type': 'Confluence'}, + 'salesforceLocation': {'field': 'url', 'type': 'Salesforce'}, + 'sharePointLocation': {'field': 'url', 'type': 'SharePoint'}, + 'webLocation': {'field': 'url', 'type': 'Web'}, + 'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'}, + 'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'}, + 'sqlLocation': {'field': 'query', 'type': 'SQL'} + } + + # Extract the URI based on location type + for loc_key, config in location_mapping.items(): + if loc_key in location: + source_uri = location[loc_key].get(config['field']) + if not location_type or location_type == 'unknown': + location_type = config['type'] + break + + # Create result object + result_object = { + 'content': content, + 'content_type': content_type, + 'source_type': location_type, + 'source_uri': source_uri + } + + # Add optional fields if available + if 'score' in result: + result_object['score'] = result['score'] + + if 'metadata' in result: + result_object['metadata'] = result['metadata'] + + # Handle byte content if present + if 'byteContent' in content_obj: + result_object['byte_content'] = content_obj['byteContent'] + + # Handle row content if present + if 'row' in content_obj: + result_object['row_content'] = content_obj['row'] + + return result_object def _run(self, query: str) -> str: try: @@ -113,62 +211,10 @@ class BedrockKBRetrieverTool(BaseTool): # Process the response results = [] for result in response.get('retrievalResults', []): - # Extract content - content_obj = result.get('content', {}) - content = content_obj.get('text', '') - content_type = content_obj.get('type', 'text') - - # Extract location information - location = result.get('location', {}) - location_type = location.get('type', 'unknown') - source_uri = None - - # Map for location types and their URI fields - location_mapping = { - 's3Location': {'field': 'uri', 'type': 'S3'}, - 'confluenceLocation': {'field': 'url', 'type': 'Confluence'}, - 'salesforceLocation': {'field': 'url', 'type': 'Salesforce'}, - 'sharePointLocation': {'field': 'url', 'type': 'SharePoint'}, - 'webLocation': {'field': 'url', 'type': 'Web'}, - 'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'}, - 'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'}, - 'sqlLocation': {'field': 'query', 'type': 'SQL'} - } - - # Extract the URI based on location type - for loc_key, config in location_mapping.items(): - if loc_key in location: - source_uri = location[loc_key].get(config['field']) - if not location_type or location_type == 'unknown': - location_type = config['type'] - break - - # Include score if available - score = result.get('score') - - # Include metadata if available - metadata = result.get('metadata') - - # Create a well-formed JSON object for each result - result_object = { - 'content': content, - 'content_type': content_type, - 'source_type': location_type, - 'source_uri': source_uri - } - - # Add score if available - if score is not None: - result_object['score'] = score - - # Add metadata if available - if metadata: - result_object['metadata'] = metadata - - # Add the JSON object to results - results.append(result_object) + processed_result = self._process_retrieval_result(result) + results.append(processed_result) - # Include nextToken in the response if available + # Build the response object response_object = {} if results: response_object["results"] = results @@ -185,4 +231,14 @@ class BedrockKBRetrieverTool(BaseTool): return json.dumps(response_object, indent=2) except ClientError as e: - return f"Error retrieving from Bedrock Knowledge Base: {str(e)}" \ No newline at end of file + error_code = "Unknown" + error_message = str(e) + + # Try to extract error code if available + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error'].get('Code', 'Unknown') + error_message = e.response['Error'].get('Message', str(e)) + + raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}") + except Exception as e: + raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}") \ No newline at end of file