Release/v1.0.0 (#3618)

* feat: add `apps` & `actions` attributes to Agent (#3504)

* feat: add app attributes to Agent

* feat: add actions attribute to Agent

* chore: resolve linter issues

* refactor: merge the apps and actions parameters into a single one

* fix: remove unnecessary print

* feat: logging error when CrewaiPlatformTools fails

* chore: export CrewaiPlatformTools directly from crewai_tools

* style: resolver linter issues

* test: fix broken tests

* style: solve linter issues

* fix: fix broken test

* feat: monorepo restructure and test/ci updates

- Add crewai workspace member
- Fix vcr cassette paths and restore test dirs
- Resolve ci failures and update linter/pytest rules

* chore: update python version to 3.13 and package metadata

* feat: add crewai-tools workspace and fix tests/dependencies

* feat: add crewai-tools workspace structure

* Squashed 'temp-crewai-tools/' content from commit 9bae5633

git-subtree-dir: temp-crewai-tools
git-subtree-split: 9bae56339096cb70f03873e600192bd2cd207ac9

* feat: configure crewai-tools workspace package with dependencies

* fix: apply ruff auto-formatting to crewai-tools code

* chore: update lockfile

* fix: don't allow tool tests yet

* fix: comment out extra pytest flags for now

* fix: remove conflicting conftest.py from crewai-tools tests

* fix: resolve dependency conflicts and test issues

- Pin vcrpy to 7.0.0 to fix pytest-recording compatibility
- Comment out types-requests to resolve urllib3 conflict
- Update requests requirement in crewai-tools to >=2.32.0

* chore: update CI workflows and docs for monorepo structure

* chore: update CI workflows and docs for monorepo structure

* fix: actions syntax

* chore: ci publish and pin versions

* fix: add permission to action

* chore: bump version to 1.0.0a1 across all packages

- Updated version to 1.0.0a1 in pyproject.toml for crewai and crewai-tools
- Adjusted version in __init__.py files for consistency

* WIP: v1 docs (#3626)

(cherry picked from commit d46e20fa09bcd2f5916282f5553ddeb7183bd92c)

* docs: parity for all translations

* docs: full name of acronym AMP

* docs: fix lingering unused code

* docs: expand contextual options in docs.json

* docs: add contextual action to request feature on GitHub (#3635)

* chore: apply linting fixes to crewai-tools

* feat: add required env var validation for brightdata

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>

* fix: handle properly anyOf oneOf allOf schema's props

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>

* feat: bump version to 1.0.0a2

* Lorenze/native inference sdks (#3619)

* ruff linted

* using native sdks with litellm fallback

* drop exa

* drop print on completion

* Refactor LLM and utility functions for type consistency

- Updated `max_tokens` parameter in `LLM` class to accept `float` in addition to `int`.
- Modified `create_llm` function to ensure consistent type hints and return types, now returning `LLM | BaseLLM | None`.
- Adjusted type hints for various parameters in `create_llm` and `_llm_via_environment_or_fallback` functions for improved clarity and type safety.
- Enhanced test cases to reflect changes in type handling and ensure proper instantiation of LLM instances.

* fix agent_tests

* fix litellm tests and usagemetrics fix

* drop print

* Refactor LLM event handling and improve test coverage

- Removed commented-out event emission for LLM call failures in `llm.py`.
- Added `from_agent` parameter to `CrewAgentExecutor` for better context in LLM responses.
- Enhanced test for LLM call failure to simulate OpenAI API failure and updated assertions for clarity.
- Updated agent and task ID assertions in tests to ensure they are consistently treated as strings.

* fix test_converter

* fixed tests/agents/test_agent.py

* Refactor LLM context length exception handling and improve provider integration

- Renamed `LLMContextLengthExceededException` to `LLMContextLengthExceededExceptionError` for clarity and consistency.
- Updated LLM class to pass the provider parameter correctly during initialization.
- Enhanced error handling in various LLM provider implementations to raise the new exception type.
- Adjusted tests to reflect the updated exception name and ensure proper error handling in context length scenarios.

* Enhance LLM context window handling across providers

- Introduced CONTEXT_WINDOW_USAGE_RATIO to adjust context window sizes dynamically for Anthropic, Azure, Gemini, and OpenAI LLMs.
- Added validation for context window sizes in Azure and Gemini providers to ensure they fall within acceptable limits.
- Updated context window size calculations to use the new ratio, improving consistency and adaptability across different models.
- Removed hardcoded context window sizes in favor of ratio-based calculations for better flexibility.

* fix test agent again

* fix test agent

* feat: add native LLM providers for Anthropic, Azure, and Gemini

- Introduced new completion implementations for Anthropic, Azure, and Gemini, integrating their respective SDKs.
- Added utility functions for tool validation and extraction to support function calling across LLM providers.
- Enhanced context window management and token usage extraction for each provider.
- Created a common utility module for shared functionality among LLM providers.

* chore: update dependencies and improve context management

- Removed direct dependency on `litellm` from the main dependencies and added it under extras for better modularity.
- Updated the `litellm` dependency specification to allow for greater flexibility in versioning.
- Refactored context length exception handling across various LLM providers to use a consistent error class.
- Enhanced platform-specific dependency markers for NVIDIA packages to ensure compatibility across different systems.

* refactor(tests): update LLM instantiation to include is_litellm flag in test cases

- Modified multiple test cases in test_llm.py to set the is_litellm parameter to True when instantiating the LLM class.
- This change ensures that the tests are aligned with the latest LLM configuration requirements and improves consistency across test scenarios.
- Adjusted relevant assertions and comments to reflect the updated LLM behavior.

* linter

* linted

* revert constants

* fix(tests): correct type hint in expected model description

- Updated the expected description in the test_generate_model_description_dict_field function to use 'Dict' instead of 'dict' for consistency with type hinting conventions.
- This change ensures that the test accurately reflects the expected output format for model descriptions.

* refactor(llm): enhance LLM instantiation and error handling

- Updated the LLM class to include validation for the model parameter, ensuring it is a non-empty string.
- Improved error handling by logging warnings when the native SDK fails, allowing for a fallback to LiteLLM.
- Adjusted the instantiation of LLM in test cases to consistently include the is_litellm flag, aligning with recent changes in LLM configuration.
- Modified relevant tests to reflect these updates, ensuring better coverage and accuracy in testing scenarios.

* fixed test

* refactor(llm): enhance token usage tracking and add copy methods

- Updated the LLM class to track token usage and log callbacks in streaming mode, improving monitoring capabilities.
- Introduced shallow and deep copy methods for the LLM instance, allowing for better management of LLM configurations and parameters.
- Adjusted test cases to instantiate LLM with the is_litellm flag, ensuring alignment with recent changes in LLM configuration.

* refactor(tests): reorganize imports and enhance error messages in test cases

- Cleaned up import statements in test_crew.py for better organization and readability.
- Enhanced error messages in test cases to use `re.escape` for improved regex matching, ensuring more robust error handling.
- Adjusted comments for clarity and consistency across test scenarios.
- Ensured that all necessary modules are imported correctly to avoid potential runtime issues.

* feat: add base devtooling

* fix: ensure dep refs are updated for devtools

* fix: allow pre-release

* feat: allow release after tag

* feat: bump versions to 1.0.0a3 

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>

* fix: match tag and release title, ignore devtools build for pypi

* fix: allow failed pypi publish

* feat: introduce trigger listing and execution commands for local development (#3643)

* chore: exclude tests from ruff linting

* chore: exclude tests from GitHub Actions linter

* fix: replace print statements with logger in agent and memory handling

* chore: add noqa for intentional print in printer utility

* fix: resolve linting errors across codebase

* feat: update docs with new approach to consume Platform Actions (#3675)

* fix: remove duplicate line and add explicit env var

* feat: bump versions to 1.0.0a4 (#3686)

* Update triggers docs (#3678)

* docs: introduce triggers list & triggers run command

* docs: add KO triggers docs

* docs: ensure CREWAI_PLATFORM_INTEGRATION_TOKEN is mentioned on docs (#3687)

* Lorenze/bedrock llm (#3693)

* feat: add AWS Bedrock support and update dependencies

- Introduced BedrockCompletion class for AWS Bedrock integration in LLM.
- Added boto3 as a new dependency in both pyproject.toml and uv.lock.
- Updated LLM class to support Bedrock provider.
- Created new files for Bedrock provider implementation.

* using converse api

* converse

* linted

* refactor: update BedrockCompletion class to improve parameter handling

- Changed max_tokens from a fixed integer to an optional integer.
- Simplified model ID assignment by removing the inference profile mapping method.
- Cleaned up comments and unnecessary code related to tool specifications and model-specific parameters.

* feat: improve event bus thread safety and async support

Add thread-safe, async-compatible event bus with read–write locking and
handler dependency ordering. Remove blinker dependency and implement
direct dispatch. Improve type safety, error handling, and deterministic
event synchronization.

Refactor tests to auto-wait for async handlers, ensure clean teardown,
and add comprehensive concurrency coverage. Replace thread-local state
in AgentEvaluator with instance-based locking for correct cross-thread
access. Enhance tracing reliability and event finalization.

* feat: enhance OpenAICompletion class with additional client parameters (#3701)

* feat: enhance OpenAICompletion class with additional client parameters

- Added support for default_headers, default_query, and client_params in the OpenAICompletion class.
- Refactored client initialization to use a dedicated method for client parameter retrieval.
- Introduced new test cases to validate the correct usage of OpenAICompletion with various parameters.

* fix: correct test case for unsupported OpenAI model

- Updated the test_openai.py to ensure that the LLM instance is created before calling the method, maintaining proper error handling for unsupported models.
- This change ensures that the test accurately checks for the NotFoundError when an invalid model is specified.

* fix: enhance error handling in OpenAICompletion class

- Added specific exception handling for NotFoundError and APIConnectionError in the OpenAICompletion class to provide clearer error messages and improve logging.
- Updated the test case for unsupported models to ensure it raises a ValueError with the appropriate message when a non-existent model is specified.
- This change improves the robustness of the OpenAI API integration and enhances the clarity of error reporting.

* fix: improve test for unsupported OpenAI model handling

- Refactored the test case in test_openai.py to create the LLM instance after mocking the OpenAI client, ensuring proper error handling for unsupported models.
- This change enhances the clarity of the test by accurately checking for ValueError when a non-existent model is specified, aligning with recent improvements in error handling for the OpenAICompletion class.

* feat: bump versions to 1.0.0b1 (#3706)

* Lorenze/tools drop litellm (#3710)

* completely drop litellm and correctly pass config for qdrant

* feat: add support for additional embedding models in EmbeddingService

- Expanded the list of supported embedding models to include Google Vertex, Hugging Face, Jina, Ollama, OpenAI, Roboflow, Watson X, custom embeddings, Sentence Transformers, Text2Vec, OpenClip, and Instructor.
- This enhancement improves the versatility of the EmbeddingService by allowing integration with a wider range of embedding providers.

* fix: update collection parameter handling in CrewAIRagAdapter

- Changed the condition for setting vectors_config in the CrewAIRagAdapter to check for QdrantConfig instance instead of using hasattr. This improves type safety and ensures proper configuration handling for Qdrant integration.

* moved stagehand as optional dep (#3712)

* feat: bump versions to 1.0.0b2 (#3713)

* feat: enhance AnthropicCompletion class with additional client parame… (#3707)

* feat: enhance AnthropicCompletion class with additional client parameters and tool handling

- Added support for client_params in the AnthropicCompletion class to allow for additional client configuration.
- Refactored client initialization to use a dedicated method for retrieving client parameters.
- Implemented a new method to handle tool use conversation flow, ensuring proper execution and response handling.
- Introduced comprehensive test cases to validate the functionality of the AnthropicCompletion class, including tool use scenarios and parameter handling.

* drop print statements

* test: add fixture to mock ANTHROPIC_API_KEY for tests

- Introduced a pytest fixture to automatically mock the ANTHROPIC_API_KEY environment variable for all tests in the test_anthropic.py module.
- This change ensures that tests can run without requiring a real API key, improving test isolation and reliability.

* refactor: streamline streaming message handling in AnthropicCompletion class

- Removed the 'stream' parameter from the API call as it is set internally by the SDK.
- Simplified the handling of tool use events and response construction by extracting token usage from the final message.
- Enhanced the flow for managing tool use conversation, ensuring proper integration with the streaming API response.

* fix streaming here too

* fix: improve error handling in tool conversion for AnthropicCompletion class

- Enhanced exception handling during tool conversion by catching KeyError and ValueError.
- Added logging for conversion errors to aid in debugging and maintain robustness in tool integration.

* feat: enhance GeminiCompletion class with client parameter support (#3717)

* feat: enhance GeminiCompletion class with client parameter support

- Added support for client_params in the GeminiCompletion class to allow for additional client configuration.
- Refactored client initialization into a dedicated method for improved parameter handling.
- Introduced a new method to retrieve client parameters, ensuring compatibility with the base class.
- Enhanced error handling during client initialization to provide clearer messages for missing configuration.
- Updated documentation to reflect the changes in client parameter usage.

* add optional dependancies

* refactor: update test fixture to mock GOOGLE_API_KEY

- Renamed the fixture from `mock_anthropic_api_key` to `mock_google_api_key` to reflect the change in the environment variable being mocked.
- This update ensures that all tests in the module can run with a mocked GOOGLE_API_KEY, improving test isolation and reliability.

* fix tests

* feat: enhance BedrockCompletion class with advanced features

* feat: enhance BedrockCompletion class with advanced features and error handling

- Added support for guardrail configuration, additional model request fields, and custom response field paths in the BedrockCompletion class.
- Improved error handling for AWS exceptions and added token usage tracking with stop reason logging.
- Enhanced streaming response handling with comprehensive event management, including tool use and content block processing.
- Updated documentation to reflect new features and initialization parameters.
- Introduced a new test suite for BedrockCompletion to validate functionality and ensure robust integration with AWS Bedrock APIs.

* chore: add boto typing

* fix: use typing_extensions.Required for Python 3.10 compatibility

---------

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>

* feat: azure native tests

* feat: add Azure AI Inference support and related tests

- Introduced the `azure-ai-inference` package with version `1.0.0b9` and its dependencies in `uv.lock` and `pyproject.toml`.
- Added new test files for Azure LLM functionality, including tests for Azure completion and tool handling.
- Implemented comprehensive test cases to validate Azure-specific behavior and integration with the CrewAI framework.
- Enhanced the testing framework to mock Azure credentials and ensure proper isolation during tests.

* feat: enhance AzureCompletion class with Azure OpenAI support

- Added support for the Azure OpenAI endpoint in the AzureCompletion class, allowing for flexible endpoint configurations.
- Implemented endpoint validation and correction to ensure proper URL formats for Azure OpenAI deployments.
- Enhanced error handling to provide clearer messages for common HTTP errors, including authentication and rate limit issues.
- Updated tests to validate the new endpoint handling and error messaging, ensuring robust integration with Azure AI Inference.
- Refactored parameter preparation to conditionally include the model parameter based on the endpoint type.

* refactor: convert project module to metaclass with full typing

* Lorenze/OpenAI base url backwards support (#3723)

* fix: enhance OpenAICompletion class base URL handling

- Updated the base URL assignment in the OpenAICompletion class to prioritize the new `api_base` attribute and fallback to the environment variable `OPENAI_BASE_URL` if both are not set.
- Added `api_base` to the list of parameters in the OpenAICompletion class to ensure proper configuration and flexibility in API endpoint management.

* feat: enhance OpenAICompletion class with api_base support

- Added the `api_base` parameter to the OpenAICompletion class to allow for flexible API endpoint configuration.
- Updated the `_get_client_params` method to prioritize `base_url` over `api_base`, ensuring correct URL handling.
- Introduced comprehensive tests to validate the behavior of `api_base` and `base_url` in various scenarios, including environment variable fallback.
- Enhanced test coverage for client parameter retrieval, ensuring robust integration with the OpenAI API.

* fix: improve OpenAICompletion class configuration handling

- Added a debug print statement to log the client configuration parameters during initialization for better traceability.
- Updated the base URL assignment logic to ensure it defaults to None if no valid base URL is provided, enhancing robustness in API endpoint configuration.
- Refined the retrieval of the `api_base` environment variable to streamline the configuration process.

* drop print

* feat: improvements on import native sdk support (#3725)

* feat: add support for Anthropic provider and enhance logging

- Introduced the `anthropic` package with version `0.69.0` in `pyproject.toml` and `uv.lock`, allowing for integration with the Anthropic API.
- Updated logging in the LLM class to provide clearer error messages when importing native providers, enhancing debugging capabilities.
- Improved error handling in the AnthropicCompletion class to guide users on installation via the updated error message format.
- Refactored import error handling in other provider classes to maintain consistency in error messaging and installation instructions.

* feat: enhance LLM support with Bedrock provider and update dependencies

- Added support for the `bedrock` provider in the LLM class, allowing integration with AWS Bedrock APIs.
- Updated `uv.lock` to replace `boto3` with `bedrock` in the dependencies, reflecting the new provider structure.
- Introduced `SUPPORTED_NATIVE_PROVIDERS` to include `bedrock` and ensure proper error handling when instantiating native providers.
- Enhanced error handling in the LLM class to raise informative errors when native provider instantiation fails.
- Added tests to validate the behavior of the new Bedrock provider and ensure fallback mechanisms work correctly for unsupported providers.

* test: update native provider fallback tests to expect ImportError

* adjust the test with the expected bevaior - raising ImportError

* this is exoecting the litellm format, all gemini native tests are in test_google.py

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>

* fix: remove stdout prints, improve test determinism, and update trace handling

Removed `print` statements from the `LLMStreamChunkEvent` handler to prevent
LLM response chunks from being written directly to stdout. The listener now
only tracks chunks internally.

Fixes #3715

Added explicit return statements for trace-related tests.

Updated cassette for `test_failed_evaluation` to reflect new behavior where
an empty trace dict is used instead of returning early.

Ensured deterministic cleanup order in test fixtures by making
`clear_event_bus_handlers` depend on `setup_test_environment`. This guarantees
event bus shutdown and file handle cleanup occur before temporary directory
deletion, resolving intermittent “Directory not empty” errors in CI.

* chore: remove lib/crewai exclusion from pre-commit hooks

* feat: enhance task guardrail functionality and validation

* feat: enhance task guardrail functionality and validation

- Introduced support for multiple guardrails in the Task class, allowing for sequential processing of guardrails.
- Added a new `guardrails` field to the Task model to accept a list of callable guardrails or string descriptions.
- Implemented validation to ensure guardrails are processed correctly, including handling of retries and error messages.
- Enhanced the `_invoke_guardrail_function` method to manage guardrail execution and integrate with existing task output processing.
- Updated tests to cover various scenarios involving multiple guardrails, including success, failure, and retry mechanisms.

This update improves the flexibility and robustness of task execution by allowing for more complex validation scenarios.

* refactor: enhance guardrail type handling in Task model

- Updated the Task class to improve guardrail type definitions, introducing GuardrailType and GuardrailsType for better clarity and type safety.
- Simplified the validation logic for guardrails, ensuring that both single and multiple guardrails are processed correctly.
- Enhanced error messages for guardrail validation to provide clearer feedback when incorrect types are provided.
- This refactor improves the maintainability and robustness of task execution by standardizing guardrail handling.

* feat: implement per-guardrail retry tracking in Task model

- Introduced a new private attribute `_guardrail_retry_counts` to the Task class for tracking retry attempts on a per-guardrail basis.
- Updated the guardrail processing logic to utilize the new retry tracking, allowing for independent retry counts for each guardrail.
- Enhanced error handling to provide clearer feedback when guardrails fail validation after exceeding retry limits.
- Modified existing tests to validate the new retry tracking behavior, ensuring accurate assertions on guardrail retries.

This update improves the robustness and flexibility of task execution by allowing for more granular control over guardrail validation and retry mechanisms.

* chore: 1.0.0b3 bump (#3734)

* chore: full ruff and mypy

improved linting, pre-commit setup, and internal architecture. Configured Ruff to respect .gitignore, added stricter rules, and introduced a lock pre-commit hook with virtualenv activation. Fixed type shadowing in EXASearchTool using a type_ alias to avoid PEP 563 conflicts and resolved circular imports in agent executor and guardrail modules. Removed agent-ops attributes, deprecated watson alias, and dropped crewai-enterprise tools with corresponding test updates. Refactored cache and memoization for thread safety and cleaned up structured output adapters and related logic.

* New MCL DSL (#3738)

* Adding MCP implementation

* New tests for MCP implementation

* fix tests

* update docs

* Revert "New tests for MCP implementation"

This reverts commit 0bbe6dee90.

* linter

* linter

* fix

* verify mcp pacakge exists

* adjust docs to be clear only remote servers are supported

* reverted

* ensure args schema generated properly

* properly close out

---------

Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>

* feat: a2a experimental

experimental a2a support

---------

Co-authored-by: Lucas Gomide <lucaslg200@gmail.com>
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Tony Kipkemboi <iamtonykipkemboi@gmail.com>
Co-authored-by: Mike Plachta <mplachta@users.noreply.github.com>
Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Lorenze Jay
2025-10-20 14:10:19 -07:00
committed by GitHub
parent 42f2b4d551
commit d1343b96ed
1339 changed files with 111657 additions and 19564 deletions

View File

View File

@@ -0,0 +1,130 @@
from pathlib import Path
from unittest.mock import MagicMock, patch
import urllib.error
import xml.etree.ElementTree as ET
from crewai_tools import ArxivPaperTool
import pytest
@pytest.fixture
def tool():
return ArxivPaperTool(download_pdfs=False)
def mock_arxiv_response():
return """<?xml version="1.0" encoding="UTF-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<entry>
<id>http://arxiv.org/abs/1234.5678</id>
<title>Sample Paper</title>
<summary>This is a summary of the sample paper.</summary>
<published>2022-01-01T00:00:00Z</published>
<author><name>John Doe</name></author>
<link title="pdf" href="http://arxiv.org/pdf/1234.5678.pdf"/>
</entry>
</feed>"""
@patch("urllib.request.urlopen")
def test_fetch_arxiv_data(mock_urlopen, tool):
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
mock_urlopen.return_value.__enter__.return_value = mock_response
results = tool.fetch_arxiv_data("transformer", 1)
assert isinstance(results, list)
assert results[0]["title"] == "Sample Paper"
@patch("urllib.request.urlopen", side_effect=urllib.error.URLError("Timeout"))
def test_fetch_arxiv_data_network_error(mock_urlopen, tool):
with pytest.raises(urllib.error.URLError):
tool.fetch_arxiv_data("transformer", 1)
@patch("urllib.request.urlretrieve")
def test_download_pdf_success(mock_urlretrieve):
tool = ArxivPaperTool()
tool.download_pdf("http://arxiv.org/pdf/1234.5678.pdf", Path("test.pdf"))
mock_urlretrieve.assert_called_once()
@patch("urllib.request.urlretrieve", side_effect=OSError("Permission denied"))
def test_download_pdf_oserror(mock_urlretrieve):
tool = ArxivPaperTool()
with pytest.raises(OSError):
tool.download_pdf(
"http://arxiv.org/pdf/1234.5678.pdf", Path("/restricted/test.pdf")
)
@patch("urllib.request.urlopen")
@patch("urllib.request.urlretrieve")
def test_run_with_download(mock_urlretrieve, mock_urlopen):
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
mock_urlopen.return_value.__enter__.return_value = mock_response
tool = ArxivPaperTool(download_pdfs=True)
output = tool._run("transformer", 1)
assert "Title: Sample Paper" in output
mock_urlretrieve.assert_called_once()
@patch("urllib.request.urlopen")
def test_run_no_download(mock_urlopen):
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = mock_arxiv_response().encode("utf-8")
mock_urlopen.return_value.__enter__.return_value = mock_response
tool = ArxivPaperTool(download_pdfs=False)
result = tool._run("transformer", 1)
assert "Title: Sample Paper" in result
@patch("pathlib.Path.mkdir")
def test_validate_save_path_creates_directory(mock_mkdir):
path = ArxivPaperTool._validate_save_path("new_folder")
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
assert isinstance(path, Path)
@patch("urllib.request.urlopen")
def test_run_handles_exception(mock_urlopen):
mock_urlopen.side_effect = Exception("API failure")
tool = ArxivPaperTool()
result = tool._run("transformer", 1)
assert "Failed to fetch or download Arxiv papers" in result
@patch("urllib.request.urlopen")
def test_invalid_xml_response(mock_urlopen, tool):
mock_response = MagicMock()
mock_response.read.return_value = b"<invalid><xml>"
mock_response.status = 200
mock_urlopen.return_value.__enter__.return_value = mock_response
with pytest.raises(ET.ParseError):
tool.fetch_arxiv_data("quantum", 1)
@patch.object(ArxivPaperTool, "fetch_arxiv_data")
def test_run_with_max_results(mock_fetch, tool):
mock_fetch.return_value = [
{
"arxiv_id": f"test_{i}",
"title": f"Title {i}",
"summary": "Summary",
"authors": ["Author"],
"published_date": "2023-01-01",
"pdf_url": None,
}
for i in range(100)
]
result = tool._run(search_query="test", max_results=100)
assert result.count("Title:") == 100

View File

@@ -0,0 +1,48 @@
from unittest.mock import patch
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
import pytest
@pytest.fixture
def brave_tool():
return BraveSearchTool(n_results=2)
def test_brave_tool_initialization():
tool = BraveSearchTool()
assert tool.n_results == 10
assert tool.save_file is False
@patch("requests.get")
def test_brave_tool_search(mock_get, brave_tool):
mock_response = {
"web": {
"results": [
{
"title": "Test Title",
"url": "http://test.com",
"description": "Test Description",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
result = brave_tool.run(search_query="test")
assert "Test Title" in result
assert "http://test.com" in result
def test_brave_tool():
tool = BraveSearchTool(
n_results=2,
)
tool.run(search_query="ChatGPT")
if __name__ == "__main__":
test_brave_tool()
test_brave_tool_initialization()
# test_brave_tool_search(brave_tool)

View File

@@ -0,0 +1,54 @@
import unittest
from unittest.mock import MagicMock, patch
from crewai_tools.tools.brightdata_tool.brightdata_serp import BrightDataSearchTool
class TestBrightDataSearchTool(unittest.TestCase):
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
def setUp(self):
self.tool = BrightDataSearchTool()
@patch("requests.post")
def test_run_successful_search(self, mock_post):
# Sample mock JSON response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "mock response text"
mock_post.return_value = mock_response
# Define search input
input_data = {
"query": "latest AI news",
"search_engine": "google",
"country": "us",
"language": "en",
"search_type": "nws",
"device_type": "desktop",
"parse_results": True,
"save_file": False,
}
result = self.tool._run(**input_data)
# Assertions
self.assertIsInstance(result, str) # Your tool returns response.text (string)
mock_post.assert_called_once()
@patch("requests.post")
def test_run_with_request_exception(self, mock_post):
mock_post.side_effect = Exception("Timeout")
result = self.tool._run(query="AI", search_engine="google")
self.assertIn("Error", result)
def tearDown(self):
# Clean up env vars
pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,61 @@
from unittest.mock import Mock, patch
from crewai_tools.tools.brightdata_tool.brightdata_unlocker import (
BrightDataWebUnlockerTool,
)
import requests
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_success_html(mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html><body>Test</body></html>"
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
tool._run(url="https://example.com", format="html", save_file=False)
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_success_json(mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "mock response text"
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
result = tool._run(url="https://example.com", format="json")
assert isinstance(result, str)
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_http_error(mock_post):
mock_response = Mock()
mock_response.status_code = 403
mock_response.text = "Forbidden"
mock_response.raise_for_status.side_effect = requests.HTTPError(
response=mock_response
)
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
result = tool._run(url="https://example.com")
assert "HTTP Error" in result
assert "Forbidden" in result

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,450 @@
from unittest.mock import MagicMock, patch
import pytest
# Mock the couchbase library before importing the tool
# This prevents ImportErrors if couchbase isn't installed in the test environment
mock_couchbase = MagicMock()
mock_couchbase.search = MagicMock()
mock_couchbase.cluster = MagicMock()
mock_couchbase.options = MagicMock()
mock_couchbase.vector_search = MagicMock()
# Simulate the structure needed for checks
mock_couchbase.cluster.Cluster = MagicMock()
mock_couchbase.options.SearchOptions = MagicMock()
mock_couchbase.vector_search.VectorQuery = MagicMock()
mock_couchbase.vector_search.VectorSearch = MagicMock()
mock_couchbase.search.SearchRequest = MagicMock() # Mock the class itself
mock_couchbase.search.SearchRequest.create = MagicMock() # Mock the class method
# Add necessary exception types if needed for testing error handling
class MockCouchbaseException(Exception):
pass
mock_couchbase.exceptions = MagicMock()
mock_couchbase.exceptions.BucketNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.ScopeNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.CollectionNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.IndexNotFoundException = MockCouchbaseException
import sys
sys.modules["couchbase"] = mock_couchbase
sys.modules["couchbase.search"] = mock_couchbase.search
sys.modules["couchbase.cluster"] = mock_couchbase.cluster
sys.modules["couchbase.options"] = mock_couchbase.options
sys.modules["couchbase.vector_search"] = mock_couchbase.vector_search
sys.modules["couchbase.exceptions"] = mock_couchbase.exceptions
# Now import the tool
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
CouchbaseFTSVectorSearchTool,
)
# --- Test Fixtures ---
@pytest.fixture(autouse=True)
def reset_global_mocks():
"""Reset call counts for globally defined mocks before each test."""
# Reset the specific mock causing the issue
mock_couchbase.vector_search.VectorQuery.reset_mock()
# It's good practice to also reset other related global mocks
# that might be called in your tests to prevent similar issues:
mock_couchbase.vector_search.VectorSearch.from_vector_query.reset_mock()
mock_couchbase.search.SearchRequest.create.reset_mock()
# Additional fixture to handle import pollution in full test suite
@pytest.fixture(autouse=True)
def ensure_couchbase_mocks():
"""Ensure that couchbase imports are properly mocked even when other tests have run first."""
# This fixture ensures our mocks are in place regardless of import order
original_modules = {}
# Store any existing modules
for module_name in [
"couchbase",
"couchbase.search",
"couchbase.cluster",
"couchbase.options",
"couchbase.vector_search",
"couchbase.exceptions",
]:
if module_name in sys.modules:
original_modules[module_name] = sys.modules[module_name]
# Ensure our mocks are active
sys.modules["couchbase"] = mock_couchbase
sys.modules["couchbase.search"] = mock_couchbase.search
sys.modules["couchbase.cluster"] = mock_couchbase.cluster
sys.modules["couchbase.options"] = mock_couchbase.options
sys.modules["couchbase.vector_search"] = mock_couchbase.vector_search
sys.modules["couchbase.exceptions"] = mock_couchbase.exceptions
yield
# Restore original modules if they existed
for module_name, original_module in original_modules.items():
if original_module is not None:
sys.modules[module_name] = original_module
@pytest.fixture
def mock_cluster():
cluster = MagicMock()
bucket_manager = MagicMock()
search_index_manager = MagicMock()
bucket = MagicMock()
scope = MagicMock()
collection = MagicMock()
scope_search_index_manager = MagicMock()
# Setup mock return values for checks
cluster.buckets.return_value = bucket_manager
cluster.search_indexes.return_value = search_index_manager
cluster.bucket.return_value = bucket
bucket.scope.return_value = scope
scope.collection.return_value = collection
scope.search_indexes.return_value = scope_search_index_manager
# Mock bucket existence check
bucket_manager.get_bucket.return_value = True
# Mock scope/collection existence check
mock_scope_spec = MagicMock()
mock_scope_spec.name = "test_scope"
mock_collection_spec = MagicMock()
mock_collection_spec.name = "test_collection"
mock_scope_spec.collections = [mock_collection_spec]
bucket.collections.return_value.get_all_scopes.return_value = [mock_scope_spec]
# Mock index existence check
mock_index_def = MagicMock()
mock_index_def.name = "test_index"
scope_search_index_manager.get_all_indexes.return_value = [mock_index_def]
search_index_manager.get_all_indexes.return_value = [mock_index_def]
return cluster
@pytest.fixture
def mock_embedding_function():
# Simple mock embedding function
# return lambda query: [0.1] * 10 # Example embedding vector
return MagicMock(return_value=[0.1] * 10)
@pytest.fixture
def tool_config(mock_cluster, mock_embedding_function):
return {
"cluster": mock_cluster,
"bucket_name": "test_bucket",
"scope_name": "test_scope",
"collection_name": "test_collection",
"index_name": "test_index",
"embedding_function": mock_embedding_function,
"limit": 5,
"embedding_key": "test_embedding",
"scoped_index": True,
}
@pytest.fixture
def couchbase_tool(tool_config):
# Patch COUCHBASE_AVAILABLE to True for these tests
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
tool = CouchbaseFTSVectorSearchTool(**tool_config)
return tool
@pytest.fixture
def mock_search_iter():
mock_iter = MagicMock()
# Simulate search results with a 'fields' attribute
mock_row1 = MagicMock()
mock_row1.fields = {"id": "doc1", "text": "content 1", "test_embedding": [0.1] * 10}
mock_row2 = MagicMock()
mock_row2.fields = {"id": "doc2", "text": "content 2", "test_embedding": [0.2] * 10}
mock_iter.rows.return_value = [mock_row1, mock_row2]
return mock_iter
# --- Test Cases ---
def test_initialization_success(couchbase_tool, tool_config):
"""Test successful initialization with valid config."""
assert couchbase_tool.cluster == tool_config["cluster"]
assert couchbase_tool.bucket_name == "test_bucket"
assert couchbase_tool.scope_name == "test_scope"
assert couchbase_tool.collection_name == "test_collection"
assert couchbase_tool.index_name == "test_index"
assert couchbase_tool.embedding_function is not None
assert couchbase_tool.limit == 5
assert couchbase_tool.embedding_key == "test_embedding"
assert couchbase_tool.scoped_index
# Check if helper methods were called during init (via mocks in fixture)
couchbase_tool.cluster.buckets().get_bucket.assert_called_once_with("test_bucket")
couchbase_tool.cluster.bucket().collections().get_all_scopes.assert_called_once()
couchbase_tool.cluster.bucket().scope().search_indexes().get_all_indexes.assert_called_once()
def test_initialization_missing_required_args(mock_cluster, mock_embedding_function):
"""Test initialization fails when required arguments are missing."""
base_config = {
"cluster": mock_cluster,
"bucket_name": "b",
"scope_name": "s",
"collection_name": "c",
"index_name": "i",
"embedding_function": mock_embedding_function,
}
required_keys = base_config.keys()
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
for key in required_keys:
incomplete_config = base_config.copy()
del incomplete_config[key]
with pytest.raises(ValueError):
CouchbaseFTSVectorSearchTool(**incomplete_config)
def test_initialization_couchbase_unavailable():
"""Test behavior when couchbase library is not available."""
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", False
):
with patch("click.confirm", return_value=False) as mock_confirm:
with pytest.raises(
ImportError, match="The 'couchbase' package is required"
):
CouchbaseFTSVectorSearchTool(
cluster=MagicMock(),
bucket_name="b",
scope_name="s",
collection_name="c",
index_name="i",
embedding_function=MagicMock(),
)
mock_confirm.assert_called_once() # Ensure user was prompted
def test_run_success_scoped_index(
couchbase_tool, mock_search_iter, tool_config, mock_embedding_function
):
"""Test successful _run execution with a scoped index."""
query = "find relevant documents"
# expected_embedding = mock_embedding_function(query)
# Mock the scope search method
couchbase_tool._scope.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with (
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery"
) as mock_vq,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch"
) as mock_vs,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest"
) as mock_sr,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions"
) as mock_so,
):
# Set up the mock objects and their return values
mock_vector_query = MagicMock()
mock_vector_search = MagicMock()
mock_search_req = MagicMock()
mock_search_options = MagicMock()
mock_vq.return_value = mock_vector_query
mock_vs.from_vector_query.return_value = mock_vector_search
mock_sr.create.return_value = mock_search_req
mock_so.return_value = mock_search_options
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config["embedding_function"].assert_called_once_with(query)
# Check VectorQuery call
mock_vq.assert_called_once_with(
tool_config["embedding_key"],
mock_embedding_function.return_value,
tool_config["limit"],
)
# Check VectorSearch call
mock_vs.from_vector_query.assert_called_once_with(mock_vector_query)
# Check SearchRequest creation
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
# Check that scope search was called correctly
couchbase_tool._scope.search.assert_called_once_with(
tool_config["index_name"], mock_search_req, mock_search_options
)
# Check cluster search was NOT called
couchbase_tool.cluster.search.assert_not_called()
# Check result format (simple check for JSON structure)
assert '"id": "doc1"' in result
assert '"id": "doc2"' in result
assert result.startswith("[") # Should be valid JSON after concatenation
def test_run_success_global_index(
tool_config, mock_search_iter, mock_embedding_function
):
"""Test successful _run execution with a global (non-scoped) index."""
tool_config["scoped_index"] = False
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
couchbase_tool = CouchbaseFTSVectorSearchTool(**tool_config)
query = "find global documents"
# expected_embedding = mock_embedding_function(query)
# Mock the cluster search method
couchbase_tool.cluster.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with (
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery"
) as mock_vq,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch"
) as mock_vs,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest"
) as mock_sr,
patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions"
) as mock_so,
):
# Set up the mock objects and their return values
mock_vector_query = MagicMock()
mock_vector_search = MagicMock()
mock_search_req = MagicMock()
mock_search_options = MagicMock()
mock_vq.return_value = mock_vector_query
mock_vs.from_vector_query.return_value = mock_vector_search
mock_sr.create.return_value = mock_search_req
mock_so.return_value = mock_search_options
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config["embedding_function"].assert_called_once_with(query)
# Check VectorQuery/Search call
mock_vq.assert_called_once_with(
tool_config["embedding_key"],
mock_embedding_function.return_value,
tool_config["limit"],
)
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config["limit"], fields=["*"])
# Check that cluster search was called correctly
couchbase_tool.cluster.search.assert_called_once_with(
tool_config["index_name"], mock_search_req, mock_search_options
)
# Check scope search was NOT called
couchbase_tool._scope.search.assert_not_called()
# Check result format
assert '"id": "doc1"' in result
assert '"id": "doc2"' in result
def test_check_bucket_exists_fail(tool_config):
"""Test check for bucket non-existence."""
mock_cluster = tool_config["cluster"]
mock_cluster.buckets().get_bucket.side_effect = (
mock_couchbase.exceptions.BucketNotFoundException("Bucket not found")
)
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
with pytest.raises(ValueError, match="Bucket test_bucket does not exist."):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_scope_exists_fail(tool_config):
"""Test check for scope non-existence."""
mock_cluster = tool_config["cluster"]
# Simulate scope not being in the list returned
mock_scope_spec = MagicMock()
mock_scope_spec.name = "wrong_scope"
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
with pytest.raises(ValueError, match="Scope test_scope not found"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_collection_exists_fail(tool_config):
"""Test check for collection non-existence."""
mock_cluster = tool_config["cluster"]
# Simulate collection not being in the scope's list
mock_scope_spec = MagicMock()
mock_scope_spec.name = "test_scope"
mock_collection_spec = MagicMock()
mock_collection_spec.name = "wrong_collection"
mock_scope_spec.collections = [mock_collection_spec] # Only has wrong collection
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
with pytest.raises(ValueError, match="Collection test_collection not found"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_index_exists_fail_scoped(tool_config):
"""Test check for scoped index non-existence."""
mock_cluster = tool_config["cluster"]
# Simulate index not being in the list returned by scope manager
mock_cluster.bucket().scope().search_indexes().get_all_indexes.return_value = []
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
with pytest.raises(ValueError, match="Index test_index does not exist"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_index_exists_fail_global(tool_config):
"""Test check for global index non-existence."""
tool_config["scoped_index"] = False
mock_cluster = tool_config["cluster"]
# Simulate index not being in the list returned by cluster manager
mock_cluster.search_indexes().get_all_indexes.return_value = []
with patch(
"crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE", True
):
with pytest.raises(ValueError, match="Index test_index does not exist"):
CouchbaseFTSVectorSearchTool(**tool_config)

View File

@@ -0,0 +1,251 @@
from typing import Union, get_args, get_origin
from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import (
CrewAIPlatformActionTool,
)
class TestSchemaProcessing:
def setup_method(self):
self.base_action_schema = {
"function": {
"parameters": {
"properties": {},
"required": []
}
}
}
def create_test_tool(self, action_name="test_action"):
return CrewAIPlatformActionTool(
description="Test tool",
action_name=action_name,
action_schema=self.base_action_schema
)
def test_anyof_multiple_types(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestField")
assert get_origin(result_type) is Union
args = get_args(result_type)
expected_types = (str, float, int)
for expected_type in expected_types:
assert expected_type in args
def test_anyof_with_null(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "null"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldNullable")
assert get_origin(result_type) is Union
args = get_args(result_type)
assert type(None) in args
assert str in args
assert float in args
def test_anyof_single_type(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldSingle")
assert result_type is str
def test_oneof_multiple_types(self):
tool = self.create_test_tool()
test_schema = {
"oneOf": [
{"type": "string"},
{"type": "boolean"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldOneOf")
assert get_origin(result_type) is Union
args = get_args(result_type)
expected_types = (str, bool)
for expected_type in expected_types:
assert expected_type in args
def test_oneof_single_type(self):
tool = self.create_test_tool()
test_schema = {
"oneOf": [
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldOneOfSingle")
assert result_type is int
def test_basic_types(self):
tool = self.create_test_tool()
test_cases = [
({"type": "string"}, str),
({"type": "integer"}, int),
({"type": "number"}, float),
({"type": "boolean"}, bool),
({"type": "array", "items": {"type": "string"}}, list),
]
for schema, expected_type in test_cases:
result_type = tool._process_schema_type(schema, "TestField")
if schema["type"] == "array":
assert get_origin(result_type) is list
else:
assert result_type is expected_type
def test_enum_handling(self):
tool = self.create_test_tool()
test_schema = {
"type": "string",
"enum": ["option1", "option2", "option3"]
}
result_type = tool._process_schema_type(test_schema, "TestFieldEnum")
assert result_type is str
def test_nested_anyof(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{
"anyOf": [
{"type": "integer"},
{"type": "boolean"}
]
}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldNested")
assert get_origin(result_type) is Union
args = get_args(result_type)
assert str in args
if len(args) == 3:
assert int in args
assert bool in args
else:
nested_union = next(arg for arg in args if get_origin(arg) is Union)
nested_args = get_args(nested_union)
assert int in nested_args
assert bool in nested_args
def test_allof_same_types(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "string"},
{"type": "string", "maxLength": 100}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSame")
assert result_type is str
def test_allof_object_merge(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name"]
},
{
"type": "object",
"properties": {
"email": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["email"]
}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMerged")
# Should create a merged model with all properties
# The implementation might fall back to dict if model creation fails
# Let's just verify it's not a basic scalar type
assert result_type is not str
assert result_type is not int
assert result_type is not bool
# It could be dict (fallback) or a proper model class
assert result_type in (dict, type) or hasattr(result_type, '__name__')
def test_allof_single_schema(self):
"""Test that allOf with single schema works correctly."""
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "boolean"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSingle")
# Should be just bool
assert result_type is bool
def test_allof_mixed_types(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "string"},
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMixed")
assert result_type is str

View File

@@ -0,0 +1,260 @@
import unittest
from unittest.mock import Mock, patch
from crewai_tools.tools.crewai_platform_tools import (
CrewAIPlatformActionTool,
CrewaiPlatformToolBuilder,
)
import pytest
class TestCrewaiPlatformToolBuilder(unittest.TestCase):
@pytest.fixture
def platform_tool_builder(self):
"""Create a CrewaiPlatformToolBuilder instance for testing"""
return CrewaiPlatformToolBuilder(apps=["github", "slack"])
@pytest.fixture
def mock_api_response(self):
return {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Issue title",
},
"body": {"type": "string", "description": "Issue body"},
},
"required": ["title"],
},
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {
"type": "string",
"description": "Channel name",
},
"text": {
"type": "string",
"description": "Message text",
},
},
"required": ["channel", "text"],
},
}
],
}
}
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
)
def test_fetch_actions_success(self, mock_get):
mock_api_response = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Issue title",
}
},
"required": ["title"],
},
}
]
}
}
builder = CrewaiPlatformToolBuilder(apps=["github", "slack/send_message"])
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = mock_api_response
mock_get.return_value = mock_response
builder._fetch_actions()
mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert "/actions" in args[0]
assert kwargs["headers"]["Authorization"] == "Bearer test_token"
assert kwargs["params"]["apps"] == "github,slack/send_message"
assert "create_issue" in builder._actions_schema
assert (
builder._actions_schema["create_issue"]["function"]["name"]
== "create_issue"
)
def test_fetch_actions_no_token(self):
builder = CrewaiPlatformToolBuilder(apps=["github"])
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
builder._fetch_actions()
assert "No platform integration token found" in str(context.exception)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
)
def test_create_tools(self, mock_get):
mock_api_response = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Issue title",
}
},
"required": ["title"],
},
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {
"type": "string",
"description": "Channel name",
}
},
"required": ["channel"],
},
}
],
}
}
builder = CrewaiPlatformToolBuilder(apps=["github", "slack"])
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = mock_api_response
mock_get.return_value = mock_response
tools = builder.tools()
assert len(tools) == 2
assert all(isinstance(tool, CrewAIPlatformActionTool) for tool in tools)
tool_names = [tool.action_name for tool in tools]
assert "create_issue" in tool_names
assert "send_message" in tool_names
github_tool = next((t for t in tools if t.action_name == "create_issue"), None)
slack_tool = next((t for t in tools if t.action_name == "send_message"), None)
assert github_tool is not None
assert slack_tool is not None
assert "Create a GitHub issue" in github_tool.description
assert "Send a Slack message" in slack_tool.description
def test_tools_caching(self):
builder = CrewaiPlatformToolBuilder(apps=["github"])
cached_tools = []
def mock_create_tools():
builder._tools = cached_tools
with (
patch.object(builder, "_fetch_actions") as mock_fetch,
patch.object(
builder, "_create_tools", side_effect=mock_create_tools
) as mock_create,
):
tools1 = builder.tools()
assert mock_fetch.call_count == 1
assert mock_create.call_count == 1
tools2 = builder.tools()
assert mock_fetch.call_count == 1
assert mock_create.call_count == 1
assert tools1 is tools2
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
def test_empty_apps_list(self):
builder = CrewaiPlatformToolBuilder(apps=[])
with patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
) as mock_get:
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {}}
mock_get.return_value = mock_response
tools = builder.tools()
assert isinstance(tools, list)
assert len(tools) == 0
_, kwargs = mock_get.call_args
assert kwargs["params"]["apps"] == ""
def test_detailed_description_generation(self):
builder = CrewaiPlatformToolBuilder(apps=["test"])
complex_schema = {
"type": "object",
"properties": {
"simple_string": {"type": "string", "description": "A simple string"},
"nested_object": {
"type": "object",
"properties": {
"inner_prop": {
"type": "integer",
"description": "Inner property",
}
},
"description": "Nested object",
},
"array_prop": {
"type": "array",
"items": {"type": "string"},
"description": "Array of strings",
},
},
}
descriptions = builder._generate_detailed_description(complex_schema)
assert isinstance(descriptions, list)
assert len(descriptions) > 0
description_text = "\n".join(descriptions)
assert "simple_string" in description_text
assert "nested_object" in description_text
assert "array_prop" in description_text

View File

@@ -0,0 +1,115 @@
import unittest
from unittest.mock import Mock, patch
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformTools
class TestCrewaiPlatformTools(unittest.TestCase):
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
)
def test_crewai_platform_tools_basic(self, mock_get):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {"github": []}}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=["github"])
assert tools is not None
assert isinstance(tools, list)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
)
def test_crewai_platform_tools_multiple_apps(self, mock_get):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Issue title",
},
"body": {"type": "string", "description": "Issue body"},
},
"required": ["title"],
},
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {
"type": "string",
"description": "Channel to send to",
},
"text": {
"type": "string",
"description": "Message text",
},
},
"required": ["channel", "text"],
},
}
],
}
}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=["github", "slack"])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 2
mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert (
"apps=github,slack" in args[0]
or kwargs.get("params", {}).get("apps") == "github,slack"
)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
def test_crewai_platform_tools_empty_apps(self):
with patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
) as mock_get:
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {}}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=[])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 0
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch(
"crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get"
)
def test_crewai_platform_tools_api_error_handling(self, mock_get):
mock_get.side_effect = Exception("API Error")
tools = CrewaiPlatformTools(apps=["github"])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 0
def test_crewai_platform_tools_no_token(self):
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
CrewaiPlatformTools(apps=["github"])
assert "No platform integration token found" in str(context.exception)

View File

@@ -0,0 +1,86 @@
import os
from unittest.mock import patch
from crewai_tools import EXASearchTool
import pytest
@pytest.fixture
def exa_search_tool():
return EXASearchTool(api_key="test_api_key")
@pytest.fixture(autouse=True)
def mock_exa_api_key():
with patch.dict(os.environ, {"EXA_API_KEY": "test_key_from_env"}):
yield
def test_exa_search_tool_initialization():
with patch.dict(os.environ, {}, clear=True):
with patch(
"crewai_tools.tools.exa_tools.exa_search_tool.Exa"
) as mock_exa_class:
api_key = "test_api_key"
tool = EXASearchTool(api_key=api_key)
assert tool.api_key == api_key
assert tool.content is False
assert tool.summary is False
assert tool.type == "auto"
mock_exa_class.assert_called_once_with(api_key=api_key)
def test_exa_search_tool_initialization_with_env(mock_exa_api_key):
with patch.dict(os.environ, {"EXA_API_KEY": "test_key_from_env"}, clear=True):
with patch(
"crewai_tools.tools.exa_tools.exa_search_tool.Exa"
) as mock_exa_class:
EXASearchTool()
mock_exa_class.assert_called_once_with(api_key="test_key_from_env")
def test_exa_search_tool_initialization_with_base_url():
with patch.dict(os.environ, {}, clear=True):
with patch(
"crewai_tools.tools.exa_tools.exa_search_tool.Exa"
) as mock_exa_class:
api_key = "test_api_key"
base_url = "https://custom.exa.api.com"
tool = EXASearchTool(api_key=api_key, base_url=base_url)
assert tool.api_key == api_key
assert tool.base_url == base_url
assert tool.content is False
assert tool.summary is False
assert tool.type == "auto"
mock_exa_class.assert_called_once_with(api_key=api_key, base_url=base_url)
@pytest.fixture
def mock_exa_base_url():
with patch.dict(os.environ, {"EXA_BASE_URL": "https://env.exa.api.com"}):
yield
def test_exa_search_tool_initialization_with_env_base_url(
mock_exa_api_key, mock_exa_base_url
):
with patch("crewai_tools.tools.exa_tools.exa_search_tool.Exa") as mock_exa_class:
EXASearchTool()
mock_exa_class.assert_called_once_with(
api_key="test_key_from_env", base_url="https://env.exa.api.com"
)
def test_exa_search_tool_initialization_without_base_url():
with patch.dict(os.environ, {}, clear=True):
with patch(
"crewai_tools.tools.exa_tools.exa_search_tool.Exa"
) as mock_exa_class:
api_key = "test_api_key"
tool = EXASearchTool(api_key=api_key)
assert tool.api_key == api_key
assert tool.base_url is None
mock_exa_class.assert_called_once_with(api_key=api_key)

View File

@@ -0,0 +1,131 @@
from unittest.mock import patch
from crewai_tools.tools.files_compressor_tool import FileCompressorTool
import pytest
@pytest.fixture
def tool():
return FileCompressorTool()
@patch("os.path.exists", return_value=False)
def test_input_path_does_not_exist(mock_exists, tool):
result = tool._run("nonexistent_path")
assert "does not exist" in result
@patch("os.path.exists", return_value=True)
@patch("os.getcwd", return_value="/mocked/cwd")
@patch.object(FileCompressorTool, "_compress_zip") # Mock actual compression
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_generate_output_path_default(
mock_prepare, mock_compress, mock_cwd, mock_exists, tool
):
result = tool._run(input_path="mydir", format="zip")
assert "Successfully compressed" in result
mock_compress.assert_called_once()
@patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_zip")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_zip_compression(mock_prepare, mock_compress, mock_exists, tool):
result = tool._run(
input_path="some/path", output_path="archive.zip", format="zip", overwrite=True
)
assert "Successfully compressed" in result
mock_compress.assert_called_once()
@patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_tar")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_tar_gz_compression(mock_prepare, mock_compress, mock_exists, tool):
result = tool._run(
input_path="some/path",
output_path="archive.tar.gz",
format="tar.gz",
overwrite=True,
)
assert "Successfully compressed" in result
mock_compress.assert_called_once()
@pytest.mark.parametrize("format", ["tar", "tar.bz2", "tar.xz"])
@patch("os.path.exists", return_value=True)
@patch.object(FileCompressorTool, "_compress_tar")
@patch.object(FileCompressorTool, "_prepare_output", return_value=True)
def test_other_tar_formats(mock_prepare, mock_compress, mock_exists, format, tool):
result = tool._run(
input_path="path/to/input",
output_path=f"archive.{format}",
format=format,
overwrite=True,
)
assert "Successfully compressed" in result
mock_compress.assert_called_once()
@pytest.mark.parametrize("format", ["rar", "7z"])
@patch("os.path.exists", return_value=True) # Ensure input_path exists
def test_unsupported_format(_, tool, format):
result = tool._run(
input_path="some/path", output_path=f"archive.{format}", format=format
)
assert "not supported" in result
@patch("os.path.exists", return_value=True)
def test_extension_mismatch(_, tool):
result = tool._run(
input_path="some/path", output_path="archive.zip", format="tar.gz"
)
assert "must have a '.tar.gz' extension" in result
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
@patch("os.path.exists", return_value=True)
def test_existing_output_no_overwrite(_, __, ___, tool):
result = tool._run(
input_path="some/path", output_path="archive.zip", format="zip", overwrite=False
)
assert "overwrite is set to False" in result
@patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=PermissionError)
def test_permission_error(mock_zip, _, tool):
result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "Permission denied" in result
@patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=FileNotFoundError)
def test_file_not_found_during_zip(mock_zip, _, tool):
result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "File not found" in result
@patch("os.path.exists", return_value=True)
@patch("zipfile.ZipFile", side_effect=Exception("Unexpected"))
def test_general_exception_during_zip(mock_zip, _, tool):
result = tool._run(
input_path="file.txt", output_path="file.zip", format="zip", overwrite=True
)
assert "unexpected error" in result
# Test: Output directory is created when missing
@patch("os.makedirs")
@patch("os.path.exists", return_value=False)
def test_prepare_output_makes_dir(mock_exists, mock_makedirs):
tool = FileCompressorTool()
result = tool._prepare_output("some/missing/path/file.zip", overwrite=True)
assert result is True
mock_makedirs.assert_called_once()

View File

@@ -0,0 +1,186 @@
import os
from unittest.mock import MagicMock, patch
from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automation_tool import (
GenerateCrewaiAutomationTool,
GenerateCrewaiAutomationToolSchema,
)
import pytest
import requests
@pytest.fixture(autouse=True)
def mock_env():
with patch.dict(os.environ, {"CREWAI_PERSONAL_ACCESS_TOKEN": "test_token"}):
os.environ.pop("CREWAI_PLUS_URL", None)
yield
@pytest.fixture
def tool():
return GenerateCrewaiAutomationTool()
@pytest.fixture
def custom_url_tool():
with patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom.crewai.com"}):
return GenerateCrewaiAutomationTool()
def test_default_initialization(tool):
assert tool.crewai_enterprise_url == "https://app.crewai.com"
assert tool.personal_access_token == "test_token"
assert tool.name == "Generate CrewAI Automation"
def test_custom_base_url_from_environment(custom_url_tool):
assert custom_url_tool.crewai_enterprise_url == "https://custom.crewai.com"
def test_personal_access_token_from_environment(tool):
assert tool.personal_access_token == "test_token"
def test_valid_prompt_only():
schema = GenerateCrewaiAutomationToolSchema(
prompt="Create a web scraping automation"
)
assert schema.prompt == "Create a web scraping automation"
assert schema.organization_id is None
def test_valid_prompt_with_organization_id():
schema = GenerateCrewaiAutomationToolSchema(
prompt="Create automation", organization_id="org-123"
)
assert schema.prompt == "Create automation"
assert schema.organization_id == "org-123"
def test_empty_prompt_validation():
schema = GenerateCrewaiAutomationToolSchema(prompt="")
assert schema.prompt == ""
assert schema.organization_id is None
@patch("requests.post")
def test_successful_generation_without_org_id(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://app.crewai.com/studio/project-123"
}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation")
assert (
result
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-123"
)
mock_post.assert_called_once_with(
"https://app.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_successful_generation_with_org_id(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://app.crewai.com/studio/project-456"
}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation", organization_id="org-456")
assert (
result
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-456"
)
mock_post.assert_called_once_with(
"https://app.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
"X-Crewai-Organization-Id": "org-456",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_custom_base_url_usage(mock_post, custom_url_tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://custom.crewai.com/studio/project-789"
}
mock_post.return_value = mock_response
custom_url_tool.run(prompt="Create automation")
mock_post.assert_called_once_with(
"https://custom.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_api_error_response_handling(mock_post, tool):
mock_post.return_value.raise_for_status.side_effect = requests.HTTPError(
"400 Bad Request"
)
with pytest.raises(requests.HTTPError):
tool.run(prompt="Create automation")
@patch("requests.post")
def test_network_error_handling(mock_post, tool):
mock_post.side_effect = requests.ConnectionError("Network unreachable")
with pytest.raises(requests.ConnectionError):
tool.run(prompt="Create automation")
@patch("requests.post")
def test_api_response_missing_url(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success"}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation")
assert result == "Generated CrewAI Studio project URL: None"
def test_authorization_header_construction(tool):
headers = tool._get_headers()
assert headers["Authorization"] == "Bearer test_token"
assert headers["Content-Type"] == "application/json"
assert headers["Accept"] == "application/json"
assert "X-Crewai-Organization-Id" not in headers
def test_authorization_header_with_org_id(tool):
headers = tool._get_headers(organization_id="org-123")
assert headers["Authorization"] == "Bearer test_token"
assert headers["X-Crewai-Organization-Id"] == "org-123"
def test_missing_personal_access_token():
with patch.dict(os.environ, {}, clear=True):
tool = GenerateCrewaiAutomationTool()
assert tool.personal_access_token is None

View File

@@ -0,0 +1,44 @@
import json
from unittest.mock import patch
from urllib.parse import urlparse
from crewai_tools.tools.parallel_tools.parallel_search_tool import (
ParallelSearchTool,
)
def test_requires_env_var(monkeypatch):
monkeypatch.delenv("PARALLEL_API_KEY", raising=False)
tool = ParallelSearchTool()
result = tool.run(objective="test")
assert "PARALLEL_API_KEY" in result
@patch("crewai_tools.tools.parallel_tools.parallel_search_tool.requests.post")
def test_happy_path(mock_post, monkeypatch):
monkeypatch.setenv("PARALLEL_API_KEY", "test")
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"search_id": "search_123",
"results": [
{
"url": "https://www.un.org/en/about-us/history-of-the-un",
"title": "History of the United Nations",
"excerpts": [
"Four months after the San Francisco Conference ended, the United Nations officially began, on 24 October 1945..."
],
}
],
}
tool = ParallelSearchTool()
result = tool.run(
objective="When was the UN established?", search_queries=["Founding year UN"]
)
data = json.loads(result)
assert "search_id" in data
urls = [r.get("url", "") for r in data.get("results", [])]
# Validate host against allowed set instead of substring matching
allowed_hosts = {"www.un.org", "un.org"}
assert any(urlparse(u).netloc in allowed_hosts for u in urls)

View File

@@ -0,0 +1,178 @@
"""Tests for RAG tool with mocked embeddings and vector database."""
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
from unittest.mock import MagicMock, Mock, patch
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_initialization(
mock_create_client: Mock, mock_get_rag_client: Mock
) -> None:
"""Test that RagTool initializes with CrewAI adapter by default."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_get_rag_client.return_value = mock_client
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
tool = MyTool()
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
adapter = cast(CrewAIRagAdapter, tool.adapter)
assert adapter.collection_name == "rag_tool_collection"
assert adapter._client is not None
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_add_and_query(
mock_create_client: Mock, mock_get_rag_client: Mock
) -> None:
"""Test adding content and querying with RagTool."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_client.search = MagicMock(
return_value=[
{"content": "The sky is blue on a clear day.", "metadata": {}, "score": 0.9}
]
)
mock_get_rag_client.return_value = mock_client
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
tool = MyTool()
tool.add("The sky is blue on a clear day.")
tool.add("Machine learning is a subset of artificial intelligence.")
# Verify documents were added
assert mock_client.add_documents.call_count == 2
result = tool._run(query="What color is the sky?")
assert "Relevant Content:" in result
assert "The sky is blue" in result
mock_client.search.return_value = [
{
"content": "Machine learning is a subset of artificial intelligence.",
"metadata": {},
"score": 0.85,
}
]
result = tool._run(query="Tell me about machine learning")
assert "Relevant Content:" in result
assert "Machine learning" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_file(
mock_create_client: Mock, mock_get_rag_client: Mock
) -> None:
"""Test RagTool with file content."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_client.search = MagicMock(
return_value=[
{
"content": "Python is a programming language known for its simplicity.",
"metadata": {"file_path": "test.txt"},
"score": 0.95,
}
]
)
mock_get_rag_client.return_value = mock_client
mock_create_client.return_value = mock_client
with TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.txt"
test_file.write_text(
"Python is a programming language known for its simplicity."
)
class MyTool(RagTool):
pass
tool = MyTool()
tool.add(str(test_file))
assert mock_client.add_documents.called
result = tool._run(query="What is Python?")
assert "Relevant Content:" in result
assert "Python is a programming language" in result
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_custom_embeddings(
mock_create_client: Mock, mock_create_embedding: Mock
) -> None:
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.2] * 1536]
mock_create_embedding.return_value = mock_embedding_func
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_client.search = MagicMock(
return_value=[{"content": "Test content", "metadata": {}, "score": 0.8}]
)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"vectordb": {"provider": "chromadb", "config": {}},
"embedding_model": {
"provider": "openai",
"config": {"model": "text-embedding-3-small"},
},
}
tool = MyTool(config=config)
tool.add("Test content")
result = tool._run(query="Test query")
assert "Relevant Content:" in result
assert "Test content" in result
mock_create_embedding.assert_called()
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_no_results(
mock_create_client: Mock, mock_get_rag_client: Mock
) -> None:
"""Test RagTool when no relevant content is found."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.search = MagicMock(return_value=[])
mock_get_rag_client.return_value = mock_client
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
tool = MyTool()
result = tool._run(query="Non-existent content")
assert "Relevant Content:" in result
assert "No relevant content found" in result

View File

@@ -0,0 +1,131 @@
import os
import tempfile
from unittest.mock import MagicMock, patch
from bs4 import BeautifulSoup
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
SeleniumScrapingTool,
)
from selenium.webdriver.chrome.options import Options
def mock_driver_with_html(html_content):
driver = MagicMock()
mock_element = MagicMock()
mock_element.get_attribute.return_value = html_content
bs = BeautifulSoup(html_content, "html.parser")
mock_element.text = bs.get_text()
driver.find_elements.return_value = [mock_element]
driver.find_element.return_value = mock_element
return driver
def initialize_tool_with(mock_driver):
tool = SeleniumScrapingTool(driver=mock_driver)
return tool
@patch("selenium.webdriver.Chrome")
def test_tool_initialization(mocked_chrome):
temp_dir = tempfile.mkdtemp()
mocked_chrome.return_value = MagicMock()
tool = SeleniumScrapingTool()
assert tool.website_url is None
assert tool.css_element is None
assert tool.cookie is None
assert tool.wait_time == 3
assert tool.return_html is False
try:
os.rmdir(temp_dir)
except:
pass
@patch("selenium.webdriver.Chrome")
def test_tool_initialization_with_options(mocked_chrome):
mocked_chrome.return_value = MagicMock()
options = Options()
options.add_argument("--disable-gpu")
SeleniumScrapingTool(options=options)
mocked_chrome.assert_called_once_with(options=options)
@patch("selenium.webdriver.Chrome")
def test_scrape_without_css_selector(_mocked_chrome_driver):
html_content = "<html><body><div>test content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com")
assert "test content" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_css_selector(_mocked_chrome_driver):
html_content = "<html><body><div>test content</div><div class='test'>test content in a specific div</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", css_element="div.test")
assert "test content in a specific div" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_elements.assert_called_with("css selector", "div.test")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_return_html_true(_mocked_chrome_driver):
html_content = "<html><body><div>HTML content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", return_html=True)
assert html_content in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_return_html_false(_mocked_chrome_driver):
html_content = "<html><body><div>HTML content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", return_html=False)
assert "HTML content" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_driver_error(_mocked_chrome_driver):
mock_driver = MagicMock()
mock_driver.find_element.side_effect = Exception("WebDriver error occurred")
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com")
assert result == "Error scraping website: WebDriver error occurred"
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_initialization_with_driver(_mocked_chrome_driver):
mock_driver = MagicMock()
tool = initialize_tool_with(mock_driver)
assert tool.driver == mock_driver

View File

@@ -0,0 +1,141 @@
import os
from unittest.mock import patch
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
import pytest
@pytest.fixture(autouse=True)
def mock_serper_api_key():
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
yield
@pytest.fixture
def serper_tool():
return SerperDevTool(n_results=2)
def test_serper_tool_initialization():
tool = SerperDevTool()
assert tool.n_results == 10
assert tool.save_file is False
assert tool.search_type == "search"
assert tool.country == ""
assert tool.location == ""
assert tool.locale == ""
def test_serper_tool_custom_initialization():
tool = SerperDevTool(
n_results=5,
save_file=True,
search_type="news",
country="US",
location="New York",
locale="en",
)
assert tool.n_results == 5
assert tool.save_file is True
assert tool.search_type == "news"
assert tool.country == "US"
assert tool.location == "New York"
assert tool.locale == "en"
@patch("requests.post")
def test_serper_tool_search(mock_post):
tool = SerperDevTool(n_results=2)
mock_response = {
"searchParameters": {"q": "test query", "type": "search"},
"organic": [
{
"title": "Test Title 1",
"link": "http://test1.com",
"snippet": "Test Description 1",
"position": 1,
},
{
"title": "Test Title 2",
"link": "http://test2.com",
"snippet": "Test Description 2",
"position": 2,
},
],
"peopleAlsoAsk": [
{
"question": "Test Question",
"snippet": "Test Answer",
"title": "Test Source",
"link": "http://test.com",
}
],
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test query")
assert "searchParameters" in result
assert result["searchParameters"]["q"] == "test query"
assert len(result["organic"]) == 2
assert result["organic"][0]["title"] == "Test Title 1"
@patch("requests.post")
def test_serper_tool_news_search(mock_post):
tool = SerperDevTool(n_results=2, search_type="news")
mock_response = {
"searchParameters": {"q": "test news", "type": "news"},
"news": [
{
"title": "News Title 1",
"link": "http://news1.com",
"snippet": "News Description 1",
"date": "2024-01-01",
"source": "News Source 1",
"imageUrl": "http://image1.com",
}
],
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test news")
assert "news" in result
assert len(result["news"]) == 1
assert result["news"][0]["title"] == "News Title 1"
@patch("requests.post")
def test_serper_tool_with_location_params(mock_post):
tool = SerperDevTool(n_results=2, country="US", location="New York", locale="en")
tool.run(search_query="test")
called_payload = mock_post.call_args.kwargs["json"]
assert called_payload["gl"] == "US"
assert called_payload["location"] == "New York"
assert called_payload["hl"] == "en"
def test_invalid_search_type():
tool = SerperDevTool()
with pytest.raises(ValueError) as exc_info:
tool.run(search_query="test", search_type="invalid")
assert "Invalid search type" in str(exc_info.value)
@patch("requests.post")
def test_api_error_handling(mock_post):
tool = SerperDevTool()
mock_post.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
tool.run(search_query="test")
assert "API Error" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,335 @@
from collections.abc import Generator
import os
from crewai_tools import SingleStoreSearchTool
from crewai_tools.tools.singlestore_search_tool import SingleStoreSearchToolSchema
import pytest
from singlestoredb import connect
from singlestoredb.server import docker
@pytest.fixture(scope="session")
def docker_server_url() -> Generator[str, None, None]:
"""Start a SingleStore Docker server for tests."""
try:
sdb = docker.start(license="")
conn = sdb.connect()
curr = conn.cursor()
curr.execute("CREATE DATABASE test_crewai")
curr.close()
conn.close()
yield sdb.connection_url
sdb.stop()
except Exception as e:
pytest.skip(f"Could not start SingleStore Docker container: {e}")
@pytest.fixture(scope="function")
def clean_db_url(docker_server_url) -> Generator[str, None, None]:
"""Provide a clean database URL and clean up tables after test."""
yield docker_server_url
try:
conn = connect(host=docker_server_url, database="test_crewai")
curr = conn.cursor()
curr.execute("SHOW TABLES")
results = curr.fetchall()
for result in results:
curr.execute(f"DROP TABLE {result[0]}")
curr.close()
conn.close()
except Exception:
# Ignore cleanup errors
pass
@pytest.fixture
def sample_table_setup(clean_db_url):
"""Set up sample tables for testing."""
conn = connect(host=clean_db_url, database="test_crewai")
curr = conn.cursor()
# Create sample tables
curr.execute(
"""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2)
)
"""
)
curr.execute(
"""
CREATE TABLE departments (
id INT PRIMARY KEY,
name VARCHAR(100),
budget DECIMAL(12,2)
)
"""
)
# Insert sample data
curr.execute(
"""
INSERT INTO employees VALUES
(1, 'Alice Smith', 'Engineering', 75000.00),
(2, 'Bob Johnson', 'Marketing', 65000.00),
(3, 'Carol Davis', 'Engineering', 80000.00)
"""
)
curr.execute(
"""
INSERT INTO departments VALUES
(1, 'Engineering', 500000.00),
(2, 'Marketing', 300000.00)
"""
)
curr.close()
conn.close()
return clean_db_url
class TestSingleStoreSearchTool:
"""Test suite for SingleStoreSearchTool."""
def test_tool_creation_with_connection_params(self, sample_table_setup):
"""Test tool creation with individual connection parameters."""
# Parse URL components for individual parameters
url_parts = sample_table_setup.split("@")[1].split(":")
host = url_parts[0]
port = int(url_parts[1].split("/")[0])
user = "root"
password = sample_table_setup.split("@")[0].split(":")[2]
tool = SingleStoreSearchTool(
tables=[],
host=host,
port=port,
user=user,
password=password,
database="test_crewai",
)
assert tool.name == "Search a database's table(s) content"
assert "SingleStore" in tool.description
assert (
"employees(id int(11), name varchar(100), department varchar(50), salary decimal(10,2))"
in tool.description.lower()
)
assert (
"departments(id int(11), name varchar(100), budget decimal(12,2))"
in tool.description.lower()
)
assert tool.args_schema == SingleStoreSearchToolSchema
assert tool.connection_pool is not None
def test_tool_creation_with_connection_url(self, sample_table_setup):
"""Test tool creation with connection URL."""
tool = SingleStoreSearchTool(host=f"{sample_table_setup}/test_crewai")
assert tool.name == "Search a database's table(s) content"
assert tool.connection_pool is not None
def test_tool_creation_with_specific_tables(self, sample_table_setup):
"""Test tool creation with specific table list."""
tool = SingleStoreSearchTool(
tables=["employees"],
host=sample_table_setup,
database="test_crewai",
)
# Check that description includes specific tables
assert "employees" in tool.description
assert "departments" not in tool.description
def test_tool_creation_with_nonexistent_table(self, sample_table_setup):
"""Test tool creation fails with non-existent table."""
with pytest.raises(ValueError, match="Table nonexistent does not exist"):
SingleStoreSearchTool(
tables=["employees", "nonexistent"],
host=sample_table_setup,
database="test_crewai",
)
def test_tool_creation_with_empty_database(self, clean_db_url):
"""Test tool creation fails with empty database."""
with pytest.raises(ValueError, match="No tables found in the database"):
SingleStoreSearchTool(host=clean_db_url, database="test_crewai")
def test_description_generation(self, sample_table_setup):
"""Test that tool description is properly generated with table info."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
# Check description contains table definitions
assert "employees(" in tool.description
assert "departments(" in tool.description
assert "id int" in tool.description.lower()
assert "name varchar" in tool.description.lower()
def test_query_validation_select_allowed(self, sample_table_setup):
"""Test that SELECT queries are allowed."""
os.environ["SINGLESTOREDB_URL"] = sample_table_setup
tool = SingleStoreSearchTool(database="test_crewai")
valid, message = tool._validate_query("SELECT * FROM employees")
assert valid is True
assert message == "Valid query"
def test_query_validation_show_allowed(self, sample_table_setup):
"""Test that SHOW queries are allowed."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("SHOW TABLES")
assert valid is True
assert message == "Valid query"
def test_query_validation_case_insensitive(self, sample_table_setup):
"""Test that query validation is case insensitive."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, _ = tool._validate_query("select * from employees")
assert valid is True
valid, _ = tool._validate_query("SHOW tables")
assert valid is True
def test_query_validation_insert_denied(self, sample_table_setup):
"""Test that INSERT queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query(
"INSERT INTO employees VALUES (4, 'Test', 'Test', 1000)"
)
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_update_denied(self, sample_table_setup):
"""Test that UPDATE queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("UPDATE employees SET salary = 90000")
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_delete_denied(self, sample_table_setup):
"""Test that DELETE queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("DELETE FROM employees WHERE id = 1")
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_non_string(self, sample_table_setup):
"""Test that non-string queries are rejected."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query(123)
assert valid is False
assert "Search query must be a string" in message
def test_run_select_query(self, sample_table_setup):
"""Test executing a SELECT query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FROM employees ORDER BY id")
assert "Search Results:" in result
assert "Alice Smith" in result
assert "Bob Johnson" in result
assert "Carol Davis" in result
def test_run_filtered_query(self, sample_table_setup):
"""Test executing a filtered SELECT query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run(
"SELECT name FROM employees WHERE department = 'Engineering'"
)
assert "Search Results:" in result
assert "Alice Smith" in result
assert "Carol Davis" in result
assert "Bob Johnson" not in result
def test_run_show_query(self, sample_table_setup):
"""Test executing a SHOW query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SHOW TABLES")
assert "Search Results:" in result
assert "employees" in result
assert "departments" in result
def test_run_empty_result(self, sample_table_setup):
"""Test executing a query that returns no results."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FROM employees WHERE department = 'NonExistent'")
assert result == "No results found."
def test_run_invalid_query_syntax(self, sample_table_setup):
"""Test executing a query with invalid syntax."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FORM employees") # Intentional typo
assert "Error executing search query:" in result
def test_run_denied_query(self, sample_table_setup):
"""Test that denied queries return appropriate error message."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("DELETE FROM employees")
assert "Invalid search query:" in result
assert "Only SELECT and SHOW queries are supported" in result
def test_connection_pool_usage(self, sample_table_setup):
"""Test that connection pooling works correctly."""
tool = SingleStoreSearchTool(
host=sample_table_setup,
database="test_crewai",
pool_size=2,
)
# Execute multiple queries to test pool usage
results = []
for _ in range(5):
result = tool._run("SELECT COUNT(*) FROM employees")
results.append(result)
# All queries should succeed
for result in results:
assert "Search Results:" in result
assert "3" in result # Count of employees
def test_tool_schema_validation(self):
"""Test that the tool schema validation works correctly."""
# Valid input
valid_input = SingleStoreSearchToolSchema(search_query="SELECT * FROM test")
assert valid_input.search_query == "SELECT * FROM test"
# Test that description is present
schema_dict = SingleStoreSearchToolSchema.model_json_schema()
assert "search_query" in schema_dict["properties"]
assert "description" in schema_dict["properties"]["search_query"]
def test_connection_error_handling(self):
"""Test handling of connection errors."""
with pytest.raises(Exception):
# This should fail due to invalid connection parameters
SingleStoreSearchTool(
host="invalid_host",
port=9999,
user="invalid_user",
password="invalid_password",
database="invalid_db",
)

View File

@@ -0,0 +1,102 @@
import asyncio
from unittest.mock import MagicMock, patch
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
import pytest
# Unit Test Fixtures
@pytest.fixture
def mock_snowflake_connection():
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.description = [("col1",), ("col2",)]
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
mock_cursor.execute.return_value = None
mock_conn.cursor.return_value = mock_cursor
return mock_conn
@pytest.fixture
def mock_config():
return SnowflakeConfig(
account="test_account",
user="test_user",
password="test_password",
warehouse="test_warehouse",
database="test_db",
snowflake_schema="test_schema",
)
@pytest.fixture
def snowflake_tool(mock_config):
with patch("snowflake.connector.connect"):
tool = SnowflakeSearchTool(config=mock_config)
yield tool
# Unit Tests
@pytest.mark.asyncio
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
results = await snowflake_tool._run(
query="SELECT * FROM test_table", timeout=300
)
assert len(results) == 2
assert results[0]["col1"] == 1
assert results[0]["col2"] == "value1"
mock_snowflake_connection.cursor.assert_called_once()
@pytest.mark.asyncio
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Execute multiple queries
await asyncio.gather(
snowflake_tool._run("SELECT 1"),
snowflake_tool._run("SELECT 2"),
snowflake_tool._run("SELECT 3"),
)
# Should reuse connections from pool
assert mock_create_conn.call_count <= snowflake_tool.pool_size
@pytest.mark.asyncio
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Add connection to pool
await snowflake_tool._get_connection()
# Return connection to pool
async with snowflake_tool._pool_lock:
snowflake_tool._connection_pool.append(mock_snowflake_connection)
# Trigger cleanup
snowflake_tool.__del__()
mock_snowflake_connection.close.assert_called_once()
def test_config_validation():
# Test missing required fields
with pytest.raises(ValueError):
SnowflakeConfig()
# Test invalid account format
with pytest.raises(ValueError):
SnowflakeConfig(
account="invalid//account", user="test_user", password="test_pass"
)
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")

View File

@@ -0,0 +1,281 @@
import sys
from unittest.mock import MagicMock, patch
import pytest
# Create mock classes that will be used by our fixture
class MockStagehandModule:
def __init__(self):
self.Stagehand = MagicMock()
self.StagehandConfig = MagicMock()
self.StagehandPage = MagicMock()
class MockStagehandSchemas:
def __init__(self):
self.ActOptions = MagicMock()
self.ExtractOptions = MagicMock()
self.ObserveOptions = MagicMock()
self.AvailableModel = MagicMock()
class MockStagehandUtils:
def __init__(self):
self.configure_logging = MagicMock()
@pytest.fixture(scope="module", autouse=True)
def mock_stagehand_modules():
"""Mock stagehand modules at the start of this test module."""
# Store original modules if they exist
original_modules = {}
for module_name in ["stagehand", "stagehand.schemas", "stagehand.utils"]:
if module_name in sys.modules:
original_modules[module_name] = sys.modules[module_name]
# Create and inject mock modules
mock_stagehand = MockStagehandModule()
mock_stagehand_schemas = MockStagehandSchemas()
mock_stagehand_utils = MockStagehandUtils()
sys.modules["stagehand"] = mock_stagehand
sys.modules["stagehand.schemas"] = mock_stagehand_schemas
sys.modules["stagehand.utils"] = mock_stagehand_utils
# Import after mocking
from crewai_tools.tools.stagehand_tool.stagehand_tool import (
StagehandResult,
StagehandTool,
)
# Make these available to tests in this module
sys.modules[__name__].StagehandResult = StagehandResult
sys.modules[__name__].StagehandTool = StagehandTool
yield
# Restore original modules
for module_name, module in original_modules.items():
sys.modules[module_name] = module
class MockStagehandPage(MagicMock):
def act(self, options):
mock_result = MagicMock()
mock_result.model_dump.return_value = {
"message": "Action completed successfully"
}
return mock_result
def goto(self, url):
return MagicMock()
def extract(self, options):
mock_result = MagicMock()
mock_result.model_dump.return_value = {
"data": "Extracted content",
"metadata": {"source": "test"},
}
return mock_result
def observe(self, options):
result1 = MagicMock()
result1.description = "Button element"
result1.method = "click"
result2 = MagicMock()
result2.description = "Input field"
result2.method = "type"
return [result1, result2]
class MockStagehand(MagicMock):
def init(self):
self.session_id = "test-session-id"
self.page = MockStagehandPage()
def close(self):
pass
@pytest.fixture
def mock_stagehand_instance():
with patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.Stagehand",
return_value=MockStagehand(),
) as mock:
yield mock
@pytest.fixture
def stagehand_tool():
return StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True, # Enable testing mode to bypass dependency check
)
def test_stagehand_tool_initialization():
"""Test that the StagehandTool initializes with the correct default values."""
tool = StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True, # Enable testing mode
)
assert tool.api_key == "test_api_key"
assert tool.project_id == "test_project_id"
assert tool.model_api_key == "test_model_api_key"
assert tool.headless is False
assert tool.dom_settle_timeout_ms == 3000
assert tool.self_heal is True
assert tool.wait_for_captcha_solves is True
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True
)
def test_act_command(mock_run, stagehand_tool):
"""Test the 'act' command functionality."""
# Setup mock
mock_run.return_value = "Action result: Action completed successfully"
# Run the tool
result = stagehand_tool._run(
instruction="Click the submit button", command_type="act"
)
# Assertions
assert "Action result" in result
assert "Action completed successfully" in result
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True
)
def test_navigate_command(mock_run, stagehand_tool):
"""Test the 'navigate' command functionality."""
# Setup mock
mock_run.return_value = "Successfully navigated to https://example.com"
# Run the tool
result = stagehand_tool._run(
instruction="Go to example.com",
url="https://example.com",
command_type="navigate",
)
# Assertions
assert "https://example.com" in result
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True
)
def test_extract_command(mock_run, stagehand_tool):
"""Test the 'extract' command functionality."""
# Setup mock
mock_run.return_value = (
'Extracted data: {"data": "Extracted content", "metadata": {"source": "test"}}'
)
# Run the tool
result = stagehand_tool._run(
instruction="Extract all product names and prices", command_type="extract"
)
# Assertions
assert "Extracted data" in result
assert "Extracted content" in result
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True
)
def test_observe_command(mock_run, stagehand_tool):
"""Test the 'observe' command functionality."""
# Setup mock
mock_run.return_value = "Element 1: Button element\nSuggested action: click\nElement 2: Input field\nSuggested action: type"
# Run the tool
result = stagehand_tool._run(
instruction="Find all interactive elements", command_type="observe"
)
# Assertions
assert "Element 1: Button element" in result
assert "Element 2: Input field" in result
assert "Suggested action: click" in result
assert "Suggested action: type" in result
@patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True
)
def test_error_handling(mock_run, stagehand_tool):
"""Test error handling in the tool."""
# Setup mock
mock_run.return_value = "Error: Browser automation error"
# Run the tool
result = stagehand_tool._run(
instruction="Click a non-existent button", command_type="act"
)
# Assertions
assert "Error:" in result
assert "Browser automation error" in result
def test_initialization_parameters():
"""Test that the StagehandTool initializes with the correct parameters."""
# Create tool with custom parameters
tool = StagehandTool(
api_key="custom_api_key",
project_id="custom_project_id",
model_api_key="custom_model_api_key",
headless=True,
dom_settle_timeout_ms=5000,
self_heal=False,
wait_for_captcha_solves=False,
verbose=3,
_testing=True, # Enable testing mode
)
# Verify the tool was initialized with the correct parameters
assert tool.api_key == "custom_api_key"
assert tool.project_id == "custom_project_id"
assert tool.model_api_key == "custom_model_api_key"
assert tool.headless is True
assert tool.dom_settle_timeout_ms == 5000
assert tool.self_heal is False
assert tool.wait_for_captcha_solves is False
assert tool.verbose == 3
def test_close_method():
"""Test that the close method cleans up resources correctly."""
# Create the tool with testing mode
tool = StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True,
)
# Setup mock stagehand instance
tool._stagehand = MagicMock()
tool._stagehand.close = MagicMock() # Non-async mock
tool._page = MagicMock()
# Call the close method
tool.close()
# Verify resources were cleaned up
assert tool._stagehand is None
assert tool._page is None

View File

@@ -0,0 +1,174 @@
from unittest.mock import patch
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
CodeInterpreterTool,
SandboxPython,
)
import pytest
@pytest.fixture
def printer_mock():
with patch("crewai_tools.printer.Printer.print") as mock:
yield mock
@pytest.fixture
def docker_unavailable_mock():
with patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.CodeInterpreterTool._check_docker_available",
return_value=False,
) as mock:
yield mock
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = "print('Hello, World!')"
libraries_used = ["numpy", "pandas"]
expected_output = "Hello, World!\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker_with_error(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = "print(1/0)"
libraries_used = ["numpy", "pandas"]
expected_output = "Something went wrong while running the code: \nZeroDivisionError: division by zero\n"
docker_mock().containers.run().exec_run().exit_code = 1
docker_mock().containers.run().exec_run().output = (
b"ZeroDivisionError: division by zero\n"
)
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker_with_script(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = """print("This is line 1")
print("This is line 2")"""
libraries_used = []
expected_output = "This is line 1\nThis is line 2\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
def test_restricted_sandbox_basic_code_execution(printer_mock, docker_unavailable_mock):
"""Test basic code execution."""
tool = CodeInterpreterTool()
code = """
result = 2 + 2
print(result)
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert result == 4
def test_restricted_sandbox_running_with_blocked_modules(
printer_mock, docker_unavailable_mock
):
"""Test that restricted modules cannot be imported."""
tool = CodeInterpreterTool()
restricted_modules = SandboxPython.BLOCKED_MODULES
for module in restricted_modules:
code = f"""
import {module}
result = "Import succeeded"
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert f"An error occurred: Importing '{module}' is not allowed" in result
def test_restricted_sandbox_running_with_blocked_builtins(
printer_mock, docker_unavailable_mock
):
"""Test that restricted builtins are not available."""
tool = CodeInterpreterTool()
restricted_builtins = SandboxPython.UNSAFE_BUILTINS
for builtin in restricted_builtins:
code = f"""
{builtin}("test")
result = "Builtin available"
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert f"An error occurred: name '{builtin}' is not defined" in result
def test_restricted_sandbox_running_with_no_result_variable(
printer_mock, docker_unavailable_mock
):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool()
code = """
x = 10
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert result == "No result variable found."
def test_unsafe_mode_running_with_no_result_variable(
printer_mock, docker_unavailable_mock
):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool(unsafe_mode=True)
code = """
x = 10
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"WARNING: Running code in unsafe mode", color="bold_magenta"
)
assert result == "No result variable found."
def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool(unsafe_mode=True)
code = """
import os
os.system("ls -la")
result = eval("5/1")
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"WARNING: Running code in unsafe mode", color="bold_magenta"
)
assert 5.0 == result

View File

@@ -0,0 +1,137 @@
import os
import shutil
import tempfile
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool
import pytest
@pytest.fixture
def tool():
return FileWriterTool()
@pytest.fixture
def temp_env():
temp_dir = tempfile.mkdtemp()
test_file = "test.txt"
test_content = "Hello, World!"
yield {
"temp_dir": temp_dir,
"test_file": test_file,
"test_content": test_content,
}
shutil.rmtree(temp_dir, ignore_errors=True)
def get_test_path(filename, directory):
return os.path.join(directory, filename)
def read_file(path):
with open(path, "r") as f:
return f.read()
def test_basic_file_write(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite=True,
)
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
assert os.path.exists(path)
assert read_file(path) == temp_env["test_content"]
assert "successfully written" in result
def test_directory_creation(tool, temp_env):
new_dir = os.path.join(temp_env["temp_dir"], "nested_dir")
result = tool._run(
filename=temp_env["test_file"],
directory=new_dir,
content=temp_env["test_content"],
overwrite=True,
)
path = get_test_path(temp_env["test_file"], new_dir)
assert os.path.exists(new_dir)
assert os.path.exists(path)
assert "successfully written" in result
@pytest.mark.parametrize(
"overwrite",
["y", "yes", "t", "true", "on", "1", True],
)
def test_overwrite_true(tool, temp_env, overwrite):
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
with open(path, "w") as f:
f.write("Original content")
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="New content",
overwrite=overwrite,
)
assert read_file(path) == "New content"
assert "successfully written" in result
def test_invalid_overwrite_value(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite="invalid",
)
assert "invalid value" in result
def test_missing_required_fields(tool, temp_env):
result = tool._run(
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite=True,
)
assert "An error occurred while accessing key: 'filename'" in result
def test_empty_content(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="",
overwrite=True,
)
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
assert os.path.exists(path)
assert read_file(path) == ""
assert "successfully written" in result
@pytest.mark.parametrize(
"overwrite",
["n", "no", "f", "false", "off", "0", False],
)
def test_file_exists_error_handling(tool, temp_env, overwrite):
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
with open(path, "w") as f:
f.write("Pre-existing content")
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="Should not be written",
overwrite=overwrite,
)
assert "already exists and overwrite option was not passed" in result
assert read_file(path) == "Pre-existing content"

View File

@@ -0,0 +1,10 @@
from pydantic.warnings import PydanticDeprecatedSince20
import pytest
@pytest.mark.filterwarnings("error", category=PydanticDeprecatedSince20)
def test_import_tools_without_pydantic_deprecation_warnings():
# This test is to ensure that the import of crewai_tools does not raise any Pydantic deprecation warnings.
import crewai_tools
assert crewai_tools

View File

@@ -0,0 +1,74 @@
import json
from unittest.mock import patch
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
import pytest
# Unit Test Fixtures
@pytest.fixture
def mongodb_vector_search_tool():
tool = MongoDBVectorSearchTool(
connection_string="foo", database_name="bar", collection_name="test"
)
tool._embed_texts = lambda x: [[0.1]]
yield tool
# Unit Tests
def test_successful_query_execution(mongodb_vector_search_tool):
# Enable embedding
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
assert len(results) == 1
assert results[0]["text"] == "foo"
assert results[0]["_id"] == 1
def test_provide_config():
query_config = MongoDBVectorSearchConfig(limit=10)
tool = MongoDBVectorSearchTool(
connection_string="foo",
database_name="bar",
collection_name="test",
query_config=query_config,
vector_index_name="foo",
embedding_model="bar",
)
tool._embed_texts = lambda x: [[0.1]]
with patch.object(tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
tool._run(query="sandwiches")
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
def test_cleanup_on_deletion(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
# Trigger cleanup
mongodb_vector_search_tool.__del__()
mock_client.close.assert_called_once()
def test_create_search_index(mongodb_vector_search_tool):
with patch(
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
) as mock_create_search_index:
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
kwargs = mock_create_search_index.mock_calls[0].kwargs
assert kwargs["dimensions"] == 10
assert kwargs["similarity"] == "cosine"
def test_add_texts(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
mongodb_vector_search_tool.add_texts(["foo"])
args = bulk_write.mock_calls[0].args
assert "ReplaceOne" in str(args[0][0])
assert "foo" in str(args[0][0])

View File

@@ -0,0 +1,161 @@
import json
import os
from unittest.mock import MagicMock
from crewai.tools.base_tool import BaseTool
from crewai_tools import (
OxylabsAmazonProductScraperTool,
OxylabsAmazonSearchScraperTool,
OxylabsGoogleSearchScraperTool,
OxylabsUniversalScraperTool,
)
from crewai_tools.tools.oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
OxylabsAmazonProductScraperConfig,
)
from crewai_tools.tools.oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
OxylabsGoogleSearchScraperConfig,
)
from oxylabs import RealtimeClient
from oxylabs.sources.response import Response as OxylabsResponse
from pydantic import BaseModel
import pytest
@pytest.fixture
def oxylabs_api() -> RealtimeClient:
oxylabs_api_mock = MagicMock()
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Scraping Sandbox</title>
</head>
<body>
<div id="main">
<div id="product-list">
<div>
<p>Amazing product</p>
<p>Price $14.99</p>
</div>
<div>
<p>Good product</p>
<p>Price $9.99</p>
</div>
</div>
</div>
</body>
</html>
"""
json_content = {
"results": {
"products": [
{"title": "Amazing product", "price": 14.99, "currency": "USD"},
{"title": "Good product", "price": 9.99, "currency": "USD"},
],
},
}
html_response = OxylabsResponse({"results": [{"content": html_content}]})
json_response = OxylabsResponse({"results": [{"content": json_content}]})
oxylabs_api_mock.universal.scrape_url.side_effect = [json_response, html_response]
oxylabs_api_mock.amazon.scrape_search.side_effect = [json_response, html_response]
oxylabs_api_mock.amazon.scrape_product.side_effect = [json_response, html_response]
oxylabs_api_mock.google.scrape_search.side_effect = [json_response, html_response]
return oxylabs_api_mock
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization(tool_class: type[BaseTool]):
tool = tool_class(username="username", password="password")
assert isinstance(tool, tool_class)
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization_with_env_vars(tool_class: type[BaseTool]):
os.environ["OXYLABS_USERNAME"] = "username"
os.environ["OXYLABS_PASSWORD"] = "password"
tool = tool_class()
assert isinstance(tool, tool_class)
del os.environ["OXYLABS_USERNAME"]
del os.environ["OXYLABS_PASSWORD"]
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization_failure(tool_class: type[BaseTool]):
# making sure env vars are not set
for key in ["OXYLABS_USERNAME", "OXYLABS_PASSWORD"]:
if key in os.environ:
del os.environ[key]
with pytest.raises(ValueError):
tool_class()
@pytest.mark.parametrize(
("tool_class", "tool_config"),
[
(OxylabsUniversalScraperTool, {"geo_location": "Paris, France"}),
(
OxylabsAmazonSearchScraperTool,
{"domain": "co.uk"},
),
(
OxylabsGoogleSearchScraperTool,
OxylabsGoogleSearchScraperConfig(render="html"),
),
(
OxylabsAmazonProductScraperTool,
OxylabsAmazonProductScraperConfig(parse=True),
),
],
)
def test_tool_invocation(
tool_class: type[BaseTool],
tool_config: BaseModel,
oxylabs_api: RealtimeClient,
):
tool = tool_class(username="username", password="password", config=tool_config)
# setting via __dict__ to bypass pydantic validation
tool.__dict__["oxylabs_api"] = oxylabs_api
# verifying parsed job returns json content
result = tool.run("Scraping Query 1")
assert isinstance(result, str)
assert isinstance(json.loads(result), dict)
# verifying raw job returns str
result = tool.run("Scraping Query 2")
assert isinstance(result, str)
assert "<!DOCTYPE html>" in result

View File

@@ -0,0 +1,352 @@
import os
from pathlib import Path
import tempfile
from unittest.mock import MagicMock
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools import (
CSVSearchTool,
CodeDocsSearchTool,
DOCXSearchTool,
DirectorySearchTool,
GithubSearchTool,
JSONSearchTool,
MDXSearchTool,
PDFSearchTool,
TXTSearchTool,
WebsiteSearchTool,
XMLSearchTool,
YoutubeChannelSearchTool,
YoutubeVideoSearchTool,
)
from crewai_tools.tools.rag.rag_tool import Adapter
import pytest
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
@pytest.fixture
def mock_adapter():
mock_adapter = MagicMock(spec=Adapter)
return mock_adapter
def test_directory_search_tool():
with tempfile.TemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.txt"
test_file.write_text("This is a test file for directory search")
tool = DirectorySearchTool(directory=temp_dir)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
def test_pdf_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
result = tool._run(query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with(
"test content", similarity_threshold=0.6, limit=5
)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = PDFSearchTool(adapter=mock_adapter)
result = tool._run(pdf="test.pdf", query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with(
"test content", similarity_threshold=0.6, limit=5
)
def test_txt_search_tool():
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
temp_file.write(b"This is a test file for txt search")
temp_file_path = temp_file.name
try:
tool = TXTSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
finally:
os.unlink(temp_file_path)
def test_docx_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
result = tool._run(search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with(
"test content", similarity_threshold=0.6, limit=5
)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = DOCXSearchTool(adapter=mock_adapter)
result = tool._run(docx="test.docx", search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with(
"test content", similarity_threshold=0.6, limit=5
)
def test_json_search_tool():
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file.write(b'{"test": "This is a test JSON file"}')
temp_file_path = temp_file.name
try:
tool = JSONSearchTool()
result = tool._run(search_query="test JSON", json_path=temp_file_path)
assert "test json" in result.lower()
finally:
os.unlink(temp_file_path)
def test_xml_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = XMLSearchTool(adapter=mock_adapter)
result = tool._run(search_query="test XML", xml="test.xml")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.xml")
mock_adapter.query.assert_called_once_with(
"test XML", similarity_threshold=0.6, limit=5
)
def test_csv_search_tool():
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
temp_file.write(b"name,description\ntest,This is a test CSV file")
temp_file_path = temp_file.name
try:
tool = CSVSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test CSV")
assert "test csv" in result.lower()
finally:
os.unlink(temp_file_path)
def test_mdx_search_tool():
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
temp_file.write(b"# Test MDX\nThis is a test MDX file")
temp_file_path = temp_file.name
try:
tool = MDXSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test MDX")
assert "test mdx" in result.lower()
finally:
os.unlink(temp_file_path)
def test_website_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
website = "https://crewai.com"
search_query = "what is crewai?"
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
result = tool._run(search_query=search_query)
mock_adapter.query.assert_called_once_with(
"what is crewai?", similarity_threshold=0.6, limit=5
)
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
assert "this is a test" in result.lower()
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = WebsiteSearchTool(adapter=mock_adapter)
result = tool._run(website=website, search_query=search_query)
mock_adapter.query.assert_called_once_with(
"what is crewai?", similarity_threshold=0.6, limit=5
)
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
assert "this is a test" in result.lower()
def test_youtube_video_search_tool(mock_adapter):
mock_adapter.query.return_value = "some video description"
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
search_query = "what is the video about?"
tool = YoutubeVideoSearchTool(
youtube_video_url=youtube_video_url,
adapter=mock_adapter,
)
result = tool._run(search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.query.return_value = "channel description"
youtube_channel_handle = "@crewai"
search_query = "what is the channel about?"
tool = YoutubeChannelSearchTool(
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
)
result = tool._run(search_query=search_query)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
result = tool._run(
youtube_channel_handle=youtube_channel_handle, search_query=search_query
)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
def test_code_docs_search_tool(mock_adapter):
mock_adapter.query.return_value = "test documentation"
docs_url = "https://crewai.com/any-docs-url"
search_query = "test documentation"
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
result = tool._run(search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = CodeDocsSearchTool(adapter=mock_adapter)
result = tool._run(docs_url=docs_url, search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(
search_query, similarity_threshold=0.6, limit=5
)
def test_github_search_tool(mock_adapter):
mock_adapter.query.return_value = "repo description"
# ensure the provided repo and content types are used after initialization
tool = GithubSearchTool(
gh_token="test_token",
github_repo="crewai/crewai",
content_types=["code"],
adapter=mock_adapter,
)
result = tool._run(search_query="tell me about crewai repo")
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={"content_types": ["code"], "gh_token": "test_token"},
)
mock_adapter.query.assert_called_once_with(
"tell me about crewai repo", similarity_threshold=0.6, limit=5
)
# ensure content types provided by run call is used
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
content_types=["code", "issue"],
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={"content_types": ["code", "issue"], "gh_token": "test_token"},
)
mock_adapter.query.assert_called_once_with(
"tell me about crewai repo", similarity_threshold=0.6, limit=5
)
# ensure default content types are used if not provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={
"content_types": ["code", "repo", "pr", "issue"],
"gh_token": "test_token",
},
)
mock_adapter.query.assert_called_once_with(
"tell me about crewai repo", similarity_threshold=0.6, limit=5
)
# ensure nothing is added if no repo is provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(search_query="tell me about crewai repo")
mock_adapter.add.assert_not_called()
mock_adapter.query.assert_called_once_with(
"tell me about crewai repo", similarity_threshold=0.6, limit=5
)

View File

@@ -0,0 +1,231 @@
import unittest
from unittest.mock import MagicMock
from crewai.tools import BaseTool
from crewai_tools.adapters.tool_collection import ToolCollection
class TestToolCollection(unittest.TestCase):
def setUp(self):
self.search_tool = self._create_mock_tool(
"SearcH", "Search Tool"
) # Tool name is case sensitive
self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool")
self.translator_tool = self._create_mock_tool("translator", "Translator Tool")
self.tools = ToolCollection(
[self.search_tool, self.calculator_tool, self.translator_tool]
)
def _create_mock_tool(self, name, description):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.name = name
mock_tool.description = description
return mock_tool
def test_initialization(self):
self.assertEqual(len(self.tools), 3)
self.assertEqual(self.tools[0].name, "SearcH")
self.assertEqual(self.tools[1].name, "calculator")
self.assertEqual(self.tools[2].name, "translator")
def test_empty_initialization(self):
empty_collection = ToolCollection()
self.assertEqual(len(empty_collection), 0)
self.assertEqual(empty_collection._name_cache, {})
def test_initialization_with_none(self):
collection = ToolCollection(None)
self.assertEqual(len(collection), 0)
self.assertEqual(collection._name_cache, {})
def test_access_by_index(self):
self.assertEqual(self.tools[0], self.search_tool)
self.assertEqual(self.tools[1], self.calculator_tool)
self.assertEqual(self.tools[2], self.translator_tool)
def test_access_by_name(self):
self.assertEqual(self.tools["search"], self.search_tool)
self.assertEqual(self.tools["calculator"], self.calculator_tool)
self.assertEqual(self.tools["translator"], self.translator_tool)
def test_key_error_for_invalid_name(self):
with self.assertRaises(KeyError):
_ = self.tools["nonexistent"]
def test_index_error_for_invalid_index(self):
with self.assertRaises(IndexError):
_ = self.tools[10]
def test_negative_index(self):
self.assertEqual(self.tools[-1], self.translator_tool)
self.assertEqual(self.tools[-2], self.calculator_tool)
self.assertEqual(self.tools[-3], self.search_tool)
def test_append(self):
new_tool = self._create_mock_tool("new", "New Tool")
self.tools.append(new_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools[3], new_tool)
self.assertEqual(self.tools["new"], new_tool)
self.assertIn("new", self.tools._name_cache)
def test_append_duplicate_name(self):
duplicate_tool = self._create_mock_tool("search", "Duplicate Search Tool")
self.tools.append(duplicate_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools["search"], duplicate_tool)
def test_extend(self):
new_tools = [
self._create_mock_tool("tool4", "Tool 4"),
self._create_mock_tool("tool5", "Tool 5"),
]
self.tools.extend(new_tools)
self.assertEqual(len(self.tools), 5)
self.assertEqual(self.tools["tool4"], new_tools[0])
self.assertEqual(self.tools["tool5"], new_tools[1])
self.assertIn("tool4", self.tools._name_cache)
self.assertIn("tool5", self.tools._name_cache)
def test_insert(self):
new_tool = self._create_mock_tool("inserted", "Inserted Tool")
self.tools.insert(1, new_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools[1], new_tool)
self.assertEqual(self.tools["inserted"], new_tool)
self.assertIn("inserted", self.tools._name_cache)
def test_remove(self):
self.tools.remove(self.calculator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["calculator"]
self.assertNotIn("calculator", self.tools._name_cache)
def test_remove_nonexistent_tool(self):
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
with self.assertRaises(ValueError):
self.tools.remove(nonexistent_tool)
def test_pop(self):
popped = self.tools.pop(1)
self.assertEqual(popped, self.calculator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["calculator"]
self.assertNotIn("calculator", self.tools._name_cache)
def test_pop_last(self):
popped = self.tools.pop()
self.assertEqual(popped, self.translator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["translator"]
self.assertNotIn("translator", self.tools._name_cache)
def test_clear(self):
self.tools.clear()
self.assertEqual(len(self.tools), 0)
self.assertEqual(self.tools._name_cache, {})
with self.assertRaises(KeyError):
_ = self.tools["search"]
def test_iteration(self):
tools_list = list(self.tools)
self.assertEqual(
tools_list, [self.search_tool, self.calculator_tool, self.translator_tool]
)
def test_contains(self):
self.assertIn(self.search_tool, self.tools)
self.assertIn(self.calculator_tool, self.tools)
self.assertIn(self.translator_tool, self.tools)
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
self.assertNotIn(nonexistent_tool, self.tools)
def test_slicing(self):
slice_result = self.tools[1:3]
self.assertEqual(len(slice_result), 2)
self.assertEqual(slice_result[0], self.calculator_tool)
self.assertEqual(slice_result[1], self.translator_tool)
self.assertIsInstance(slice_result, list)
self.assertNotIsInstance(slice_result, ToolCollection)
def test_getitem_with_tool_name_as_int(self):
numeric_name_tool = self._create_mock_tool("123", "Numeric Name Tool")
self.tools.append(numeric_name_tool)
self.assertEqual(self.tools["123"], numeric_name_tool)
with self.assertRaises(IndexError):
_ = self.tools[123]
def test_filter_by_names(self):
filtered = self.tools.filter_by_names(None)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 3)
filtered = self.tools.filter_by_names(["search", "translator"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered[1], self.translator_tool)
self.assertEqual(filtered["search"], self.search_tool)
self.assertEqual(filtered["translator"], self.translator_tool)
filtered = self.tools.filter_by_names(["search", "nonexistent"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0], self.search_tool)
filtered = self.tools.filter_by_names(["nonexistent1", "nonexistent2"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
filtered = self.tools.filter_by_names([])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
def test_filter_where(self):
filtered = self.tools.filter_where(lambda tool: tool.name.startswith("S"))
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered["search"], self.search_tool)
filtered = self.tools.filter_where(lambda tool: True)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 3)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered[1], self.calculator_tool)
self.assertEqual(filtered[2], self.translator_tool)
filtered = self.tools.filter_where(lambda tool: False)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
filtered = self.tools.filter_where(lambda tool: len(tool.name) > 8)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered[0], self.calculator_tool)
self.assertEqual(filtered[1], self.translator_tool)