Files
crewAI/crewai_tools/tools/databricks_query_tool/databricks_query_tool.py
Greyson Lalonde e16606672a Squashed 'packages/tools/' content from commit 78317b9c
git-subtree-dir: packages/tools
git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
2025-09-12 21:58:02 -04:00

671 lines
35 KiB
Python

import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from crewai.tools import BaseTool
from pydantic import BaseModel, Field, model_validator
if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient
class DatabricksQueryToolSchema(BaseModel):
"""Input schema for DatabricksQueryTool."""
query: str = Field(
..., description="SQL query to execute against the Databricks workspace table"
)
catalog: Optional[str] = Field(
None, description="Databricks catalog name (optional, defaults to configured catalog)"
)
db_schema: Optional[str] = Field(
None, description="Databricks schema name (optional, defaults to configured schema)"
)
warehouse_id: Optional[str] = Field(
None, description="Databricks SQL warehouse ID (optional, defaults to configured warehouse)"
)
row_limit: Optional[int] = Field(
1000, description="Maximum number of rows to return (default: 1000)"
)
@model_validator(mode='after')
def validate_input(self) -> 'DatabricksQueryToolSchema':
"""Validate the input parameters."""
# Ensure the query is not empty
if not self.query or not self.query.strip():
raise ValueError("Query cannot be empty")
# Add a LIMIT clause to the query if row_limit is provided and query doesn't have one
if self.row_limit and "limit" not in self.query.lower():
self.query = f"{self.query.rstrip(';')} LIMIT {self.row_limit};"
return self
class DatabricksQueryTool(BaseTool):
"""
A tool for querying Databricks workspace tables using SQL.
This tool executes SQL queries against Databricks tables and returns the results.
It requires Databricks authentication credentials to be set as environment variables.
Authentication can be provided via:
- Databricks CLI profile: Set DATABRICKS_CONFIG_PROFILE environment variable
- Direct credentials: Set DATABRICKS_HOST and DATABRICKS_TOKEN environment variables
Example:
>>> tool = DatabricksQueryTool()
>>> results = tool.run(query="SELECT * FROM my_table LIMIT 10")
"""
name: str = "Databricks SQL Query"
description: str = (
"Execute SQL queries against Databricks workspace tables and return the results."
" Provide a 'query' parameter with the SQL query to execute."
)
args_schema: Type[BaseModel] = DatabricksQueryToolSchema
# Optional default parameters
default_catalog: Optional[str] = None
default_schema: Optional[str] = None
default_warehouse_id: Optional[str] = None
_workspace_client: Optional["WorkspaceClient"] = None
package_dependencies: List[str] = ["databricks-sdk"]
def __init__(
self,
default_catalog: Optional[str] = None,
default_schema: Optional[str] = None,
default_warehouse_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Initialize the DatabricksQueryTool.
Args:
default_catalog (Optional[str]): Default catalog to use for queries.
default_schema (Optional[str]): Default schema to use for queries.
default_warehouse_id (Optional[str]): Default SQL warehouse ID to use.
**kwargs: Additional keyword arguments passed to BaseTool.
"""
super().__init__(**kwargs)
self.default_catalog = default_catalog
self.default_schema = default_schema
self.default_warehouse_id = default_warehouse_id
self._validate_credentials()
def _validate_credentials(self) -> None:
"""Validate that Databricks credentials are available."""
has_profile = "DATABRICKS_CONFIG_PROFILE" in os.environ
has_direct_auth = "DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ
if not (has_profile or has_direct_auth):
raise ValueError(
"Databricks authentication credentials are required. "
"Set either DATABRICKS_CONFIG_PROFILE or both DATABRICKS_HOST and DATABRICKS_TOKEN environment variables."
)
@property
def workspace_client(self) -> "WorkspaceClient":
"""Get or create a Databricks WorkspaceClient instance."""
if self._workspace_client is None:
try:
from databricks.sdk import WorkspaceClient
self._workspace_client = WorkspaceClient()
except ImportError:
raise ImportError(
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
)
return self._workspace_client
def _format_results(self, results: List[Dict[str, Any]]) -> str:
"""Format query results as a readable string."""
if not results:
return "Query returned no results."
# Get column names from the first row
if not results[0]:
return "Query returned empty rows with no columns."
columns = list(results[0].keys())
# If we have rows but they're all empty, handle that case
if not columns:
return "Query returned rows but with no column data."
# Calculate column widths based on data
col_widths = {col: len(col) for col in columns}
for row in results:
for col in columns:
# Convert value to string and get its length
# Handle None values gracefully
value_str = str(row[col]) if row[col] is not None else "NULL"
col_widths[col] = max(col_widths[col], len(value_str))
# Create header row
header = " | ".join(f"{col:{col_widths[col]}}" for col in columns)
separator = "-+-".join("-" * col_widths[col] for col in columns)
# Format data rows
data_rows = []
for row in results:
# Handle None values by displaying "NULL"
row_values = {col: str(row[col]) if row[col] is not None else "NULL" for col in columns}
data_row = " | ".join(f"{row_values[col]:{col_widths[col]}}" for col in columns)
data_rows.append(data_row)
# Add row count information
result_info = f"({len(results)} row{'s' if len(results) != 1 else ''} returned)"
# Combine all parts
return f"{header}\n{separator}\n" + "\n".join(data_rows) + f"\n\n{result_info}"
def _run(
self,
**kwargs: Any,
) -> str:
"""
Execute a SQL query against Databricks and return the results.
Args:
query (str): SQL query to execute
catalog (Optional[str]): Databricks catalog name
db_schema (Optional[str]): Databricks schema name
warehouse_id (Optional[str]): SQL warehouse ID
row_limit (Optional[int]): Maximum number of rows to return
Returns:
str: Formatted query results
"""
try:
# Get parameters with fallbacks to default values
query = kwargs.get("query")
catalog = kwargs.get("catalog") or self.default_catalog
db_schema = kwargs.get("db_schema") or self.default_schema
warehouse_id = kwargs.get("warehouse_id") or self.default_warehouse_id
row_limit = kwargs.get("row_limit", 1000)
# Validate schema and query
validated_input = DatabricksQueryToolSchema(
query=query,
catalog=catalog,
db_schema=db_schema,
warehouse_id=warehouse_id,
row_limit=row_limit
)
# Extract validated parameters
query = validated_input.query
catalog = validated_input.catalog
db_schema = validated_input.db_schema
warehouse_id = validated_input.warehouse_id
# Setup SQL context with catalog/schema if provided
context = {}
if catalog:
context["catalog"] = catalog
if db_schema:
context["schema"] = db_schema
# Execute query
statement = self.workspace_client.statement_execution
try:
# Execute the statement
execution = statement.execute_statement(
warehouse_id=warehouse_id,
statement=query,
**context
)
statement_id = execution.statement_id
except Exception as execute_error:
# Handle immediate execution errors
return f"Error starting query execution: {str(execute_error)}"
# Poll for results with better error handling
import time
result = None
timeout = 300 # 5 minutes timeout
start_time = time.time()
poll_count = 0
previous_state = None # Track previous state to detect changes
while time.time() - start_time < timeout:
poll_count += 1
try:
# Get statement status
result = statement.get_statement(statement_id)
# Check if finished - be very explicit about state checking
if hasattr(result, 'status') and hasattr(result.status, 'state'):
state_value = str(result.status.state) # Convert to string to handle both string and enum
# Track state changes for debugging
if previous_state != state_value:
previous_state = state_value
# Check if state indicates completion
if "SUCCEEDED" in state_value:
break
elif "FAILED" in state_value:
# Extract error message with more robust handling
error_info = "No detailed error info"
try:
# First try direct access to error.message
if hasattr(result.status, 'error') and result.status.error:
if hasattr(result.status.error, 'message'):
error_info = result.status.error.message
# Some APIs may have a different structure
elif hasattr(result.status.error, 'error_message'):
error_info = result.status.error.error_message
# Last resort, try to convert the whole error object to string
else:
error_info = str(result.status.error)
except Exception as err_extract_error:
# If all else fails, try to get any info we can
error_info = f"Error details unavailable: {str(err_extract_error)}"
# Return immediately on first FAILED state detection
return f"Query execution failed: {error_info}"
elif "CANCELED" in state_value:
return "Query was canceled"
except Exception as poll_error:
# Don't immediately fail - try again a few times
if poll_count > 3:
return f"Error checking query status: {str(poll_error)}"
# Wait before polling again
time.sleep(2)
# Check if we timed out
if result is None:
return "Query returned no result (likely timed out or failed)"
if not hasattr(result, 'status') or not hasattr(result.status, 'state'):
return "Query completed but returned an invalid result structure"
# Convert state to string for comparison
state_value = str(result.status.state)
if not any(state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]):
return f"Query timed out after 5 minutes (last state: {state_value})"
# Get results - adapt this based on the actual structure of the result object
chunk_results = []
# Check if we have results and a schema in a very defensive way
has_schema = (hasattr(result, 'manifest') and result.manifest is not None and
hasattr(result.manifest, 'schema') and result.manifest.schema is not None)
has_result = (hasattr(result, 'result') and result.result is not None)
if has_schema and has_result:
try:
# Get schema for column names
columns = [col.name for col in result.manifest.schema.columns]
# Debug info for schema
# Keep track of all dynamic columns we create
all_columns = set(columns)
# Dump the raw structure of result data to help troubleshoot
if hasattr(result.result, 'data_array'):
# Add defensive check for None data_array
if result.result.data_array is None:
print("data_array is None - likely an empty result set or DDL query")
# Return empty result handling rather than trying to process null data
return "Query executed successfully (no data returned)"
# IMPROVED DETECTION LOGIC: Check if we're possibly dealing with rows where each item
# contains a single value or character (which could indicate incorrect row structure)
is_likely_incorrect_row_structure = False
# Only try to analyze sample if data_array exists and has content
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0 and len(result.result.data_array[0]) > 0:
sample_size = min(20, len(result.result.data_array[0]))
if sample_size > 0:
single_char_count = 0
single_digit_count = 0
total_items = 0
for i in range(sample_size):
val = result.result.data_array[0][i]
total_items += 1
if isinstance(val, str) and len(val) == 1 and not val.isdigit():
single_char_count += 1
elif isinstance(val, str) and len(val) == 1 and val.isdigit():
single_digit_count += 1
# If a significant portion of the first values are single characters or digits,
# this likely indicates data is being incorrectly structured
if total_items > 0 and (single_char_count + single_digit_count) / total_items > 0.5:
is_likely_incorrect_row_structure = True
# Additional check: if many rows have just 1 item when we expect multiple columns
rows_with_single_item = 0
if hasattr(result.result, 'data_array') and result.result.data_array and len(result.result.data_array) > 0:
sample_size_for_rows = min(sample_size, len(result.result.data_array[0])) if 'sample_size' in locals() else min(20, len(result.result.data_array[0]))
rows_with_single_item = sum(1 for row in result.result.data_array[0][:sample_size_for_rows] if isinstance(row, list) and len(row) == 1)
if rows_with_single_item > sample_size_for_rows * 0.5 and len(columns) > 1:
is_likely_incorrect_row_structure = True
# Check if we're getting primarily single characters or the data structure seems off,
# we should use special handling
if 'is_likely_incorrect_row_structure' in locals() and is_likely_incorrect_row_structure:
print("Data appears to be malformed - will use special row reconstruction")
needs_special_string_handling = True
else:
needs_special_string_handling = False
# Process results differently based on detection
if 'needs_special_string_handling' in locals() and needs_special_string_handling:
# We're dealing with data where the rows may be incorrectly structured
print("Using row reconstruction processing mode")
# Collect all values into a flat list
all_values = []
if hasattr(result.result, 'data_array') and result.result.data_array:
# Flatten all values into a single list
for chunk in result.result.data_array:
for item in chunk:
if isinstance(item, (list, tuple)):
all_values.extend(item)
else:
all_values.append(item)
# Get the expected column count from schema
expected_column_count = len(columns)
# Try to reconstruct rows using pattern recognition
reconstructed_rows = []
# PATTERN RECOGNITION APPROACH
# Look for likely indicators of row boundaries in the data
# For Netflix data, we expect IDs as numbers, titles as text strings, etc.
# Use regex pattern to identify ID columns that likely start a new row
import re
id_pattern = re.compile(r'^\d{5,9}$') # Netflix IDs are often 5-9 digits
id_indices = []
for i, val in enumerate(all_values):
if isinstance(val, str) and id_pattern.match(val):
# This value looks like an ID, might be the start of a row
if i < len(all_values) - 1:
next_few_values = all_values[i+1:i+5]
# If following values look like they could be part of a title
if any(isinstance(v, str) and len(v) > 1 for v in next_few_values):
id_indices.append(i)
if id_indices:
# If we found potential row starts, use them to extract rows
for i in range(len(id_indices)):
start_idx = id_indices[i]
end_idx = id_indices[i+1] if i+1 < len(id_indices) else len(all_values)
# Extract values for this row
row_values = all_values[start_idx:end_idx]
# Special handling for Netflix title data
# Titles might be split into individual characters
if 'Title' in columns and len(row_values) > expected_column_count:
# Try to reconstruct by looking for patterns
# We know ID is first, then Title (which may be split)
# Then other fields like Genre, etc.
# Take first value as ID
row_dict = {columns[0]: row_values[0]}
# Look for Genre or other non-title fields to determine where title ends
title_end_idx = 1
for j in range(2, min(100, len(row_values))):
val = row_values[j]
# Check for common genres or non-title markers
if isinstance(val, str) and val in ['Comedy', 'Drama', 'Action', 'Horror', 'Thriller', 'Documentary']:
# Likely found the Genre field
title_end_idx = j
break
# Reconstruct title from individual characters
if title_end_idx > 1:
title_chars = row_values[1:title_end_idx]
# Check if they're individual characters
if all(isinstance(c, str) and len(c) == 1 for c in title_chars):
title = ''.join(title_chars)
row_dict['Title'] = title
# Assign remaining values to columns
remaining_values = row_values[title_end_idx:]
for j, col_name in enumerate(columns[2:], 2):
if j-2 < len(remaining_values):
row_dict[col_name] = remaining_values[j-2]
else:
row_dict[col_name] = None
else:
# Fallback: simple mapping
for j, col_name in enumerate(columns):
if j < len(row_values):
row_dict[col_name] = row_values[j]
else:
row_dict[col_name] = None
else:
# Standard mapping
row_dict = {}
for j, col_name in enumerate(columns):
if j < len(row_values):
row_dict[col_name] = row_values[j]
else:
row_dict[col_name] = None
reconstructed_rows.append(row_dict)
else:
# More intelligent chunking - try to detect where columns like Title might be split
title_idx = columns.index('Title') if 'Title' in columns else -1
if title_idx >= 0:
print("Attempting title reconstruction method")
# Try to detect if title is split across multiple values
i = 0
while i < len(all_values):
# Check if this could be an ID (start of a row)
if isinstance(all_values[i], str) and id_pattern.match(all_values[i]):
row_dict = {columns[0]: all_values[i]}
i += 1
# Try to reconstruct title if it appears to be split
title_chars = []
while (i < len(all_values) and
isinstance(all_values[i], str) and
len(all_values[i]) <= 1 and
len(title_chars) < 100): # Cap title length
title_chars.append(all_values[i])
i += 1
if title_chars:
row_dict[columns[title_idx]] = ''.join(title_chars)
# Add remaining fields
for j in range(title_idx + 1, len(columns)):
if i < len(all_values):
row_dict[columns[j]] = all_values[i]
i += 1
else:
row_dict[columns[j]] = None
reconstructed_rows.append(row_dict)
else:
i += 1
# If we still don't have rows, use simple chunking as fallback
if not reconstructed_rows:
print("Falling back to basic chunking approach")
chunks = [all_values[i:i+expected_column_count] for i in range(0, len(all_values), expected_column_count)]
for chunk in chunks:
# Skip chunks that seem to be partial/incomplete rows
if len(chunk) < expected_column_count * 0.75: # Allow for some missing values
continue
row_dict = {}
# Map values to column names
for i, col in enumerate(columns):
if i < len(chunk):
row_dict[col] = chunk[i]
else:
row_dict[col] = None
reconstructed_rows.append(row_dict)
# Apply post-processing to fix known issues
if reconstructed_rows and 'Title' in columns:
print("Applying post-processing to improve data quality")
for row in reconstructed_rows:
# Fix titles that might still have issues
if isinstance(row.get('Title'), str) and len(row.get('Title')) <= 1:
# This is likely still a fragmented title - mark as potentially incomplete
row['Title'] = f"[INCOMPLETE] {row.get('Title')}"
# Ensure we respect the row limit
if row_limit and len(reconstructed_rows) > row_limit:
reconstructed_rows = reconstructed_rows[:row_limit]
chunk_results = reconstructed_rows
else:
# Process normal result structure as before
print("Using standard processing mode")
# Check different result structures
if hasattr(result.result, 'data_array') and result.result.data_array:
# Check if data appears to be malformed within chunks
for chunk_idx, chunk in enumerate(result.result.data_array):
# Check if chunk might actually contain individual columns of a single row
# This is another way data might be malformed - check the first few values
if len(chunk) > 0 and len(columns) > 1:
# If there seems to be a mismatch between chunk structure and expected columns
first_few_values = chunk[:min(5, len(chunk))]
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in first_few_values):
if len(chunk) > len(columns) * 3: # Heuristic: if chunk has way more items than columns
print("Chunk appears to contain individual values rather than rows - switching to row reconstruction")
# This chunk might actually be values of multiple rows - try to reconstruct
values = chunk # All values in this chunk
reconstructed_rows = []
# Try to create rows based on expected column count
for i in range(0, len(values), len(columns)):
if i + len(columns) <= len(values): # Ensure we have enough values
row_values = values[i:i+len(columns)]
row_dict = {col: val for col, val in zip(columns, row_values)}
reconstructed_rows.append(row_dict)
if reconstructed_rows:
chunk_results.extend(reconstructed_rows)
continue # Skip normal processing for this chunk
# Special case: when chunk contains exactly the right number of values for a single row
# This handles the case where instead of a list of rows, we just got all values in a flat list
if all(isinstance(val, (str, int, float)) and not isinstance(val, (list, dict)) for val in chunk):
if len(chunk) == len(columns) or (len(chunk) > 0 and len(chunk) % len(columns) == 0):
# Process flat list of values as rows
for i in range(0, len(chunk), len(columns)):
row_values = chunk[i:i+len(columns)]
if len(row_values) == len(columns): # Only process complete rows
row_dict = {col: val for col, val in zip(columns, row_values)}
chunk_results.append(row_dict)
# Skip regular row processing for this chunk
continue
# Normal processing for typical row structure
for row_idx, row in enumerate(chunk):
# Ensure row is actually a collection of values
if not isinstance(row, (list, tuple, dict)):
# This might be a single value; skip it or handle specially
continue
# Convert each row to a dictionary with column names as keys
row_dict = {}
# Handle dict rows directly
if isinstance(row, dict):
# Use the existing column mapping
row_dict = dict(row)
elif isinstance(row, (list, tuple)):
# Map list of values to columns
for i, val in enumerate(row):
if i < len(columns): # Only process if we have a matching column
row_dict[columns[i]] = val
else:
# Extra values without column names
dynamic_col = f"Column_{i}"
row_dict[dynamic_col] = val
all_columns.add(dynamic_col)
# If we have fewer values than columns, set missing values to None
for col in columns:
if col not in row_dict:
row_dict[col] = None
chunk_results.append(row_dict)
elif hasattr(result.result, 'data') and result.result.data:
# Alternative data structure
for row_idx, row in enumerate(result.result.data):
# Debug info
# Safely create dictionary matching column names to values
row_dict = {}
for i, val in enumerate(row):
if i < len(columns): # Only process if we have a matching column
row_dict[columns[i]] = val
else:
# Extra values without column names
dynamic_col = f"Column_{i}"
row_dict[dynamic_col] = val
all_columns.add(dynamic_col)
# If we have fewer values than columns, set missing values to None
for i, col in enumerate(columns):
if i >= len(row):
row_dict[col] = None
chunk_results.append(row_dict)
# After processing all rows, ensure all rows have all columns
normalized_results = []
for row in chunk_results:
# Create a new row with all columns, defaulting to None for missing ones
normalized_row = {col: row.get(col, None) for col in all_columns}
normalized_results.append(normalized_row)
# Replace the original results with normalized ones
chunk_results = normalized_results
except Exception as results_error:
# Enhanced error message with more context
import traceback
error_details = traceback.format_exc()
return f"Error processing query results: {str(results_error)}\n\nDetails:\n{error_details}"
# If we have no results but the query succeeded (e.g., for DDL statements)
if not chunk_results and hasattr(result, 'status'):
state_value = str(result.status.state)
if "SUCCEEDED" in state_value:
return "Query executed successfully (no results to display)"
# Format and return results
return self._format_results(chunk_results)
except Exception as e:
# Include more details in the error message to help with debugging
import traceback
error_details = traceback.format_exc()
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"