mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Test optional dependencies are not required in runtime (#260)
* Test optional dependencies are not required in runtime * Add dynamic imports to S3 tools * Setup CI
This commit is contained in:
@@ -8,10 +8,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -92,6 +89,12 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent = boto3.client(
|
||||
|
||||
@@ -5,10 +5,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import custom exceptions
|
||||
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -179,6 +176,12 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
return result_object
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Agent Runtime client
|
||||
bedrock_agent_runtime = boto3.client(
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Type
|
||||
from typing import Any, Type
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class S3ReaderToolInput(BaseModel):
|
||||
@@ -19,6 +17,12 @@ class S3ReaderTool(BaseTool):
|
||||
args_schema: Type[BaseModel] = S3ReaderToolInput
|
||||
|
||||
def _run(self, file_path: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
|
||||
try:
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from typing import Type
|
||||
from typing import Any, Type
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
class S3WriterToolInput(BaseModel):
|
||||
"""Input schema for S3WriterTool."""
|
||||
file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')")
|
||||
content: str = Field(..., description="Content to write to the file")
|
||||
|
||||
|
||||
class S3WriterTool(BaseTool):
|
||||
name: str = "S3 Writer Tool"
|
||||
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3WriterToolInput
|
||||
|
||||
def _run(self, file_path: str, content: str) -> str:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("`boto3` package not found, please run `uv add boto3`")
|
||||
|
||||
try:
|
||||
bucket_name, object_key = self._parse_s3_path(file_path)
|
||||
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
# Test Data
|
||||
MENU_ITEMS = [
|
||||
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
|
||||
(
|
||||
10002,
|
||||
"Ice Cream",
|
||||
"Freezing Point",
|
||||
"Vanilla Ice Cream",
|
||||
"Dessert",
|
||||
"Ice Cream",
|
||||
2,
|
||||
6,
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_QUERIES = [
|
||||
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
|
||||
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
|
||||
("INVALID SQL QUERY", "SQL compilation error"),
|
||||
]
|
||||
|
||||
|
||||
# Integration Test Fixtures
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a Snowflake configuration with test credentials."""
|
||||
return SnowflakeConfig(
|
||||
account="lwyhjun-wx11931",
|
||||
user="crewgitci",
|
||||
password="crewaiT00ls_publicCIpass123",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="tasty_bytes_sample_data",
|
||||
snowflake_schema="raw_pos",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(config):
|
||||
"""Create a SnowflakeSearchTool instance."""
|
||||
return SnowflakeSearchTool(config=config)
|
||||
|
||||
|
||||
# Integration Tests with Real Snowflake Connection
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
|
||||
)
|
||||
async def test_menu_items(
|
||||
snowflake_tool,
|
||||
menu_id,
|
||||
expected_type,
|
||||
brand,
|
||||
item_name,
|
||||
category,
|
||||
subcategory,
|
||||
cost,
|
||||
price,
|
||||
):
|
||||
"""Test menu items with parameterized data for multiple test cases."""
|
||||
results = await snowflake_tool._run(
|
||||
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
|
||||
)
|
||||
assert len(results) == 1
|
||||
menu_item = results[0]
|
||||
|
||||
# Validate all fields
|
||||
assert menu_item["MENU_ID"] == menu_id
|
||||
assert menu_item["MENU_TYPE"] == expected_type
|
||||
assert menu_item["TRUCK_BRAND_NAME"] == brand
|
||||
assert menu_item["MENU_ITEM_NAME"] == item_name
|
||||
assert menu_item["ITEM_CATEGORY"] == category
|
||||
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
|
||||
assert menu_item["COST_OF_GOODS_USD"] == cost
|
||||
assert menu_item["SALE_PRICE_USD"] == price
|
||||
|
||||
# Validate health metrics JSON structure
|
||||
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
|
||||
assert "menu_item_health_metrics" in health_metrics
|
||||
metrics = health_metrics["menu_item_health_metrics"][0]
|
||||
assert "ingredients" in metrics
|
||||
assert isinstance(metrics["ingredients"], list)
|
||||
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
|
||||
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_menu_categories_aggregation(snowflake_tool):
|
||||
"""Test complex aggregation query on menu categories with detailed validations."""
|
||||
results = await snowflake_tool._run(
|
||||
query="""
|
||||
SELECT
|
||||
item_category,
|
||||
COUNT(*) as item_count,
|
||||
AVG(sale_price_usd) as avg_price,
|
||||
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
|
||||
COUNT(DISTINCT menu_type) as menu_type_count,
|
||||
MIN(sale_price_usd) as min_price,
|
||||
MAX(sale_price_usd) as max_price
|
||||
FROM menu
|
||||
GROUP BY item_category
|
||||
HAVING COUNT(*) > 1
|
||||
ORDER BY item_count DESC
|
||||
"""
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
for category in results:
|
||||
# Basic presence checks
|
||||
assert all(
|
||||
key in category
|
||||
for key in [
|
||||
"ITEM_CATEGORY",
|
||||
"ITEM_COUNT",
|
||||
"AVG_PRICE",
|
||||
"TOTAL_MARGIN",
|
||||
"MENU_TYPE_COUNT",
|
||||
"MIN_PRICE",
|
||||
"MAX_PRICE",
|
||||
]
|
||||
)
|
||||
|
||||
# Value validations
|
||||
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
|
||||
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
|
||||
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["MENU_TYPE_COUNT"] >= 1
|
||||
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
|
||||
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
|
||||
"""Test error handling for invalid queries."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(query=invalid_query)
|
||||
assert expected_error.lower() in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_queries(snowflake_tool):
|
||||
"""Test handling of concurrent queries."""
|
||||
queries = [
|
||||
"SELECT COUNT(*) FROM menu",
|
||||
"SELECT COUNT(DISTINCT menu_type) FROM menu",
|
||||
"SELECT COUNT(DISTINCT item_category) FROM menu",
|
||||
]
|
||||
|
||||
tasks = [snowflake_tool._run(query=query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(result, list) for result in results)
|
||||
assert all(len(result) == 1 for result in results)
|
||||
assert all(isinstance(result[0], dict) for result in results)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout(snowflake_tool):
|
||||
"""Test query timeout handling with a complex query."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(
|
||||
query="""
|
||||
WITH RECURSIVE numbers AS (
|
||||
SELECT 1 as n
|
||||
UNION ALL
|
||||
SELECT n + 1
|
||||
FROM numbers
|
||||
WHERE n < 1000000
|
||||
)
|
||||
SELECT COUNT(*) FROM numbers
|
||||
"""
|
||||
)
|
||||
assert (
|
||||
"timeout" in str(exc_info.value).lower()
|
||||
or "execution time" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_behavior(snowflake_tool):
|
||||
"""Test query caching behavior and performance."""
|
||||
query = "SELECT * FROM menu LIMIT 5"
|
||||
|
||||
# First execution
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results1 = await snowflake_tool._run(query=query)
|
||||
first_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Second execution (should be cached)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results2 = await snowflake_tool._run(query=query)
|
||||
second_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Verify results
|
||||
assert results1 == results2
|
||||
assert len(results1) == 5
|
||||
assert second_duration < first_duration
|
||||
|
||||
# Verify cache invalidation with different query
|
||||
different_query = "SELECT * FROM menu LIMIT 10"
|
||||
different_results = await snowflake_tool._run(query=different_query)
|
||||
assert len(different_results) == 10
|
||||
assert different_results != results1
|
||||
@@ -1,46 +0,0 @@
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
|
||||
|
||||
|
||||
def test_spider_tool():
|
||||
spider_tool = SpiderTool()
|
||||
|
||||
searcher = Agent(
|
||||
role="Web Research Expert",
|
||||
goal="Find related information from specific URL's",
|
||||
backstory="An expert web researcher that uses the web extremely well",
|
||||
tools=[spider_tool],
|
||||
verbose=True,
|
||||
cache=False,
|
||||
)
|
||||
|
||||
choose_between_scrape_crawl = Task(
|
||||
description="Scrape the page of spider.cloud and return a summary of how fast it is",
|
||||
expected_output="spider.cloud is a fast scraping and crawling tool",
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
return_metadata = Task(
|
||||
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
|
||||
expected_output="Metadata and 10 word summary of spider.cloud",
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
css_selector = Task(
|
||||
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
|
||||
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[searcher],
|
||||
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_spider_tool()
|
||||
41
tests/test_optional_dependencies.py
Normal file
41
tests/test_optional_dependencies.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project():
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
project_dir = Path(temp_dir.name) / "test_project"
|
||||
project_dir.mkdir()
|
||||
|
||||
pyproject_content = f"""
|
||||
[project]
|
||||
name = "test-project"
|
||||
version = "0.1.0"
|
||||
description = "Test project"
|
||||
requires-python = ">=3.10"
|
||||
"""
|
||||
|
||||
(project_dir / "pyproject.toml").write_text(pyproject_content)
|
||||
run_command(["uv", "add", "--editable", f"file://{Path.cwd().absolute()}"], project_dir)
|
||||
run_command(["uv", "sync"], project_dir)
|
||||
yield project_dir
|
||||
|
||||
|
||||
def run_command(cmd, cwd):
|
||||
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||||
|
||||
|
||||
def test_no_optional_dependencies_in_init(temp_project):
|
||||
"""
|
||||
Test that crewai-tools can be imported without optional dependencies.
|
||||
|
||||
The package defines optional dependencies in pyproject.toml, but the base
|
||||
package should be importable without any of these optional dependencies
|
||||
being installed.
|
||||
"""
|
||||
result = run_command(["uv", "run", "python", "-c", "import crewai_tools"], temp_project)
|
||||
assert result.returncode == 0, f"Import failed with error: {result.stderr}"
|
||||
@@ -1,4 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@@ -27,7 +29,11 @@ def initialize_tool_with(mock_driver):
|
||||
return tool
|
||||
|
||||
|
||||
def test_tool_initialization():
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_tool_initialization(mocked_chrome):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
mocked_chrome.return_value = MagicMock()
|
||||
|
||||
tool = SeleniumScrapingTool()
|
||||
|
||||
assert tool.website_url is None
|
||||
@@ -35,6 +41,11 @@ def test_tool_initialization():
|
||||
assert tool.cookie is None
|
||||
assert tool.wait_time == 3
|
||||
assert tool.return_html is False
|
||||
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
|
||||
Reference in New Issue
Block a user