Test optional dependencies are not required in runtime (#260)

* Test optional dependencies are not required in runtime

* Add dynamic imports to S3 tools

* Setup CI
This commit is contained in:
Vini Brasil
2025-04-08 13:20:11 -04:00
committed by GitHub
parent 6f95572e18
commit 257f4bf385
8 changed files with 80 additions and 278 deletions

View File

@@ -8,10 +8,7 @@ from dotenv import load_dotenv
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import boto3
from botocore.exceptions import ClientError
# Import custom exceptions
from ..exceptions import BedrockAgentError, BedrockValidationError
# Load environment variables from .env file
@@ -92,6 +89,12 @@ class BedrockInvokeAgentTool(BaseTool):
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
def _run(self, query: str) -> str:
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("`boto3` package not found, please run `uv add boto3`")
try:
# Initialize the Bedrock Agent Runtime client
bedrock_agent = boto3.client(

View File

@@ -5,10 +5,7 @@ from dotenv import load_dotenv
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import boto3
from botocore.exceptions import ClientError
# Import custom exceptions
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
# Load environment variables from .env file
@@ -179,6 +176,12 @@ class BedrockKBRetrieverTool(BaseTool):
return result_object
def _run(self, query: str) -> str:
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("`boto3` package not found, please run `uv add boto3`")
try:
# Initialize the Bedrock Agent Runtime client
bedrock_agent_runtime = boto3.client(

View File

@@ -1,10 +1,8 @@
from typing import Type
from typing import Any, Type
import os
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import boto3
from botocore.exceptions import ClientError
class S3ReaderToolInput(BaseModel):
@@ -19,6 +17,12 @@ class S3ReaderTool(BaseTool):
args_schema: Type[BaseModel] = S3ReaderToolInput
def _run(self, file_path: str) -> str:
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("`boto3` package not found, please run `uv add boto3`")
try:
bucket_name, object_key = self._parse_s3_path(file_path)

View File

@@ -1,22 +1,27 @@
from typing import Type
from typing import Any, Type
import os
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import boto3
from botocore.exceptions import ClientError
class S3WriterToolInput(BaseModel):
"""Input schema for S3WriterTool."""
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')")
content: str = Field(..., description="Content to write to the file")
class S3WriterTool(BaseTool):
name: str = "S3 Writer Tool"
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3WriterToolInput
def _run(self, file_path: str, content: str) -> str:
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("`boto3` package not found, please run `uv add boto3`")
try:
bucket_name, object_key = self._parse_s3_path(file_path)

View File

@@ -1,219 +0,0 @@
import asyncio
import json
from decimal import Decimal
import pytest
from snowflake.connector.errors import DatabaseError, OperationalError
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Test Data
MENU_ITEMS = [
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
(
10002,
"Ice Cream",
"Freezing Point",
"Vanilla Ice Cream",
"Dessert",
"Ice Cream",
2,
6,
),
]
INVALID_QUERIES = [
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
("INVALID SQL QUERY", "SQL compilation error"),
]
# Integration Test Fixtures
@pytest.fixture
def config():
"""Create a Snowflake configuration with test credentials."""
return SnowflakeConfig(
account="lwyhjun-wx11931",
user="crewgitci",
password="crewaiT00ls_publicCIpass123",
warehouse="COMPUTE_WH",
database="tasty_bytes_sample_data",
snowflake_schema="raw_pos",
)
@pytest.fixture
def snowflake_tool(config):
"""Create a SnowflakeSearchTool instance."""
return SnowflakeSearchTool(config=config)
# Integration Tests with Real Snowflake Connection
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize(
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
)
async def test_menu_items(
snowflake_tool,
menu_id,
expected_type,
brand,
item_name,
category,
subcategory,
cost,
price,
):
"""Test menu items with parameterized data for multiple test cases."""
results = await snowflake_tool._run(
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
)
assert len(results) == 1
menu_item = results[0]
# Validate all fields
assert menu_item["MENU_ID"] == menu_id
assert menu_item["MENU_TYPE"] == expected_type
assert menu_item["TRUCK_BRAND_NAME"] == brand
assert menu_item["MENU_ITEM_NAME"] == item_name
assert menu_item["ITEM_CATEGORY"] == category
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
assert menu_item["COST_OF_GOODS_USD"] == cost
assert menu_item["SALE_PRICE_USD"] == price
# Validate health metrics JSON structure
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
assert "menu_item_health_metrics" in health_metrics
metrics = health_metrics["menu_item_health_metrics"][0]
assert "ingredients" in metrics
assert isinstance(metrics["ingredients"], list)
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
@pytest.mark.integration
@pytest.mark.asyncio
async def test_menu_categories_aggregation(snowflake_tool):
"""Test complex aggregation query on menu categories with detailed validations."""
results = await snowflake_tool._run(
query="""
SELECT
item_category,
COUNT(*) as item_count,
AVG(sale_price_usd) as avg_price,
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
COUNT(DISTINCT menu_type) as menu_type_count,
MIN(sale_price_usd) as min_price,
MAX(sale_price_usd) as max_price
FROM menu
GROUP BY item_category
HAVING COUNT(*) > 1
ORDER BY item_count DESC
"""
)
assert len(results) > 0
for category in results:
# Basic presence checks
assert all(
key in category
for key in [
"ITEM_CATEGORY",
"ITEM_COUNT",
"AVG_PRICE",
"TOTAL_MARGIN",
"MENU_TYPE_COUNT",
"MIN_PRICE",
"MAX_PRICE",
]
)
# Value validations
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
assert category["MENU_TYPE_COUNT"] >= 1
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
"""Test error handling for invalid queries."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(query=invalid_query)
assert expected_error.lower() in str(exc_info.value).lower()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_concurrent_queries(snowflake_tool):
"""Test handling of concurrent queries."""
queries = [
"SELECT COUNT(*) FROM menu",
"SELECT COUNT(DISTINCT menu_type) FROM menu",
"SELECT COUNT(DISTINCT item_category) FROM menu",
]
tasks = [snowflake_tool._run(query=query) for query in queries]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all(isinstance(result, list) for result in results)
assert all(len(result) == 1 for result in results)
assert all(isinstance(result[0], dict) for result in results)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_query_timeout(snowflake_tool):
"""Test query timeout handling with a complex query."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(
query="""
WITH RECURSIVE numbers AS (
SELECT 1 as n
UNION ALL
SELECT n + 1
FROM numbers
WHERE n < 1000000
)
SELECT COUNT(*) FROM numbers
"""
)
assert (
"timeout" in str(exc_info.value).lower()
or "execution time" in str(exc_info.value).lower()
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_caching_behavior(snowflake_tool):
"""Test query caching behavior and performance."""
query = "SELECT * FROM menu LIMIT 5"
# First execution
start_time = asyncio.get_event_loop().time()
results1 = await snowflake_tool._run(query=query)
first_duration = asyncio.get_event_loop().time() - start_time
# Second execution (should be cached)
start_time = asyncio.get_event_loop().time()
results2 = await snowflake_tool._run(query=query)
second_duration = asyncio.get_event_loop().time() - start_time
# Verify results
assert results1 == results2
assert len(results1) == 5
assert second_duration < first_duration
# Verify cache invalidation with different query
different_query = "SELECT * FROM menu LIMIT 10"
different_results = await snowflake_tool._run(query=different_query)
assert len(different_results) == 10
assert different_results != results1

View File

@@ -1,46 +0,0 @@
from crewai import Agent, Crew, Task
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
def test_spider_tool():
spider_tool = SpiderTool()
searcher = Agent(
role="Web Research Expert",
goal="Find related information from specific URL's",
backstory="An expert web researcher that uses the web extremely well",
tools=[spider_tool],
verbose=True,
cache=False,
)
choose_between_scrape_crawl = Task(
description="Scrape the page of spider.cloud and return a summary of how fast it is",
expected_output="spider.cloud is a fast scraping and crawling tool",
agent=searcher,
)
return_metadata = Task(
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
expected_output="Metadata and 10 word summary of spider.cloud",
agent=searcher,
)
css_selector = Task(
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
agent=searcher,
)
crew = Crew(
agents=[searcher],
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
verbose=True,
)
crew.kickoff()
if __name__ == "__main__":
test_spider_tool()

View File

@@ -0,0 +1,41 @@
import subprocess
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def temp_project():
temp_dir = tempfile.TemporaryDirectory()
project_dir = Path(temp_dir.name) / "test_project"
project_dir.mkdir()
pyproject_content = f"""
[project]
name = "test-project"
version = "0.1.0"
description = "Test project"
requires-python = ">=3.10"
"""
(project_dir / "pyproject.toml").write_text(pyproject_content)
run_command(["uv", "add", "--editable", f"file://{Path.cwd().absolute()}"], project_dir)
run_command(["uv", "sync"], project_dir)
yield project_dir
def run_command(cmd, cwd):
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
def test_no_optional_dependencies_in_init(temp_project):
"""
Test that crewai-tools can be imported without optional dependencies.
The package defines optional dependencies in pyproject.toml, but the base
package should be importable without any of these optional dependencies
being installed.
"""
result = run_command(["uv", "run", "python", "-c", "import crewai_tools"], temp_project)
assert result.returncode == 0, f"Import failed with error: {result.stderr}"

View File

@@ -1,4 +1,6 @@
from unittest.mock import MagicMock, patch
import tempfile
import os
from bs4 import BeautifulSoup
@@ -27,7 +29,11 @@ def initialize_tool_with(mock_driver):
return tool
def test_tool_initialization():
@patch("selenium.webdriver.Chrome")
def test_tool_initialization(mocked_chrome):
temp_dir = tempfile.mkdtemp()
mocked_chrome.return_value = MagicMock()
tool = SeleniumScrapingTool()
assert tool.website_url is None
@@ -35,6 +41,11 @@ def test_tool_initialization():
assert tool.cookie is None
assert tool.wait_time == 3
assert tool.return_html is False
try:
os.rmdir(temp_dir)
except:
pass
@patch("selenium.webdriver.Chrome")