mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Merge pull request #239 from raju-rangan/main
Amazon Bedrock Knowledge Bases Retriever and Agents support
This commit is contained in:
@@ -1,3 +1,9 @@
|
||||
from .s3 import S3ReaderTool, S3WriterTool
|
||||
from .bedrock import BedrockKBRetrieverTool, BedrockInvokeAgentTool
|
||||
|
||||
__all__ = ['S3ReaderTool', 'S3WriterTool']
|
||||
__all__ = [
|
||||
'S3ReaderTool',
|
||||
'S3WriterTool',
|
||||
'BedrockKBRetrieverTool',
|
||||
'BedrockInvokeAgentTool'
|
||||
]
|
||||
4
src/crewai_tools/aws/bedrock/__init__.py
Normal file
4
src/crewai_tools/aws/bedrock/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||
from .agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
|
||||
__all__ = ["BedrockKBRetrieverTool", "BedrockInvokeAgentTool"]
|
||||
181
src/crewai_tools/aws/bedrock/agents/README.md
Normal file
181
src/crewai_tools/aws/bedrock/agents/README.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# BedrockInvokeAgentTool
|
||||
|
||||
The `BedrockInvokeAgentTool` enables CrewAI agents to invoke Amazon Bedrock Agents and leverage their capabilities within your workflows.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- AWS credentials configured (either through environment variables or AWS CLI)
|
||||
- `boto3` and `python-dotenv` packages
|
||||
- Access to Amazon Bedrock Agents
|
||||
|
||||
## Usage
|
||||
|
||||
Here's how to use the tool with a CrewAI agent:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
|
||||
# Initialize the tool
|
||||
agent_tool = BedrockInvokeAgentTool(
|
||||
agent_id="your-agent-id",
|
||||
agent_alias_id="your-agent-alias-id"
|
||||
)
|
||||
|
||||
# Create a CrewAI agent that uses the tool
|
||||
aws_expert = Agent(
|
||||
role='AWS Service Expert',
|
||||
goal='Help users understand AWS services and quotas',
|
||||
backstory='I am an expert in AWS services and can provide detailed information about them.',
|
||||
tools=[agent_tool],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
quota_task = Task(
|
||||
description="Find out the current service quotas for EC2 in us-west-2 and explain any recent changes.",
|
||||
agent=aws_expert
|
||||
)
|
||||
|
||||
# Create a crew with the agent
|
||||
crew = Crew(
|
||||
agents=[aws_expert],
|
||||
tasks=[quota_task],
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Tool Arguments
|
||||
|
||||
| Argument | Type | Required | Default | Description |
|
||||
|----------|------|----------|---------|-------------|
|
||||
| agent_id | str | Yes | None | The unique identifier of the Bedrock agent |
|
||||
| agent_alias_id | str | Yes | None | The unique identifier of the agent alias |
|
||||
| session_id | str | No | timestamp | The unique identifier of the session |
|
||||
| enable_trace | bool | No | False | Whether to enable trace for debugging |
|
||||
| end_session | bool | No | False | Whether to end the session after invocation |
|
||||
| description | str | No | None | Custom description for the tool |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```bash
|
||||
BEDROCK_AGENT_ID=your-agent-id # Alternative to passing agent_id
|
||||
BEDROCK_AGENT_ALIAS_ID=your-agent-alias-id # Alternative to passing agent_alias_id
|
||||
AWS_REGION=your-aws-region # Defaults to us-west-2
|
||||
AWS_ACCESS_KEY_ID=your-access-key # Required for AWS authentication
|
||||
AWS_SECRET_ACCESS_KEY=your-secret-key # Required for AWS authentication
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Multi-Agent Workflow with Session Management
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew, Process
|
||||
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
|
||||
# Initialize tools with session management
|
||||
initial_tool = BedrockInvokeAgentTool(
|
||||
agent_id="your-agent-id",
|
||||
agent_alias_id="your-agent-alias-id",
|
||||
session_id="custom-session-id"
|
||||
)
|
||||
|
||||
followup_tool = BedrockInvokeAgentTool(
|
||||
agent_id="your-agent-id",
|
||||
agent_alias_id="your-agent-alias-id",
|
||||
session_id="custom-session-id"
|
||||
)
|
||||
|
||||
final_tool = BedrockInvokeAgentTool(
|
||||
agent_id="your-agent-id",
|
||||
agent_alias_id="your-agent-alias-id",
|
||||
session_id="custom-session-id",
|
||||
end_session=True
|
||||
)
|
||||
|
||||
# Create agents for different stages
|
||||
researcher = Agent(
|
||||
role='AWS Service Researcher',
|
||||
goal='Gather information about AWS services',
|
||||
backstory='I am specialized in finding detailed AWS service information.',
|
||||
tools=[initial_tool]
|
||||
)
|
||||
|
||||
analyst = Agent(
|
||||
role='Service Compatibility Analyst',
|
||||
goal='Analyze service compatibility and requirements',
|
||||
backstory='I analyze AWS services for compatibility and integration possibilities.',
|
||||
tools=[followup_tool]
|
||||
)
|
||||
|
||||
summarizer = Agent(
|
||||
role='Technical Documentation Writer',
|
||||
goal='Create clear technical summaries',
|
||||
backstory='I specialize in creating clear, concise technical documentation.',
|
||||
tools=[final_tool]
|
||||
)
|
||||
|
||||
# Create tasks
|
||||
research_task = Task(
|
||||
description="Find all available AWS services in us-west-2 region.",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
analysis_task = Task(
|
||||
description="Analyze which services support IPv6 and their implementation requirements.",
|
||||
agent=analyst
|
||||
)
|
||||
|
||||
summary_task = Task(
|
||||
description="Create a summary of IPv6-compatible services and their key features.",
|
||||
agent=summarizer
|
||||
)
|
||||
|
||||
# Create a crew with the agents and tasks
|
||||
crew = Crew(
|
||||
agents=[researcher, analyst, summarizer],
|
||||
tasks=[research_task, analysis_task, summary_task],
|
||||
process=Process.sequential,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Hybrid Multi-Agent Collaborations
|
||||
- Create workflows where CrewAI agents collaborate with managed Bedrock agents running as services in AWS
|
||||
- Enable scenarios where sensitive data processing happens within your AWS environment while other agents operate externally
|
||||
- Bridge on-premises CrewAI agents with cloud-based Bedrock agents for distributed intelligence workflows
|
||||
|
||||
### Data Sovereignty and Compliance
|
||||
- Keep data-sensitive agentic workflows within your AWS environment while allowing external CrewAI agents to orchestrate tasks
|
||||
- Maintain compliance with data residency requirements by processing sensitive information only within your AWS account
|
||||
- Enable secure multi-agent collaborations where some agents cannot access your organization's private data
|
||||
|
||||
### Seamless AWS Service Integration
|
||||
- Access any AWS service through Amazon Bedrock Actions without writing complex integration code
|
||||
- Enable CrewAI agents to interact with AWS services through natural language requests
|
||||
- Leverage pre-built Bedrock agent capabilities to interact with AWS services like Bedrock Knowledge Bases, Lambda, and more
|
||||
|
||||
### Scalable Hybrid Agent Architectures
|
||||
- Offload computationally intensive tasks to managed Bedrock agents while lightweight tasks run in CrewAI
|
||||
- Scale agent processing by distributing workloads between local CrewAI agents and cloud-based Bedrock agents
|
||||
|
||||
### Cross-Organizational Agent Collaboration
|
||||
- Enable secure collaboration between your organization's CrewAI agents and partner organizations' Bedrock agents
|
||||
- Create workflows where external expertise from Bedrock agents can be incorporated without exposing sensitive data
|
||||
- Build agent ecosystems that span organizational boundaries while maintaining security and data control
|
||||
3
src/crewai_tools/aws/bedrock/agents/__init__.py
Normal file
3
src/crewai_tools/aws/bedrock/agents/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .invoke_agent_tool import BedrockInvokeAgentTool
|
||||
|
||||
__all__ = ["BedrockInvokeAgentTool"]
|
||||
172
src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py
Normal file
172
src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Type, Optional, Dict, Any
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class BedrockInvokeAgentToolInput(BaseModel):
|
||||
"""Input schema for BedrockInvokeAgentTool."""
|
||||
query: str = Field(..., description="The query to send to the agent")
|
||||
|
||||
|
||||
class BedrockInvokeAgentTool(BaseTool):
|
||||
name: str = "Bedrock Agent Invoke Tool"
|
||||
description: str = "An agent responsible for policy analysis."
|
||||
args_schema: Type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
agent_id: str = None
|
||||
agent_alias_id: str = None
|
||||
session_id: str = None
|
||||
enable_trace: bool = False
|
||||
end_session: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str = None,
|
||||
agent_alias_id: str = None,
|
||||
session_id: str = None,
|
||||
enable_trace: bool = False,
|
||||
end_session: bool = False,
|
||||
description: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Initialize the BedrockInvokeAgentTool with agent configuration.
|
||||
|
||||
Args:
|
||||
agent_id (str): The unique identifier of the Bedrock agent
|
||||
agent_alias_id (str): The unique identifier of the agent alias
|
||||
session_id (str): The unique identifier of the session
|
||||
enable_trace (bool): Whether to enable trace for the agent invocation
|
||||
end_session (bool): Whether to end the session with the agent
|
||||
description (Optional[str]): Custom description for the tool
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get values from environment variables if not provided
|
||||
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID')
|
||||
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID')
|
||||
self.session_id = session_id or str(int(time.time())) # Use timestamp as session ID if not provided
|
||||
self.enable_trace = enable_trace
|
||||
self.end_session = end_session
|
||||
|
||||
# Update the description if provided
|
||||
if description:
|
||||
self.description = description
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate agent_id
|
||||
if not self.agent_id:
|
||||
raise BedrockValidationError("agent_id cannot be empty")
|
||||
if not isinstance(self.agent_id, str):
|
||||
raise BedrockValidationError("agent_id must be a string")
|
||||
|
||||
# Validate agent_alias_id
|
||||
if not self.agent_alias_id:
|
||||
raise BedrockValidationError("agent_alias_id cannot be empty")
|
||||
if not isinstance(self.agent_alias_id, str):
|
||||
raise BedrockValidationError("agent_alias_id must be a string")
|
||||
|
||||
# Validate session_id if provided
|
||||
if self.session_id and not isinstance(self.session_id, str):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-west-2'))
|
||||
)
|
||||
|
||||
# Format the prompt with current time
|
||||
current_utc = datetime.now(timezone.utc)
|
||||
prompt = f"""
|
||||
The current time is: {current_utc}
|
||||
|
||||
Below is the users query or task. Complete it and answer it consicely and to the point:
|
||||
{query}
|
||||
"""
|
||||
|
||||
# Invoke the agent
|
||||
response = bedrock_agent.invoke_agent(
|
||||
agentId=self.agent_id,
|
||||
agentAliasId=self.agent_alias_id,
|
||||
sessionId=self.session_id,
|
||||
inputText=prompt,
|
||||
enableTrace=self.enable_trace,
|
||||
endSession=self.end_session
|
||||
)
|
||||
|
||||
# Process the response
|
||||
completion = ""
|
||||
|
||||
# Check if response contains a completion field
|
||||
if 'completion' in response:
|
||||
# Process streaming response format
|
||||
for event in response.get('completion', []):
|
||||
if 'chunk' in event and 'bytes' in event['chunk']:
|
||||
chunk_bytes = event['chunk']['bytes']
|
||||
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||
completion += chunk_bytes.decode('utf-8')
|
||||
else:
|
||||
completion += str(chunk_bytes)
|
||||
|
||||
# If no completion found in streaming format, try direct format
|
||||
if not completion and 'chunk' in response and 'bytes' in response['chunk']:
|
||||
chunk_bytes = response['chunk']['bytes']
|
||||
if isinstance(chunk_bytes, (bytes, bytearray)):
|
||||
completion = chunk_bytes.decode('utf-8')
|
||||
else:
|
||||
completion = str(chunk_bytes)
|
||||
|
||||
# If still no completion, return debug info
|
||||
if not completion:
|
||||
debug_info = {
|
||||
"error": "Could not extract completion from response",
|
||||
"response_keys": list(response.keys())
|
||||
}
|
||||
|
||||
# Add more debug info
|
||||
if 'chunk' in response:
|
||||
debug_info["chunk_keys"] = list(response['chunk'].keys())
|
||||
|
||||
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
|
||||
|
||||
return completion
|
||||
|
||||
except ClientError as e:
|
||||
error_code = "Unknown"
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||
except BedrockAgentError:
|
||||
# Re-raise BedrockAgentError exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BedrockAgentError(f"Unexpected error: {str(e)}")
|
||||
17
src/crewai_tools/aws/bedrock/exceptions.py
Normal file
17
src/crewai_tools/aws/bedrock/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Custom exceptions for AWS Bedrock integration."""
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Base exception for Bedrock-related errors."""
|
||||
pass
|
||||
|
||||
class BedrockAgentError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Agent operations."""
|
||||
pass
|
||||
|
||||
class BedrockKnowledgeBaseError(BedrockError):
|
||||
"""Exception raised for errors in the Bedrock Knowledge Base operations."""
|
||||
pass
|
||||
|
||||
class BedrockValidationError(BedrockError):
|
||||
"""Exception raised for validation errors in Bedrock operations."""
|
||||
pass
|
||||
159
src/crewai_tools/aws/bedrock/knowledge_base/README.md
Normal file
159
src/crewai_tools/aws/bedrock/knowledge_base/README.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# BedrockKBRetrieverTool
|
||||
|
||||
The `BedrockKBRetrieverTool` enables CrewAI agents to retrieve information from Amazon Bedrock Knowledge Bases using natural language queries.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- AWS credentials configured (either through environment variables or AWS CLI)
|
||||
- `boto3` and `python-dotenv` packages
|
||||
- Access to Amazon Bedrock Knowledge Base
|
||||
|
||||
## Usage
|
||||
|
||||
Here's how to use the tool with a CrewAI agent:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai_tools.aws.bedrock.knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||
|
||||
# Initialize the tool
|
||||
kb_tool = BedrockKBRetrieverTool(
|
||||
knowledge_base_id="your-kb-id",
|
||||
number_of_results=5
|
||||
)
|
||||
|
||||
# Create a CrewAI agent that uses the tool
|
||||
researcher = Agent(
|
||||
role='Knowledge Base Researcher',
|
||||
goal='Find information about company policies',
|
||||
backstory='I am a researcher specialized in retrieving and analyzing company documentation.',
|
||||
tools=[kb_tool],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
research_task = Task(
|
||||
description="Find our company's remote work policy and summarize the key points.",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
# Create a crew with the agent
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[research_task],
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Tool Arguments
|
||||
|
||||
| Argument | Type | Required | Default | Description |
|
||||
|----------|------|----------|---------|-------------|
|
||||
| knowledge_base_id | str | Yes | None | The unique identifier of the knowledge base (0-10 alphanumeric characters) |
|
||||
| number_of_results | int | No | 5 | Maximum number of results to return |
|
||||
| retrieval_configuration | dict | No | None | Custom configurations for the knowledge base query |
|
||||
| guardrail_configuration | dict | No | None | Content filtering settings |
|
||||
| next_token | str | No | None | Token for pagination |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```bash
|
||||
BEDROCK_KB_ID=your-knowledge-base-id # Alternative to passing knowledge_base_id
|
||||
AWS_REGION=your-aws-region # Defaults to us-east-1
|
||||
AWS_ACCESS_KEY_ID=your-access-key # Required for AWS authentication
|
||||
AWS_SECRET_ACCESS_KEY=your-secret-key # Required for AWS authentication
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
The tool returns results in JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"content": "Retrieved text content",
|
||||
"content_type": "text",
|
||||
"source_type": "S3",
|
||||
"source_uri": "s3://bucket/document.pdf",
|
||||
"score": 0.95,
|
||||
"metadata": {
|
||||
"additional": "metadata"
|
||||
}
|
||||
}
|
||||
],
|
||||
"nextToken": "pagination-token",
|
||||
"guardrailAction": "NONE"
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Retrieval Configuration
|
||||
|
||||
```python
|
||||
kb_tool = BedrockKBRetrieverTool(
|
||||
knowledge_base_id="your-kb-id",
|
||||
retrieval_configuration={
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": 10,
|
||||
"overrideSearchType": "HYBRID"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
policy_expert = Agent(
|
||||
role='Policy Expert',
|
||||
goal='Analyze company policies in detail',
|
||||
backstory='I am an expert in corporate policy analysis with deep knowledge of regulatory requirements.',
|
||||
tools=[kb_tool]
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Data Sources
|
||||
|
||||
- Amazon S3
|
||||
- Confluence
|
||||
- Salesforce
|
||||
- SharePoint
|
||||
- Web pages
|
||||
- Custom document locations
|
||||
- Amazon Kendra
|
||||
- SQL databases
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Enterprise Knowledge Integration
|
||||
- Enable CrewAI agents to access your organization's proprietary knowledge without exposing sensitive data
|
||||
- Allow agents to make decisions based on your company's specific policies, procedures, and documentation
|
||||
- Create agents that can answer questions based on your internal documentation while maintaining data security
|
||||
|
||||
### Specialized Domain Knowledge
|
||||
- Connect CrewAI agents to domain-specific knowledge bases (legal, medical, technical) without retraining models
|
||||
- Leverage existing knowledge repositories that are already maintained in your AWS environment
|
||||
- Combine CrewAI's reasoning with domain-specific information from your knowledge bases
|
||||
|
||||
### Data-Driven Decision Making
|
||||
- Ground CrewAI agent responses in your actual company data rather than general knowledge
|
||||
- Ensure agents provide recommendations based on your specific business context and documentation
|
||||
- Reduce hallucinations by retrieving factual information from your knowledge bases
|
||||
|
||||
### Scalable Information Access
|
||||
- Access terabytes of organizational knowledge without embedding it all into your models
|
||||
- Dynamically query only the relevant information needed for specific tasks
|
||||
- Leverage AWS's scalable infrastructure to handle large knowledge bases efficiently
|
||||
|
||||
### Compliance and Governance
|
||||
- Ensure CrewAI agents provide responses that align with your company's approved documentation
|
||||
- Create auditable trails of information sources used by your agents
|
||||
- Maintain control over what information sources your agents can access
|
||||
3
src/crewai_tools/aws/bedrock/knowledge_base/__init__.py
Normal file
3
src/crewai_tools/aws/bedrock/knowledge_base/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .retriever_tool import BedrockKBRetrieverTool
|
||||
|
||||
__all__ = ["BedrockKBRetrieverTool"]
|
||||
244
src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py
Normal file
244
src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from typing import Type, Optional, List, Dict, Any
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class BedrockKBRetrieverToolInput(BaseModel):
|
||||
"""Input schema for BedrockKBRetrieverTool."""
|
||||
query: str = Field(..., description="The query to retrieve information from the knowledge base")
|
||||
|
||||
|
||||
class BedrockKBRetrieverTool(BaseTool):
|
||||
name: str = "Bedrock Knowledge Base Retriever Tool"
|
||||
description: str = "Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||
args_schema: Type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
knowledge_base_id: str = None
|
||||
number_of_results: Optional[int] = 5
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None
|
||||
next_token: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_base_id: str = None,
|
||||
number_of_results: Optional[int] = 5,
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None,
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None,
|
||||
next_token: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Initialize the BedrockKBRetrieverTool with knowledge base configuration.
|
||||
|
||||
Args:
|
||||
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
||||
number_of_results (Optional[int], optional): The maximum number of results to return. Defaults to 5.
|
||||
retrieval_configuration (Optional[Dict[str, Any]], optional): Configurations for the knowledge base query and retrieval process. Defaults to None.
|
||||
guardrail_configuration (Optional[Dict[str, Any]], optional): Guardrail settings. Defaults to None.
|
||||
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get knowledge_base_id from environment variable if not provided
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
|
||||
self.number_of_results = number_of_results
|
||||
self.guardrail_configuration = guardrail_configuration
|
||||
self.next_token = next_token
|
||||
|
||||
# Initialize retrieval_configuration with provided parameters or use the one provided
|
||||
if retrieval_configuration is None:
|
||||
self.retrieval_configuration = self._build_retrieval_configuration()
|
||||
else:
|
||||
self.retrieval_configuration = retrieval_configuration
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
# Update the description to include the knowledge base details
|
||||
self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"
|
||||
|
||||
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||
"""Build the retrieval configuration based on provided parameters.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The constructed retrieval configuration
|
||||
"""
|
||||
vector_search_config = {}
|
||||
|
||||
# Add number of results if provided
|
||||
if self.number_of_results is not None:
|
||||
vector_search_config["numberOfResults"] = self.number_of_results
|
||||
|
||||
return {"vectorSearchConfiguration": vector_search_config}
|
||||
|
||||
def _validate_parameters(self):
|
||||
"""Validate the parameters according to AWS API requirements."""
|
||||
try:
|
||||
# Validate knowledge_base_id
|
||||
if not self.knowledge_base_id:
|
||||
raise BedrockValidationError("knowledge_base_id cannot be empty")
|
||||
if not isinstance(self.knowledge_base_id, str):
|
||||
raise BedrockValidationError("knowledge_base_id must be a string")
|
||||
if len(self.knowledge_base_id) > 10:
|
||||
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
|
||||
if not all(c.isalnum() for c in self.knowledge_base_id):
|
||||
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token:
|
||||
if not isinstance(self.next_token, str):
|
||||
raise BedrockValidationError("next_token must be a string")
|
||||
if len(self.next_token) < 1 or len(self.next_token) > 2048:
|
||||
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
|
||||
if ' ' in self.next_token:
|
||||
raise BedrockValidationError("next_token cannot contain spaces")
|
||||
|
||||
# Validate number_of_results if provided
|
||||
if self.number_of_results is not None:
|
||||
if not isinstance(self.number_of_results, int):
|
||||
raise BedrockValidationError("number_of_results must be an integer")
|
||||
if self.number_of_results < 1:
|
||||
raise BedrockValidationError("number_of_results must be greater than 0")
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||
|
||||
Args:
|
||||
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed result with standardized format
|
||||
"""
|
||||
# Extract content
|
||||
content_obj = result.get('content', {})
|
||||
content = content_obj.get('text', '')
|
||||
content_type = content_obj.get('type', 'text')
|
||||
|
||||
# Extract location information
|
||||
location = result.get('location', {})
|
||||
location_type = location.get('type', 'unknown')
|
||||
source_uri = None
|
||||
|
||||
# Map for location types and their URI fields
|
||||
location_mapping = {
|
||||
's3Location': {'field': 'uri', 'type': 'S3'},
|
||||
'confluenceLocation': {'field': 'url', 'type': 'Confluence'},
|
||||
'salesforceLocation': {'field': 'url', 'type': 'Salesforce'},
|
||||
'sharePointLocation': {'field': 'url', 'type': 'SharePoint'},
|
||||
'webLocation': {'field': 'url', 'type': 'Web'},
|
||||
'customDocumentLocation': {'field': 'id', 'type': 'CustomDocument'},
|
||||
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
||||
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
||||
}
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
source_uri = location[loc_key].get(config['field'])
|
||||
if not location_type or location_type == 'unknown':
|
||||
location_type = config['type']
|
||||
break
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
'content': content,
|
||||
'content_type': content_type,
|
||||
'source_type': location_type,
|
||||
'source_uri': source_uri
|
||||
}
|
||||
|
||||
# Add optional fields if available
|
||||
if 'score' in result:
|
||||
result_object['score'] = result['score']
|
||||
|
||||
if 'metadata' in result:
|
||||
result_object['metadata'] = result['metadata']
|
||||
|
||||
# Handle byte content if present
|
||||
if 'byteContent' in content_obj:
|
||||
result_object['byte_content'] = content_obj['byteContent']
|
||||
|
||||
# Handle row content if present
|
||||
if 'row' in content_obj:
|
||||
result_object['row_content'] = content_obj['row']
|
||||
|
||||
return result_object
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent_runtime = boto3.client(
|
||||
'bedrock-agent-runtime',
|
||||
region_name=os.getenv('AWS_REGION', os.getenv('AWS_DEFAULT_REGION', 'us-east-1')),
|
||||
# AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
|
||||
)
|
||||
|
||||
# Prepare the request parameters
|
||||
retrieve_params = {
|
||||
'knowledgeBaseId': self.knowledge_base_id,
|
||||
'retrievalQuery': {
|
||||
'text': query
|
||||
}
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if self.retrieval_configuration:
|
||||
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration
|
||||
|
||||
if self.guardrail_configuration:
|
||||
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration
|
||||
|
||||
if self.next_token:
|
||||
retrieve_params['nextToken'] = self.next_token
|
||||
|
||||
# Make the retrieve API call
|
||||
response = bedrock_agent_runtime.retrieve(**retrieve_params)
|
||||
|
||||
# Process the response
|
||||
results = []
|
||||
for result in response.get('retrievalResults', []):
|
||||
processed_result = self._process_retrieval_result(result)
|
||||
results.append(processed_result)
|
||||
|
||||
# Build the response object
|
||||
response_object = {}
|
||||
if results:
|
||||
response_object["results"] = results
|
||||
else:
|
||||
response_object["message"] = "No results found for the given query."
|
||||
|
||||
if "nextToken" in response:
|
||||
response_object["nextToken"] = response["nextToken"]
|
||||
|
||||
if "guardrailAction" in response:
|
||||
response_object["guardrailAction"] = response["guardrailAction"]
|
||||
|
||||
# Return the results as a JSON string
|
||||
return json.dumps(response_object, indent=2)
|
||||
|
||||
except ClientError as e:
|
||||
error_code = "Unknown"
|
||||
error_message = str(e)
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||
except Exception as e:
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}")
|
||||
Reference in New Issue
Block a user