mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
feat: add crewai-tools workspace and fix tests/dependencies
* feat: add crewai-tools workspace structure * Squashed 'temp-crewai-tools/' content from commit 9bae5633 git-subtree-dir: temp-crewai-tools git-subtree-split: 9bae56339096cb70f03873e600192bd2cd207ac9 * feat: configure crewai-tools workspace package with dependencies * fix: apply ruff auto-formatting to crewai-tools code * chore: update lockfile * fix: don't allow tool tests yet * fix: comment out extra pytest flags for now * fix: remove conflicting conftest.py from crewai-tools tests * fix: resolve dependency conflicts and test issues - Pin vcrpy to 7.0.0 to fix pytest-recording compatibility - Comment out types-requests to resolve urllib3 conflict - Update requests requirement in crewai-tools to >=2.32.0
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
# Snowflake Search Tool
|
||||
|
||||
A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv sync --extra snowflake
|
||||
|
||||
OR
|
||||
uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
|
||||
OR
|
||||
pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai_tools import SnowflakeSearchTool, SnowflakeConfig
|
||||
|
||||
# Create configuration
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
password="your_password",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema
|
||||
)
|
||||
|
||||
# Initialize tool
|
||||
tool = SnowflakeSearchTool(
|
||||
config=config,
|
||||
pool_size=5,
|
||||
max_retries=3,
|
||||
enable_caching=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
async def main():
|
||||
results = await tool._run(
|
||||
query="SELECT * FROM your_table LIMIT 10",
|
||||
timeout=300
|
||||
)
|
||||
print(f"Retrieved {len(results)} rows")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- ✨ Asynchronous query execution
|
||||
- 🚀 Connection pooling for better performance
|
||||
- 🔄 Automatic retries for transient failures
|
||||
- 💾 Query result caching (optional)
|
||||
- 🔒 Support for both password and key-pair authentication
|
||||
- 📝 Comprehensive error handling and logging
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SnowflakeConfig Parameters
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| account | Yes | Snowflake account identifier |
|
||||
| user | Yes | Snowflake username |
|
||||
| password | Yes* | Snowflake password |
|
||||
| private_key_path | No* | Path to private key file (alternative to password) |
|
||||
| warehouse | Yes | Snowflake warehouse name |
|
||||
| database | Yes | Default database |
|
||||
| snowflake_schema | Yes | Default schema |
|
||||
| role | No | Snowflake role |
|
||||
| session_parameters | No | Custom session parameters dict |
|
||||
|
||||
\* Either password or private_key_path must be provided
|
||||
|
||||
### Tool Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| pool_size | 5 | Number of connections in the pool |
|
||||
| max_retries | 3 | Maximum retry attempts for failed queries |
|
||||
| retry_delay | 1.0 | Delay between retries in seconds |
|
||||
| enable_caching | True | Enable/disable query result caching |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Using Key-Pair Authentication
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
private_key_path="/path/to/private_key.p8",
|
||||
warehouse="your_warehouse",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema"
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Session Parameters
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
# ... other config parameters ...
|
||||
session_parameters={
|
||||
"QUERY_TAG": "my_app",
|
||||
"TIMEZONE": "America/Los_Angeles"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Error Handling**: Always wrap query execution in try-except blocks
|
||||
2. **Logging**: Enable logging to track query execution and errors
|
||||
3. **Connection Management**: Use appropriate pool sizes for your workload
|
||||
4. **Timeouts**: Set reasonable query timeouts to prevent hanging
|
||||
5. **Security**: Use key-pair auth in production and never hardcode credentials
|
||||
|
||||
## Example with Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# ... tool initialization ...
|
||||
results = await tool._run(query="SELECT * FROM table LIMIT 10")
|
||||
logger.info(f"Query completed successfully. Retrieved {len(results)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool automatically handles common Snowflake errors:
|
||||
- DatabaseError
|
||||
- OperationalError
|
||||
- ProgrammingError
|
||||
- Network timeouts
|
||||
- Connection issues
|
||||
|
||||
Errors are logged and retried based on your retry configuration.
|
||||
@@ -0,0 +1,12 @@
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeConfig",
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
]
|
||||
@@ -0,0 +1,273 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import types for type checking only
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import snowflake.connector
|
||||
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache = {}
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
"""Configuration for Snowflake connection."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
account: str = Field(
|
||||
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
|
||||
)
|
||||
user: str = Field(..., description="Snowflake username")
|
||||
password: Optional[SecretStr] = Field(None, description="Snowflake password")
|
||||
private_key_path: Optional[str] = Field(
|
||||
None, description="Path to private key file"
|
||||
)
|
||||
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
|
||||
database: Optional[str] = Field(None, description="Default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Default schema")
|
||||
role: Optional[str] = Field(None, description="Snowflake role")
|
||||
session_parameters: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Session parameters"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_auth(self) -> bool:
|
||||
return bool(self.password or self.private_key_path)
|
||||
|
||||
def model_post_init(self, *args, **kwargs):
|
||||
if not self.has_auth:
|
||||
raise ValueError("Either password or private_key_path must be provided")
|
||||
|
||||
|
||||
class SnowflakeSearchToolInput(BaseModel):
|
||||
"""Input schema for SnowflakeSearchTool."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: Optional[str] = Field(None, description="Override default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
|
||||
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
class SnowflakeSearchTool(BaseTool):
|
||||
"""Tool for executing queries and semantic search on Snowflake."""
|
||||
|
||||
name: str = "Snowflake Database Search"
|
||||
description: str = (
|
||||
"Execute SQL queries or semantic search on Snowflake data warehouse. "
|
||||
"Supports both raw SQL and natural language queries."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
|
||||
|
||||
# Define Pydantic fields
|
||||
config: SnowflakeConfig = Field(
|
||||
..., description="Snowflake connection configuration"
|
||||
)
|
||||
pool_size: int = Field(default=5, description="Size of connection pool")
|
||||
max_retries: int = Field(default=3, description="Maximum retry attempts")
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
enable_caching: bool = Field(
|
||||
default=True, description="Enable query result caching"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
|
||||
_connection_pool: Optional[List["SnowflakeConnection"]] = None
|
||||
_pool_lock: Optional[asyncio.Lock] = None
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_model_rebuilt: bool = False
|
||||
package_dependencies: List[str] = [
|
||||
"snowflake-connector-python",
|
||||
"snowflake-sqlalchemy",
|
||||
"cryptography",
|
||||
]
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._initialize_snowflake()
|
||||
|
||||
def _initialize_snowflake(self) -> None:
|
||||
try:
|
||||
if SNOWFLAKE_AVAILABLE:
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
else:
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'snowflake-connector-python' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"uv",
|
||||
"add",
|
||||
"cryptography",
|
||||
"snowflake-connector-python",
|
||||
"snowflake-sqlalchemy",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
self._connection_pool = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
except subprocess.CalledProcessError:
|
||||
raise ImportError("Failed to install Snowflake dependencies")
|
||||
else:
|
||||
raise ImportError(
|
||||
"Snowflake dependencies not found. Please install them by running "
|
||||
"`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`"
|
||||
)
|
||||
|
||||
async def _get_connection(self) -> "SnowflakeConnection":
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = await asyncio.get_event_loop().run_in_executor(
|
||||
self._thread_pool, self._create_connection
|
||||
)
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> "SnowflakeConnection":
|
||||
"""Create a new Snowflake connection."""
|
||||
conn_params = {
|
||||
"account": self.config.account,
|
||||
"user": self.config.user,
|
||||
"warehouse": self.config.warehouse,
|
||||
"database": self.config.database,
|
||||
"schema": self.config.snowflake_schema,
|
||||
"role": self.config.role,
|
||||
"session_parameters": self.config.session_parameters,
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
conn_params["password"] = self.config.password.get_secret_value()
|
||||
elif self.config.private_key_path and serialization:
|
||||
with open(self.config.private_key_path, "rb") as key_file:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key_file.read(), password=None, backend=default_backend()
|
||||
)
|
||||
conn_params["private_key"] = p_key
|
||||
|
||||
return snowflake.connector.connect(**conn_params)
|
||||
|
||||
def _get_cache_key(self, query: str, timeout: int) -> str:
|
||||
"""Generate a cache key for the query."""
|
||||
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
|
||||
|
||||
async def _execute_query(
|
||||
self, query: str, timeout: int = 300
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query with retries and return results."""
|
||||
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, timeout=timeout)
|
||||
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
cursor.close()
|
||||
async with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(self.retry_delay * (2**attempt))
|
||||
logger.warning(f"Query failed, attempt {attempt + 1}: {e!s}")
|
||||
continue
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
database: Optional[str] = None,
|
||||
snowflake_schema: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
|
||||
results = await self._execute_query(query, timeout)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {e!s}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
if self._connection_pool:
|
||||
for conn in self._connection_pool:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self._thread_pool:
|
||||
self._thread_pool.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
|
||||
SnowflakeSearchTool.model_rebuild()
|
||||
SnowflakeSearchTool._model_rebuilt = True
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user