mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
"""MySQL database loader."""
|
|
|
|
from typing import Any
|
|
from urllib.parse import urlparse
|
|
|
|
import pymysql
|
|
|
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
|
from crewai_tools.rag.source_content import SourceContent
|
|
|
|
|
|
class MySQLLoader(BaseLoader):
|
|
"""Loader for MySQL database content."""
|
|
|
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
|
"""Load content from a MySQL database table.
|
|
|
|
Args:
|
|
source: SQL query (e.g., "SELECT * FROM table_name")
|
|
**kwargs: Additional arguments including db_uri
|
|
|
|
Returns:
|
|
LoaderResult with database content
|
|
"""
|
|
metadata = kwargs.get("metadata", {})
|
|
db_uri = metadata.get("db_uri")
|
|
|
|
if not db_uri:
|
|
raise ValueError("Database URI is required for MySQL loader")
|
|
|
|
query = source.source
|
|
|
|
parsed = urlparse(db_uri)
|
|
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
|
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
|
|
|
connection_params = {
|
|
"host": parsed.hostname or "localhost",
|
|
"port": parsed.port or 3306,
|
|
"user": parsed.username,
|
|
"password": parsed.password,
|
|
"database": parsed.path.lstrip("/") if parsed.path else None,
|
|
"charset": "utf8mb4",
|
|
"cursorclass": pymysql.cursors.DictCursor
|
|
}
|
|
|
|
if not connection_params["database"]:
|
|
raise ValueError("Database name is required in the URI")
|
|
|
|
try:
|
|
connection = pymysql.connect(**connection_params)
|
|
try:
|
|
with connection.cursor() as cursor:
|
|
cursor.execute(query)
|
|
rows = cursor.fetchall()
|
|
|
|
if not rows:
|
|
content = "No data found in the table"
|
|
return LoaderResult(
|
|
content=content,
|
|
metadata={"source": query, "row_count": 0},
|
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
|
)
|
|
|
|
text_parts = []
|
|
|
|
columns = list(rows[0].keys())
|
|
text_parts.append(f"Columns: {', '.join(columns)}")
|
|
text_parts.append(f"Total rows: {len(rows)}")
|
|
text_parts.append("")
|
|
|
|
for i, row in enumerate(rows, 1):
|
|
text_parts.append(f"Row {i}:")
|
|
for col, val in row.items():
|
|
if val is not None:
|
|
text_parts.append(f" {col}: {val}")
|
|
text_parts.append("")
|
|
|
|
content = "\n".join(text_parts)
|
|
|
|
if len(content) > 100000:
|
|
content = content[:100000] + "\n\n[Content truncated...]"
|
|
|
|
return LoaderResult(
|
|
content=content,
|
|
metadata={
|
|
"source": query,
|
|
"database": connection_params["database"],
|
|
"row_count": len(rows),
|
|
"columns": columns
|
|
},
|
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
|
)
|
|
finally:
|
|
connection.close()
|
|
except pymysql.Error as e:
|
|
raise ValueError(f"MySQL database error: {e}")
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to load data from MySQL: {e}") |