Files
crewAI/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py
Greyson Lalonde e16606672a Squashed 'packages/tools/' content from commit 78317b9c
git-subtree-dir: packages/tools
git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
2025-09-12 21:58:02 -04:00

269 lines
9.8 KiB
Python

import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from crewai.tools.base_tool import BaseTool
from pydantic import BaseModel, ConfigDict, Field, SecretStr
if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError
try:
import snowflake.connector
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
SNOWFLAKE_AVAILABLE = True
except ImportError:
SNOWFLAKE_AVAILABLE = False
# Configure logging
logger = logging.getLogger(__name__)
# Cache for query results
_query_cache = {}
class SnowflakeConfig(BaseModel):
"""Configuration for Snowflake connection."""
model_config = ConfigDict(protected_namespaces=())
account: str = Field(
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
)
user: str = Field(..., description="Snowflake username")
password: Optional[SecretStr] = Field(None, description="Snowflake password")
private_key_path: Optional[str] = Field(
None, description="Path to private key file"
)
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
database: Optional[str] = Field(None, description="Default database")
snowflake_schema: Optional[str] = Field(None, description="Default schema")
role: Optional[str] = Field(None, description="Snowflake role")
session_parameters: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Session parameters"
)
@property
def has_auth(self) -> bool:
return bool(self.password or self.private_key_path)
def model_post_init(self, *args, **kwargs):
if not self.has_auth:
raise ValueError("Either password or private_key_path must be provided")
class SnowflakeSearchToolInput(BaseModel):
"""Input schema for SnowflakeSearchTool."""
model_config = ConfigDict(protected_namespaces=())
query: str = Field(..., description="SQL query or semantic search query to execute")
database: Optional[str] = Field(None, description="Override default database")
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
class SnowflakeSearchTool(BaseTool):
"""Tool for executing queries and semantic search on Snowflake."""
name: str = "Snowflake Database Search"
description: str = (
"Execute SQL queries or semantic search on Snowflake data warehouse. "
"Supports both raw SQL and natural language queries."
)
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
# Define Pydantic fields
config: SnowflakeConfig = Field(
..., description="Snowflake connection configuration"
)
pool_size: int = Field(default=5, description="Size of connection pool")
max_retries: int = Field(default=3, description="Maximum retry attempts")
retry_delay: float = Field(
default=1.0, description="Delay between retries in seconds"
)
enable_caching: bool = Field(
default=True, description="Enable query result caching"
)
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)
_connection_pool: Optional[List["SnowflakeConnection"]] = None
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
package_dependencies: List[str] = ["snowflake-connector-python", "snowflake-sqlalchemy", "cryptography"]
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._initialize_snowflake()
def _initialize_snowflake(self) -> None:
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
except ImportError:
import click
if click.confirm(
"You are missing the 'snowflake-connector-python' package. Would you like to install it?"
):
import subprocess
try:
subprocess.run(
[
"uv",
"add",
"cryptography",
"snowflake-connector-python",
"snowflake-sqlalchemy",
],
check=True,
)
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError:
raise ImportError("Failed to install Snowflake dependencies")
else:
raise ImportError(
"Snowflake dependencies not found. Please install them by running "
"`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`"
)
async def _get_connection(self) -> "SnowflakeConnection":
"""Get a connection from the pool or create a new one."""
async with self._pool_lock:
if not self._connection_pool:
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()
def _create_connection(self) -> "SnowflakeConnection":
"""Create a new Snowflake connection."""
conn_params = {
"account": self.config.account,
"user": self.config.user,
"warehouse": self.config.warehouse,
"database": self.config.database,
"schema": self.config.snowflake_schema,
"role": self.config.role,
"session_parameters": self.config.session_parameters,
}
if self.config.password:
conn_params["password"] = self.config.password.get_secret_value()
elif self.config.private_key_path and serialization:
with open(self.config.private_key_path, "rb") as key_file:
p_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend()
)
conn_params["private_key"] = p_key
return snowflake.connector.connect(**conn_params)
def _get_cache_key(self, query: str, timeout: int) -> str:
"""Generate a cache key for the query."""
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
async def _execute_query(
self, query: str, timeout: int = 300
) -> List[Dict[str, Any]]:
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
conn = await self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(query, timeout=timeout)
if not cursor.description:
return []
columns = [col[0] for col in cursor.description]
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
if self.enable_caching:
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
cursor.close()
async with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e:
if attempt == self.max_retries - 1:
raise
await asyncio.sleep(self.retry_delay * (2**attempt))
logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}")
continue
async def _run(
self,
query: str,
database: Optional[str] = None,
snowflake_schema: Optional[str] = None,
timeout: int = 300,
**kwargs: Any,
) -> Any:
"""Execute the search query."""
try:
# Override database/schema if provided
if database:
await self._execute_query(f"USE DATABASE {database}")
if snowflake_schema:
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
results = await self._execute_query(query, timeout)
return results
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
raise
def __del__(self):
"""Cleanup connections on deletion."""
try:
if self._connection_pool:
for conn in self._connection_pool:
try:
conn.close()
except Exception:
pass
if self._thread_pool:
self._thread_pool.shutdown()
except Exception:
pass
try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
SnowflakeSearchTool.model_rebuild()
SnowflakeSearchTool._model_rebuilt = True
except Exception:
pass