mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Merge pull request #80 from crewAIInc/feat/NL2SQL-tool
feat: Add nl2sql tool to run and execute sql queries in databases
This commit is contained in:
74
src/crewai_tools/tools/nl2sql/README.md
Normal file
74
src/crewai_tools/tools/nl2sql/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# NL2SQL Tool
|
||||
|
||||
## Description
|
||||
|
||||
This tool is used to convert natural language to SQL queries. When passsed to the agent it will generate queries and then use them to interact with the database.
|
||||
|
||||
This enables multiple workflows like having an Agent to access the database fetch information based on the goal and then use the information to generate a response, report or any other output. Along with that proivdes the ability for the Agent to update the database based on its goal.
|
||||
|
||||
**Attention**: Make sure that the Agent has access to a Read-Replica or that is okay for the Agent to run insert/update queries on the database.
|
||||
|
||||
## Requirements
|
||||
|
||||
- SqlAlchemy
|
||||
- Any DB compatible library (e.g. psycopg2, mysql-connector-python)
|
||||
|
||||
## Installation
|
||||
Install the crewai_tools package
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
In order to use the NL2SQLTool, you need to pass the database URI to the tool. The URI should be in the format `dialect+driver://username:password@host:port/database`.
|
||||
|
||||
|
||||
```python
|
||||
from crewai_tools import NL2SQLTool
|
||||
|
||||
# psycopg2 was installed to run this example with PostgreSQL
|
||||
nl2sql = NL2SQLTool(db_uri="postgresql://example@localhost:5432/test_db")
|
||||
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["researcher"],
|
||||
allow_delegation=False,
|
||||
tools=[nl2sql]
|
||||
)
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
The primary task goal was:
|
||||
|
||||
"Retrieve the average, maximum, and minimum monthly revenue for each city, but only include cities that have more than one user. Also, count the number of user in each city and sort the results by the average monthly revenue in descending order"
|
||||
|
||||
So the Agent tried to get information from the DB, the first one is wrong so the Agent tries again and gets the correct information and passes to the next agent.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
The second task goal was:
|
||||
|
||||
"Review the data and create a detailed report, and then create the table on the database with the fields based on the data provided.
|
||||
Include information on the average, maximum, and minimum monthly revenue for each city, but only include cities that have more than one user. Also, count the number of users in each city and sort the results by the average monthly revenue in descending order."
|
||||
|
||||
Now things start to get interesting, the Agent generates the SQL query to not only create the table but also insert the data into the table. And in the end the Agent still returns the final report which is exactly what was in the database.
|
||||
|
||||

|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
This is a simple example of how the NL2SQLTool can be used to interact with the database and generate reports based on the data in the database.
|
||||
|
||||
The Tool provides endless possibilities on the logic of the Agent and how it can interact with the database.
|
||||
|
||||
```
|
||||
DB -> Agent -> ... -> Agent -> DB
|
||||
```
|
||||
BIN
src/crewai_tools/tools/nl2sql/images/image-2.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
BIN
src/crewai_tools/tools/nl2sql/images/image-3.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
BIN
src/crewai_tools/tools/nl2sql/images/image-4.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
BIN
src/crewai_tools/tools/nl2sql/images/image-5.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-5.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
BIN
src/crewai_tools/tools/nl2sql/images/image-7.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-7.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
src/crewai_tools/tools/nl2sql/images/image-9.png
Normal file
BIN
src/crewai_tools/tools/nl2sql/images/image-9.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
72
src/crewai_tools/tools/nl2sql/nl2sql_tool.py
Normal file
72
src/crewai_tools/tools/nl2sql/nl2sql_tool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from crewai_tools import BaseTool
|
||||
from pydantic import Field
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
class NL2SQLTool(BaseTool):
|
||||
name: str = "NL2SQLTool"
|
||||
description: str = "Converts natural language to SQL queries and executes them."
|
||||
db_uri: str = Field(
|
||||
title="Database URI",
|
||||
description="The URI of the database to connect to.",
|
||||
)
|
||||
tables: list = []
|
||||
columns: dict = {}
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
data = {}
|
||||
tables = self._fetch_available_tables()
|
||||
|
||||
for table in tables:
|
||||
table_columns = self._fetch_all_available_columns(table["table_name"])
|
||||
data[f'{table["table_name"]}_columns'] = table_columns
|
||||
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
def _fetch_available_tables(self):
|
||||
return self.execute_sql(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
def _fetch_all_available_columns(self, table_name: str):
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
)
|
||||
|
||||
def _run(self, sql_query: str):
|
||||
try:
|
||||
data = self.execute_sql(sql_query)
|
||||
except Exception as exc:
|
||||
data = (
|
||||
f"Based on these tables {self.tables} and columns {self.columns}, "
|
||||
"you can create SQL queries to retrieve data from the database."
|
||||
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> Union[list, str]:
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
try:
|
||||
result = session.execute(text(sql_query))
|
||||
session.commit()
|
||||
|
||||
if result.returns_rows:
|
||||
columns = result.keys()
|
||||
data = [dict(zip(columns, row)) for row in result.fetchall()]
|
||||
return data
|
||||
else:
|
||||
return f"Query {sql_query} executed successfully"
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
Reference in New Issue
Block a user