Adding Snowflake search tool

This commit is contained in:
ChethanUK
2025-01-17 02:23:06 +05:30
parent 71f3ed9ef9
commit 9c4c4219cd
45 changed files with 1089 additions and 311 deletions

View File

@@ -1,69 +1,104 @@
from typing import Callable
from crewai.tools import BaseTool, tool
from crewai.tools.base_tool import to_langchain
def test_creating_a_tool_using_annotation():
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.func("What is the meaning of life?")
== "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?"
def test_creating_a_tool_using_baseclass():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.invoke({"question": "What is the meaning of life?"})
== "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?"
def test_setting_cache_function():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
cache_function: Callable = lambda: False
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
cache_function: Callable = lambda: False
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == False
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == False
def test_default_cache_function_is_true():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == True
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == True

View File

@@ -1,7 +1,8 @@
import os
import pytest
from crewai_tools import FileReadTool
def test_file_read_tool_constructor():
"""Test FileReadTool initialization with file_path."""
# Create a temporary test file
@@ -18,6 +19,7 @@ def test_file_read_tool_constructor():
# Clean up
os.remove(test_file)
def test_file_read_tool_run():
"""Test FileReadTool _run method with file_path at runtime."""
# Create a temporary test file
@@ -34,6 +36,7 @@ def test_file_read_tool_run():
# Clean up
os.remove(test_file)
def test_file_read_tool_error_handling():
"""Test FileReadTool error handling."""
# Test missing file path
@@ -58,6 +61,7 @@ def test_file_read_tool_error_handling():
os.chmod(test_file, 0o666) # Restore permissions to delete
os.remove(test_file)
def test_file_read_tool_constructor_and_run():
"""Test FileReadTool using both constructor and runtime file paths."""
# Create two test files

View File

View File

@@ -0,0 +1,21 @@
import pytest
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line("markers", "integration: mark test as an integration test")
config.addinivalue_line("markers", "asyncio: mark test as an async test")
# Set the asyncio loop scope through ini configuration
config.inicfg["asyncio_mode"] = "auto"
@pytest.fixture(scope="function")
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()

View File

@@ -0,0 +1,219 @@
import asyncio
import json
from decimal import Decimal
import pytest
from snowflake.connector.errors import DatabaseError, OperationalError
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Test Data
MENU_ITEMS = [
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
(
10002,
"Ice Cream",
"Freezing Point",
"Vanilla Ice Cream",
"Dessert",
"Ice Cream",
2,
6,
),
]
INVALID_QUERIES = [
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
("INVALID SQL QUERY", "SQL compilation error"),
]
# Integration Test Fixtures
@pytest.fixture
def config():
"""Create a Snowflake configuration with test credentials."""
return SnowflakeConfig(
account="lwyhjun-wx11931",
user="crewgitci",
password="crewaiT00ls_publicCIpass123",
warehouse="COMPUTE_WH",
database="tasty_bytes_sample_data",
snowflake_schema="raw_pos",
)
@pytest.fixture
def snowflake_tool(config):
"""Create a SnowflakeSearchTool instance."""
return SnowflakeSearchTool(config=config)
# Integration Tests with Real Snowflake Connection
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize(
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
)
async def test_menu_items(
snowflake_tool,
menu_id,
expected_type,
brand,
item_name,
category,
subcategory,
cost,
price,
):
"""Test menu items with parameterized data for multiple test cases."""
results = await snowflake_tool._run(
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
)
assert len(results) == 1
menu_item = results[0]
# Validate all fields
assert menu_item["MENU_ID"] == menu_id
assert menu_item["MENU_TYPE"] == expected_type
assert menu_item["TRUCK_BRAND_NAME"] == brand
assert menu_item["MENU_ITEM_NAME"] == item_name
assert menu_item["ITEM_CATEGORY"] == category
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
assert menu_item["COST_OF_GOODS_USD"] == cost
assert menu_item["SALE_PRICE_USD"] == price
# Validate health metrics JSON structure
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
assert "menu_item_health_metrics" in health_metrics
metrics = health_metrics["menu_item_health_metrics"][0]
assert "ingredients" in metrics
assert isinstance(metrics["ingredients"], list)
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
@pytest.mark.integration
@pytest.mark.asyncio
async def test_menu_categories_aggregation(snowflake_tool):
"""Test complex aggregation query on menu categories with detailed validations."""
results = await snowflake_tool._run(
query="""
SELECT
item_category,
COUNT(*) as item_count,
AVG(sale_price_usd) as avg_price,
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
COUNT(DISTINCT menu_type) as menu_type_count,
MIN(sale_price_usd) as min_price,
MAX(sale_price_usd) as max_price
FROM menu
GROUP BY item_category
HAVING COUNT(*) > 1
ORDER BY item_count DESC
"""
)
assert len(results) > 0
for category in results:
# Basic presence checks
assert all(
key in category
for key in [
"ITEM_CATEGORY",
"ITEM_COUNT",
"AVG_PRICE",
"TOTAL_MARGIN",
"MENU_TYPE_COUNT",
"MIN_PRICE",
"MAX_PRICE",
]
)
# Value validations
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
assert category["MENU_TYPE_COUNT"] >= 1
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
"""Test error handling for invalid queries."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(query=invalid_query)
assert expected_error.lower() in str(exc_info.value).lower()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_concurrent_queries(snowflake_tool):
"""Test handling of concurrent queries."""
queries = [
"SELECT COUNT(*) FROM menu",
"SELECT COUNT(DISTINCT menu_type) FROM menu",
"SELECT COUNT(DISTINCT item_category) FROM menu",
]
tasks = [snowflake_tool._run(query=query) for query in queries]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all(isinstance(result, list) for result in results)
assert all(len(result) == 1 for result in results)
assert all(isinstance(result[0], dict) for result in results)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_query_timeout(snowflake_tool):
"""Test query timeout handling with a complex query."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(
query="""
WITH RECURSIVE numbers AS (
SELECT 1 as n
UNION ALL
SELECT n + 1
FROM numbers
WHERE n < 1000000
)
SELECT COUNT(*) FROM numbers
"""
)
assert (
"timeout" in str(exc_info.value).lower()
or "execution time" in str(exc_info.value).lower()
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_caching_behavior(snowflake_tool):
"""Test query caching behavior and performance."""
query = "SELECT * FROM menu LIMIT 5"
# First execution
start_time = asyncio.get_event_loop().time()
results1 = await snowflake_tool._run(query=query)
first_duration = asyncio.get_event_loop().time() - start_time
# Second execution (should be cached)
start_time = asyncio.get_event_loop().time()
results2 = await snowflake_tool._run(query=query)
second_duration = asyncio.get_event_loop().time() - start_time
# Verify results
assert results1 == results2
assert len(results1) == 5
assert second_duration < first_duration
# Verify cache invalidation with different query
different_query = "SELECT * FROM menu LIMIT 10"
different_results = await snowflake_tool._run(query=different_query)
assert len(different_results) == 10
assert different_results != results1

View File

@@ -1,5 +1,7 @@
from crewai import Agent, Crew, Task
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
from crewai import Agent, Task, Crew
def test_spider_tool():
spider_tool = SpiderTool()
@@ -10,38 +12,35 @@ def test_spider_tool():
backstory="An expert web researcher that uses the web extremely well",
tools=[spider_tool],
verbose=True,
cache=False
cache=False,
)
choose_between_scrape_crawl = Task(
description="Scrape the page of spider.cloud and return a summary of how fast it is",
expected_output="spider.cloud is a fast scraping and crawling tool",
agent=searcher
agent=searcher,
)
return_metadata = Task(
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
expected_output="Metadata and 10 word summary of spider.cloud",
agent=searcher
agent=searcher,
)
css_selector = Task(
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
agent=searcher
agent=searcher,
)
crew = Crew(
agents=[searcher],
tasks=[
choose_between_scrape_crawl,
return_metadata,
css_selector
],
verbose=True
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
verbose=True,
)
crew.kickoff()
if __name__ == "__main__":
test_spider_tool()

View File

@@ -0,0 +1,103 @@
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Unit Test Fixtures
@pytest.fixture
def mock_snowflake_connection():
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.description = [("col1",), ("col2",)]
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
mock_cursor.execute.return_value = None
mock_conn.cursor.return_value = mock_cursor
return mock_conn
@pytest.fixture
def mock_config():
return SnowflakeConfig(
account="test_account",
user="test_user",
password="test_password",
warehouse="test_warehouse",
database="test_db",
snowflake_schema="test_schema",
)
@pytest.fixture
def snowflake_tool(mock_config):
with patch("snowflake.connector.connect") as mock_connect:
tool = SnowflakeSearchTool(config=mock_config)
yield tool
# Unit Tests
@pytest.mark.asyncio
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
results = await snowflake_tool._run(
query="SELECT * FROM test_table", timeout=300
)
assert len(results) == 2
assert results[0]["col1"] == 1
assert results[0]["col2"] == "value1"
mock_snowflake_connection.cursor.assert_called_once()
@pytest.mark.asyncio
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Execute multiple queries
await asyncio.gather(
snowflake_tool._run("SELECT 1"),
snowflake_tool._run("SELECT 2"),
snowflake_tool._run("SELECT 3"),
)
# Should reuse connections from pool
assert mock_create_conn.call_count <= snowflake_tool.pool_size
@pytest.mark.asyncio
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Add connection to pool
await snowflake_tool._get_connection()
# Return connection to pool
async with snowflake_tool._pool_lock:
snowflake_tool._connection_pool.append(mock_snowflake_connection)
# Trigger cleanup
snowflake_tool.__del__()
mock_snowflake_connection.close.assert_called_once()
def test_config_validation():
# Test missing required fields
with pytest.raises(ValueError):
SnowflakeConfig()
# Test invalid account format
with pytest.raises(ValueError):
SnowflakeConfig(
account="invalid//account", user="test_user", password="test_pass"
)
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")

View File

@@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
class TestCodeInterpreterTool(unittest.TestCase):
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker(self, docker_mock):
tool = CodeInterpreterTool()
code = "print('Hello, World!')"
@@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase):
expected_output = "Hello, World!\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = (
expected_output.encode()
)
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
self.assertEqual(result, expected_output)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker_with_error(self, docker_mock):
tool = CodeInterpreterTool()
code = "print(1/0)"
@@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase):
self.assertEqual(result, expected_output)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker_with_script(self, docker_mock):
tool = CodeInterpreterTool()
code = """print("This is line 1")