mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
git-subtree-dir: packages/tools git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
248 lines
10 KiB
Python
248 lines
10 KiB
Python
from typing import Type, Optional, List, Dict, Any
|
|
import os
|
|
import json
|
|
from dotenv import load_dotenv
|
|
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
class BedrockKBRetrieverToolInput(BaseModel):
|
|
"""Input schema for BedrockKBRetrieverTool."""
|
|
query: str = Field(..., description="The query to retrieve information from the knowledge base")
|
|
|
|
|
|
class BedrockKBRetrieverTool(BaseTool):
|
|
name: str = "Bedrock Knowledge Base Retriever Tool"
|
|
description: str = "Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
|
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput
|
|
knowledge_base_id: str = None
|
|
number_of_results: Optional[int] = 5
|
|
retrieval_configuration: Optional[Dict[str, Any]] = None
|
|
guardrail_configuration: Optional[Dict[str, Any]] = None
|
|
next_token: Optional[str] = None
|
|
package_dependencies: List[str] = ["boto3"]
|
|
|
|
def __init__(
|
|
self,
|
|
knowledge_base_id: str = None,
|
|
number_of_results: Optional[int] = 5,
|
|
retrieval_configuration: Optional[Dict[str, Any]] = None,
|
|
guardrail_configuration: Optional[Dict[str, Any]] = None,
|
|
next_token: Optional[str] = None,
|
|
**kwargs
|
|
):
|
|
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
|
|
|
Args:
|
|
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
|
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
|
|
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
|
|
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
|
|
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
|
|
# Get knowledge_base_id from environment variable if not provided
|
|
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
|
|
self.number_of_results = number_of_results
|
|
self.guardrail_configuration = guardrail_configuration
|
|
self.next_token = next_token
|
|
|
|
# Initialize retrieval_configuration with provided parameters or use the one provided
|
|
if retrieval_configuration is None:
|
|
self.retrieval_configuration = self._build_retrieval_configuration()
|
|
else:
|
|
self.retrieval_configuration = retrieval_configuration
|
|
|
|
# Validate parameters
|
|
self._validate_parameters()
|
|
|
|
# Update the description to include the knowledge base details
|
|
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
|
|
|
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
|
"""Build the retrieval configuration based on provided parameters.
|
|
|
|
Returns:
|
|
Dict[str, Any]: The constructed retrieval configuration
|
|
"""
|
|
vector_search_config = {}
|
|
|
|
# Add number of results if provided
|
|
if self.number_of_results is not None:
|
|
vector_search_config["numberOfResults"] = self.number_of_results
|
|
|
|
return {"vectorSearchConfiguration": vector_search_config}
|
|
|
|
def _validate_parameters(self):
|
|
"""Validate the parameters according to AWS API requirements."""
|
|
try:
|
|
# Validate knowledge_base_id
|
|
if not self.knowledge_base_id:
|
|
raise BedrockValidationError("knowledge_base_id cannot be empty")
|
|
if not isinstance(self.knowledge_base_id, str):
|
|
raise BedrockValidationError("knowledge_base_id must be a string")
|
|
if len(self.knowledge_base_id) > 10:
|
|
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
|
|
if not all(c.isalnum() for c in self.knowledge_base_id):
|
|
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
|
|
|
|
# Validate next_token if provided
|
|
if self.next_token:
|
|
if not isinstance(self.next_token, str):
|
|
raise BedrockValidationError("next_token must be a string")
|
|
if len(self.next_token) < 1 or len(self.next_token) > 2048:
|
|
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
|
|
if ' ' in self.next_token:
|
|
raise BedrockValidationError("next_token cannot contain spaces")
|
|
|
|
# Validate number_of_results if provided
|
|
if self.number_of_results is not None:
|
|
if not isinstance(self.number_of_results, int):
|
|
raise BedrockValidationError("number_of_results must be an integer")
|
|
if self.number_of_results < 1:
|
|
raise BedrockValidationError("number_of_results must be greater than 0")
|
|
|
|
except BedrockValidationError as e:
|
|
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
|
|
|
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Process a single retrieval result from Bedrock Knowledge Base.
|
|
|
|
Args:
|
|
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
|
|
|
|
Returns:
|
|
Dict[str, Any]: Processed result with standardized format
|
|
"""
|
|
# Extract content
|
|
content_obj = result.get('content', {})
|
|
content = content_obj.get('text', '')
|
|
content_type = content_obj.get('type', 'text')
|
|
|
|
# Extract location information
|
|
location = result.get('location', {})
|
|
location_type = location.get('type', 'unknown')
|
|
source_uri = None
|
|
|
|
# Map for location types and their URI fields
|
|
location_mapping = {
|
|
's3Location': {'field': 'uri', 'type': 'S3'},
|
|
'confluenceLocation': {'field': 'url', 'type': 'Confluence'},
|
|
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'},
|
|
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'},
|
|
'webLocation': {'field': 'url', 'type': 'Web'},
|
|
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'},
|
|
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
|
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
|
}
|
|
|
|
# Extract the URI based on location type
|
|
for loc_key, config in location_mapping.items():
|
|
if loc_key in location:
|
|
source_uri = location[loc_key].get(config['field'])
|
|
if not location_type or location_type == 'unknown':
|
|
location_type = config['type']
|
|
break
|
|
|
|
# Create result object
|
|
result_object = {
|
|
'content': content,
|
|
'content_type': content_type,
|
|
'source_type': location_type,
|
|
'source_uri': source_uri
|
|
}
|
|
|
|
# Add optional fields if available
|
|
if 'score' in result:
|
|
result_object['score'] = result['score']
|
|
|
|
if 'metadata' in result:
|
|
result_object['metadata'] = result['metadata']
|
|
|
|
# Handle byte content if present
|
|
if 'byteContent' in content_obj:
|
|
result_object['byte_content'] = content_obj['byteContent']
|
|
|
|
# Handle row content if present
|
|
if 'row' in content_obj:
|
|
result_object['row_content'] = content_obj['row']
|
|
|
|
return result_object
|
|
|
|
def _run(self, query: str) -> str:
|
|
try:
|
|
import boto3
|
|
from botocore.exceptions import ClientError
|
|
except ImportError:
|
|
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
|
|
|
try:
|
|
# Initialize the Bedrock Agent Runtime client
|
|
bedrock_agent_runtime = boto3.client(
|
|
'bedrock-agent-runtime',
|
|
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-east-1')),
|
|
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
|
|
)
|
|
|
|
# Prepare the request parameters
|
|
retrieve_params = {
|
|
'knowledgeBaseId': self.knowledge_base_id,
|
|
'retrievalQuery': {
|
|
'text': query
|
|
}
|
|
}
|
|
|
|
# Add optional parameters if provided
|
|
if self.retrieval_configuration:
|
|
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration
|
|
|
|
if self.guardrail_configuration:
|
|
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration
|
|
|
|
if self.next_token:
|
|
retrieve_params['nextToken'] = self.next_token
|
|
|
|
# Make the retrieve API call
|
|
response = bedrock_agent_runtime.retrieve(**retrieve_params)
|
|
|
|
# Process the response
|
|
results = []
|
|
for result in response.get('retrievalResults', []):
|
|
processed_result = self._process_retrieval_result(result)
|
|
results.append(processed_result)
|
|
|
|
# Build the response object
|
|
response_object = {}
|
|
if results:
|
|
response_object["results"] = results
|
|
else:
|
|
response_object["message"] = "No results found for the given query."
|
|
|
|
if "nextToken" in response:
|
|
response_object["nextToken"] = response["nextToken"]
|
|
|
|
if "guardrailAction" in response:
|
|
response_object["guardrailAction"] = response["guardrailAction"]
|
|
|
|
# Return the results as a JSON string
|
|
return json.dumps(response_object, indent=2)
|
|
|
|
except ClientError as e:
|
|
error_code = "Unknown"
|
|
error_message = str(e)
|
|
|
|
# Try to extract error code if available
|
|
if hasattr(e, 'response') and 'Error' in e.response:
|
|
error_code = e.response['Error'].get('Code', 'Unknown')
|
|
error_message = e.response['Error'].get('Message', str(e))
|
|
|
|
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
|
except Exception as e:
|
|
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}") |