diff --git a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py index 6c43480c0..c064b9b2d 100644 --- a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py +++ b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py @@ -8,10 +8,7 @@ from dotenv import load_dotenv from crewai.tools import BaseTool from pydantic import BaseModel, Field -import boto3 -from botocore.exceptions import ClientError -# Import custom exceptions from ..exceptions import BedrockAgentError, BedrockValidationError # Load environment variables from .env file @@ -92,6 +89,12 @@ class BedrockInvokeAgentTool(BaseTool): raise BedrockValidationError(f"Parameter validation failed: {str(e)}") def _run(self, query: str) -> str: + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError: + raise ImportError("`boto3` package not found, please run `uv add boto3`") + try: # Initialize the Bedrock Agent Runtime client bedrock_agent = boto3.client( diff --git a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py index 55a15b621..15c74077c 100644 --- a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py +++ b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py @@ -5,10 +5,7 @@ from dotenv import load_dotenv from crewai.tools import BaseTool from pydantic import BaseModel, Field -import boto3 -from botocore.exceptions import ClientError -# Import custom exceptions from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError # Load environment variables from .env file @@ -179,6 +176,12 @@ class BedrockKBRetrieverTool(BaseTool): return result_object def _run(self, query: str) -> str: + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError: + raise ImportError("`boto3` package not found, please run `uv add boto3`") + try: # Initialize the Bedrock Agent Runtime client bedrock_agent_runtime = boto3.client( diff --git a/src/crewai_tools/aws/s3/reader_tool.py b/src/crewai_tools/aws/s3/reader_tool.py index 7cd734081..4b3b9a394 100644 --- a/src/crewai_tools/aws/s3/reader_tool.py +++ b/src/crewai_tools/aws/s3/reader_tool.py @@ -1,10 +1,8 @@ -from typing import Type +from typing import Any, Type import os from crewai.tools import BaseTool from pydantic import BaseModel, Field -import boto3 -from botocore.exceptions import ClientError class S3ReaderToolInput(BaseModel): @@ -19,6 +17,12 @@ class S3ReaderTool(BaseTool): args_schema: Type[BaseModel] = S3ReaderToolInput def _run(self, file_path: str) -> str: + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError: + raise ImportError("`boto3` package not found, please run `uv add boto3`") + try: bucket_name, object_key = self._parse_s3_path(file_path) diff --git a/src/crewai_tools/aws/s3/writer_tool.py b/src/crewai_tools/aws/s3/writer_tool.py index 0c4201e0f..f0aaddb28 100644 --- a/src/crewai_tools/aws/s3/writer_tool.py +++ b/src/crewai_tools/aws/s3/writer_tool.py @@ -1,22 +1,27 @@ -from typing import Type +from typing import Any, Type import os from crewai.tools import BaseTool from pydantic import BaseModel, Field -import boto3 -from botocore.exceptions import ClientError class S3WriterToolInput(BaseModel): """Input schema for S3WriterTool.""" file_path: str = Field(..., description="S3 file path (e.g., 's3://bucket-name/file-name')") content: str = Field(..., description="Content to write to the file") + class S3WriterTool(BaseTool): name: str = "S3 Writer Tool" description: str = "Writes content to a file in Amazon S3 given an S3 file path" args_schema: Type[BaseModel] = S3WriterToolInput def _run(self, file_path: str, content: str) -> str: + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError: + raise ImportError("`boto3` package not found, please run `uv add boto3`") + try: bucket_name, object_key = self._parse_s3_path(file_path) diff --git a/tests/it/tools/snowflake_search_tool_test.py b/tests/it/tools/snowflake_search_tool_test.py deleted file mode 100644 index 70dc07953..000000000 --- a/tests/it/tools/snowflake_search_tool_test.py +++ /dev/null @@ -1,219 +0,0 @@ -import asyncio -import json -from decimal import Decimal - -import pytest -from snowflake.connector.errors import DatabaseError, OperationalError - -from crewai_tools import SnowflakeConfig, SnowflakeSearchTool - -# Test Data -MENU_ITEMS = [ - (10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4), - ( - 10002, - "Ice Cream", - "Freezing Point", - "Vanilla Ice Cream", - "Dessert", - "Ice Cream", - 2, - 6, - ), -] - -INVALID_QUERIES = [ - ("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"), - ("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"), - ("INVALID SQL QUERY", "SQL compilation error"), -] - - -# Integration Test Fixtures -@pytest.fixture -def config(): - """Create a Snowflake configuration with test credentials.""" - return SnowflakeConfig( - account="lwyhjun-wx11931", - user="crewgitci", - password="crewaiT00ls_publicCIpass123", - warehouse="COMPUTE_WH", - database="tasty_bytes_sample_data", - snowflake_schema="raw_pos", - ) - - -@pytest.fixture -def snowflake_tool(config): - """Create a SnowflakeSearchTool instance.""" - return SnowflakeSearchTool(config=config) - - -# Integration Tests with Real Snowflake Connection -@pytest.mark.integration -@pytest.mark.asyncio -@pytest.mark.parametrize( - "menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS -) -async def test_menu_items( - snowflake_tool, - menu_id, - expected_type, - brand, - item_name, - category, - subcategory, - cost, - price, -): - """Test menu items with parameterized data for multiple test cases.""" - results = await snowflake_tool._run( - query=f"SELECT * FROM menu WHERE menu_id = {menu_id}" - ) - assert len(results) == 1 - menu_item = results[0] - - # Validate all fields - assert menu_item["MENU_ID"] == menu_id - assert menu_item["MENU_TYPE"] == expected_type - assert menu_item["TRUCK_BRAND_NAME"] == brand - assert menu_item["MENU_ITEM_NAME"] == item_name - assert menu_item["ITEM_CATEGORY"] == category - assert menu_item["ITEM_SUBCATEGORY"] == subcategory - assert menu_item["COST_OF_GOODS_USD"] == cost - assert menu_item["SALE_PRICE_USD"] == price - - # Validate health metrics JSON structure - health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"]) - assert "menu_item_health_metrics" in health_metrics - metrics = health_metrics["menu_item_health_metrics"][0] - assert "ingredients" in metrics - assert isinstance(metrics["ingredients"], list) - assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"]) - assert metrics["is_dairy_free_flag"] in ["Y", "N"] - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_menu_categories_aggregation(snowflake_tool): - """Test complex aggregation query on menu categories with detailed validations.""" - results = await snowflake_tool._run( - query=""" - SELECT - item_category, - COUNT(*) as item_count, - AVG(sale_price_usd) as avg_price, - SUM(sale_price_usd - cost_of_goods_usd) as total_margin, - COUNT(DISTINCT menu_type) as menu_type_count, - MIN(sale_price_usd) as min_price, - MAX(sale_price_usd) as max_price - FROM menu - GROUP BY item_category - HAVING COUNT(*) > 1 - ORDER BY item_count DESC - """ - ) - - assert len(results) > 0 - for category in results: - # Basic presence checks - assert all( - key in category - for key in [ - "ITEM_CATEGORY", - "ITEM_COUNT", - "AVG_PRICE", - "TOTAL_MARGIN", - "MENU_TYPE_COUNT", - "MIN_PRICE", - "MAX_PRICE", - ] - ) - - # Value validations - assert category["ITEM_COUNT"] > 1 # Due to HAVING clause - assert category["MIN_PRICE"] <= category["MAX_PRICE"] - assert category["AVG_PRICE"] >= category["MIN_PRICE"] - assert category["AVG_PRICE"] <= category["MAX_PRICE"] - assert category["MENU_TYPE_COUNT"] >= 1 - assert isinstance(category["TOTAL_MARGIN"], (float, Decimal)) - - -@pytest.mark.integration -@pytest.mark.asyncio -@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES) -async def test_invalid_queries(snowflake_tool, invalid_query, expected_error): - """Test error handling for invalid queries.""" - with pytest.raises((DatabaseError, OperationalError)) as exc_info: - await snowflake_tool._run(query=invalid_query) - assert expected_error.lower() in str(exc_info.value).lower() - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_concurrent_queries(snowflake_tool): - """Test handling of concurrent queries.""" - queries = [ - "SELECT COUNT(*) FROM menu", - "SELECT COUNT(DISTINCT menu_type) FROM menu", - "SELECT COUNT(DISTINCT item_category) FROM menu", - ] - - tasks = [snowflake_tool._run(query=query) for query in queries] - results = await asyncio.gather(*tasks) - - assert len(results) == 3 - assert all(isinstance(result, list) for result in results) - assert all(len(result) == 1 for result in results) - assert all(isinstance(result[0], dict) for result in results) - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_query_timeout(snowflake_tool): - """Test query timeout handling with a complex query.""" - with pytest.raises((DatabaseError, OperationalError)) as exc_info: - await snowflake_tool._run( - query=""" - WITH RECURSIVE numbers AS ( - SELECT 1 as n - UNION ALL - SELECT n + 1 - FROM numbers - WHERE n < 1000000 - ) - SELECT COUNT(*) FROM numbers - """ - ) - assert ( - "timeout" in str(exc_info.value).lower() - or "execution time" in str(exc_info.value).lower() - ) - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_caching_behavior(snowflake_tool): - """Test query caching behavior and performance.""" - query = "SELECT * FROM menu LIMIT 5" - - # First execution - start_time = asyncio.get_event_loop().time() - results1 = await snowflake_tool._run(query=query) - first_duration = asyncio.get_event_loop().time() - start_time - - # Second execution (should be cached) - start_time = asyncio.get_event_loop().time() - results2 = await snowflake_tool._run(query=query) - second_duration = asyncio.get_event_loop().time() - start_time - - # Verify results - assert results1 == results2 - assert len(results1) == 5 - assert second_duration < first_duration - - # Verify cache invalidation with different query - different_query = "SELECT * FROM menu LIMIT 10" - different_results = await snowflake_tool._run(query=different_query) - assert len(different_results) == 10 - assert different_results != results1 diff --git a/tests/spider_tool_test.py b/tests/spider_tool_test.py deleted file mode 100644 index 7f5613fe6..000000000 --- a/tests/spider_tool_test.py +++ /dev/null @@ -1,46 +0,0 @@ -from crewai import Agent, Crew, Task - -from crewai_tools.tools.spider_tool.spider_tool import SpiderTool - - -def test_spider_tool(): - spider_tool = SpiderTool() - - searcher = Agent( - role="Web Research Expert", - goal="Find related information from specific URL's", - backstory="An expert web researcher that uses the web extremely well", - tools=[spider_tool], - verbose=True, - cache=False, - ) - - choose_between_scrape_crawl = Task( - description="Scrape the page of spider.cloud and return a summary of how fast it is", - expected_output="spider.cloud is a fast scraping and crawling tool", - agent=searcher, - ) - - return_metadata = Task( - description="Scrape https://spider.cloud with a limit of 1 and enable metadata", - expected_output="Metadata and 10 word summary of spider.cloud", - agent=searcher, - ) - - css_selector = Task( - description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector", - expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1", - agent=searcher, - ) - - crew = Crew( - agents=[searcher], - tasks=[choose_between_scrape_crawl, return_metadata, css_selector], - verbose=True, - ) - - crew.kickoff() - - -if __name__ == "__main__": - test_spider_tool() diff --git a/tests/test_optional_dependencies.py b/tests/test_optional_dependencies.py new file mode 100644 index 000000000..b2d691a61 --- /dev/null +++ b/tests/test_optional_dependencies.py @@ -0,0 +1,41 @@ +import subprocess +import tempfile +from pathlib import Path + +import pytest + + +@pytest.fixture +def temp_project(): + temp_dir = tempfile.TemporaryDirectory() + project_dir = Path(temp_dir.name) / "test_project" + project_dir.mkdir() + + pyproject_content = f""" + [project] + name = "test-project" + version = "0.1.0" + description = "Test project" + requires-python = ">=3.10" + """ + + (project_dir / "pyproject.toml").write_text(pyproject_content) + run_command(["uv", "add", "--editable", f"file://{Path.cwd().absolute()}"], project_dir) + run_command(["uv", "sync"], project_dir) + yield project_dir + + +def run_command(cmd, cwd): + return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) + + +def test_no_optional_dependencies_in_init(temp_project): + """ + Test that crewai-tools can be imported without optional dependencies. + + The package defines optional dependencies in pyproject.toml, but the base + package should be importable without any of these optional dependencies + being installed. + """ + result = run_command(["uv", "run", "python", "-c", "import crewai_tools"], temp_project) + assert result.returncode == 0, f"Import failed with error: {result.stderr}" \ No newline at end of file diff --git a/tests/tools/selenium_scraping_tool_test.py b/tests/tools/selenium_scraping_tool_test.py index 271047449..4e0b890b5 100644 --- a/tests/tools/selenium_scraping_tool_test.py +++ b/tests/tools/selenium_scraping_tool_test.py @@ -1,4 +1,6 @@ from unittest.mock import MagicMock, patch +import tempfile +import os from bs4 import BeautifulSoup @@ -27,7 +29,11 @@ def initialize_tool_with(mock_driver): return tool -def test_tool_initialization(): +@patch("selenium.webdriver.Chrome") +def test_tool_initialization(mocked_chrome): + temp_dir = tempfile.mkdtemp() + mocked_chrome.return_value = MagicMock() + tool = SeleniumScrapingTool() assert tool.website_url is None @@ -35,6 +41,11 @@ def test_tool_initialization(): assert tool.cookie is None assert tool.wait_time == 3 assert tool.return_html is False + + try: + os.rmdir(temp_dir) + except: + pass @patch("selenium.webdriver.Chrome")