mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 05:48:14 +00:00
Squashed 'packages/tools/' changes from 78317b9c..0b3f00e6
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
This commit is contained in:
99
crewai_tools/rag/loaders/postgres_loader.py
Normal file
99
crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""PostgreSQL database loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"cursor_factory": RealDictCursor
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = psycopg2.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||
Reference in New Issue
Block a user