mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 00:28:13 +00:00
Squashed 'packages/tools/' content from commit 78317b9c
git-subtree-dir: packages/tools git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
This commit is contained in:
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
50
tests/tools/brave_search_tool_test.py
Normal file
50
tests/tools/brave_search_tool_test.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brave_tool():
|
||||
return BraveSearchTool(n_results=2)
|
||||
|
||||
|
||||
def test_brave_tool_initialization():
|
||||
tool = BraveSearchTool()
|
||||
assert tool.n_results == 10
|
||||
assert tool.save_file is False
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_brave_tool_search(mock_get, brave_tool):
|
||||
mock_response = {
|
||||
"web": {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "http://test.com",
|
||||
"description": "Test Description",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
|
||||
result = brave_tool.run(search_query="test")
|
||||
assert "Test Title" in result
|
||||
assert "http://test.com" in result
|
||||
|
||||
|
||||
def test_brave_tool():
|
||||
tool = BraveSearchTool(
|
||||
n_results=2,
|
||||
)
|
||||
x = tool.run(search_query="ChatGPT")
|
||||
print(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_brave_tool()
|
||||
test_brave_tool_initialization()
|
||||
# test_brave_tool_search(brave_tool)
|
||||
54
tests/tools/brightdata_serp_tool_test.py
Normal file
54
tests/tools/brightdata_serp_tool_test.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_serp import BrightDataSearchTool
|
||||
|
||||
|
||||
class TestBrightDataSearchTool(unittest.TestCase):
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
|
||||
)
|
||||
def setUp(self):
|
||||
self.tool = BrightDataSearchTool()
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_successful_search(self, mock_post):
|
||||
# Sample mock JSON response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "mock response text"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Define search input
|
||||
input_data = {
|
||||
"query": "latest AI news",
|
||||
"search_engine": "google",
|
||||
"country": "us",
|
||||
"language": "en",
|
||||
"search_type": "nws",
|
||||
"device_type": "desktop",
|
||||
"parse_results": True,
|
||||
"save_file": False,
|
||||
}
|
||||
|
||||
result = self.tool._run(**input_data)
|
||||
|
||||
# Assertions
|
||||
self.assertIsInstance(result, str) # Your tool returns response.text (string)
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_with_request_exception(self, mock_post):
|
||||
mock_post.side_effect = Exception("Timeout")
|
||||
|
||||
result = self.tool._run(query="AI", search_engine="google")
|
||||
self.assertIn("Error", result)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up env vars
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
64
tests/tools/brightdata_webunlocker_tool_test.py
Normal file
64
tests/tools/brightdata_webunlocker_tool_test.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_unlocker import (
|
||||
BrightDataWebUnlockerTool,
|
||||
)
|
||||
|
||||
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
|
||||
)
|
||||
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
|
||||
def test_run_success_html(mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html><body>Test</body></html>"
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = BrightDataWebUnlockerTool()
|
||||
result = tool._run(url="https://example.com", format="html", save_file=False)
|
||||
|
||||
print(result)
|
||||
|
||||
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
|
||||
)
|
||||
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
|
||||
def test_run_success_json(mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "mock response text"
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = BrightDataWebUnlockerTool()
|
||||
result = tool._run(url="https://example.com", format="json")
|
||||
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
|
||||
)
|
||||
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
|
||||
def test_run_http_error(mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.text = "Forbidden"
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError(
|
||||
response=mock_response
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = BrightDataWebUnlockerTool()
|
||||
result = tool._run(url="https://example.com")
|
||||
|
||||
assert "HTTP Error" in result
|
||||
assert "Forbidden" in result
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
365
tests/tools/couchbase_tool_test.py
Normal file
365
tests/tools/couchbase_tool_test.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
|
||||
# Mock the couchbase library before importing the tool
|
||||
# This prevents ImportErrors if couchbase isn't installed in the test environment
|
||||
mock_couchbase = MagicMock()
|
||||
mock_couchbase.search = MagicMock()
|
||||
mock_couchbase.cluster = MagicMock()
|
||||
mock_couchbase.options = MagicMock()
|
||||
mock_couchbase.vector_search = MagicMock()
|
||||
|
||||
# Simulate the structure needed for checks
|
||||
mock_couchbase.cluster.Cluster = MagicMock()
|
||||
mock_couchbase.options.SearchOptions = MagicMock()
|
||||
mock_couchbase.vector_search.VectorQuery = MagicMock()
|
||||
mock_couchbase.vector_search.VectorSearch = MagicMock()
|
||||
mock_couchbase.search.SearchRequest = MagicMock() # Mock the class itself
|
||||
mock_couchbase.search.SearchRequest.create = MagicMock() # Mock the class method
|
||||
|
||||
# Add necessary exception types if needed for testing error handling
|
||||
class MockCouchbaseException(Exception):
|
||||
pass
|
||||
mock_couchbase.exceptions = MagicMock()
|
||||
mock_couchbase.exceptions.BucketNotFoundException = MockCouchbaseException
|
||||
mock_couchbase.exceptions.ScopeNotFoundException = MockCouchbaseException
|
||||
mock_couchbase.exceptions.CollectionNotFoundException = MockCouchbaseException
|
||||
mock_couchbase.exceptions.IndexNotFoundException = MockCouchbaseException
|
||||
|
||||
|
||||
import sys
|
||||
sys.modules['couchbase'] = mock_couchbase
|
||||
sys.modules['couchbase.search'] = mock_couchbase.search
|
||||
sys.modules['couchbase.cluster'] = mock_couchbase.cluster
|
||||
sys.modules['couchbase.options'] = mock_couchbase.options
|
||||
sys.modules['couchbase.vector_search'] = mock_couchbase.vector_search
|
||||
sys.modules['couchbase.exceptions'] = mock_couchbase.exceptions
|
||||
|
||||
# Now import the tool
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import CouchbaseFTSVectorSearchTool
|
||||
|
||||
# --- Test Fixtures ---
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_global_mocks():
|
||||
"""Reset call counts for globally defined mocks before each test."""
|
||||
# Reset the specific mock causing the issue
|
||||
mock_couchbase.vector_search.VectorQuery.reset_mock()
|
||||
# It's good practice to also reset other related global mocks
|
||||
# that might be called in your tests to prevent similar issues:
|
||||
mock_couchbase.vector_search.VectorSearch.from_vector_query.reset_mock()
|
||||
mock_couchbase.search.SearchRequest.create.reset_mock()
|
||||
|
||||
# Additional fixture to handle import pollution in full test suite
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_couchbase_mocks():
|
||||
"""Ensure that couchbase imports are properly mocked even when other tests have run first."""
|
||||
# This fixture ensures our mocks are in place regardless of import order
|
||||
original_modules = {}
|
||||
|
||||
# Store any existing modules
|
||||
for module_name in ['couchbase', 'couchbase.search', 'couchbase.cluster', 'couchbase.options', 'couchbase.vector_search', 'couchbase.exceptions']:
|
||||
if module_name in sys.modules:
|
||||
original_modules[module_name] = sys.modules[module_name]
|
||||
|
||||
# Ensure our mocks are active
|
||||
sys.modules['couchbase'] = mock_couchbase
|
||||
sys.modules['couchbase.search'] = mock_couchbase.search
|
||||
sys.modules['couchbase.cluster'] = mock_couchbase.cluster
|
||||
sys.modules['couchbase.options'] = mock_couchbase.options
|
||||
sys.modules['couchbase.vector_search'] = mock_couchbase.vector_search
|
||||
sys.modules['couchbase.exceptions'] = mock_couchbase.exceptions
|
||||
|
||||
yield
|
||||
|
||||
# Restore original modules if they existed
|
||||
for module_name, original_module in original_modules.items():
|
||||
if original_module is not None:
|
||||
sys.modules[module_name] = original_module
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cluster():
|
||||
cluster = MagicMock()
|
||||
bucket_manager = MagicMock()
|
||||
search_index_manager = MagicMock()
|
||||
bucket = MagicMock()
|
||||
scope = MagicMock()
|
||||
collection = MagicMock()
|
||||
scope_search_index_manager = MagicMock()
|
||||
|
||||
# Setup mock return values for checks
|
||||
cluster.buckets.return_value = bucket_manager
|
||||
cluster.search_indexes.return_value = search_index_manager
|
||||
cluster.bucket.return_value = bucket
|
||||
bucket.scope.return_value = scope
|
||||
scope.collection.return_value = collection
|
||||
scope.search_indexes.return_value = scope_search_index_manager
|
||||
|
||||
# Mock bucket existence check
|
||||
bucket_manager.get_bucket.return_value = True
|
||||
|
||||
# Mock scope/collection existence check
|
||||
mock_scope_spec = MagicMock()
|
||||
mock_scope_spec.name = "test_scope"
|
||||
mock_collection_spec = MagicMock()
|
||||
mock_collection_spec.name = "test_collection"
|
||||
mock_scope_spec.collections = [mock_collection_spec]
|
||||
bucket.collections.return_value.get_all_scopes.return_value = [mock_scope_spec]
|
||||
|
||||
# Mock index existence check
|
||||
mock_index_def = MagicMock()
|
||||
mock_index_def.name = "test_index"
|
||||
scope_search_index_manager.get_all_indexes.return_value = [mock_index_def]
|
||||
search_index_manager.get_all_indexes.return_value = [mock_index_def]
|
||||
|
||||
return cluster
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_function():
|
||||
# Simple mock embedding function
|
||||
# return lambda query: [0.1] * 10 # Example embedding vector
|
||||
return MagicMock(return_value=[0.1] * 10)
|
||||
|
||||
@pytest.fixture
|
||||
def tool_config(mock_cluster, mock_embedding_function):
|
||||
return {
|
||||
"cluster": mock_cluster,
|
||||
"bucket_name": "test_bucket",
|
||||
"scope_name": "test_scope",
|
||||
"collection_name": "test_collection",
|
||||
"index_name": "test_index",
|
||||
"embedding_function": mock_embedding_function,
|
||||
"limit": 5,
|
||||
"embedding_key": "test_embedding",
|
||||
"scoped_index": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def couchbase_tool(tool_config):
|
||||
# Patch COUCHBASE_AVAILABLE to True for these tests
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
tool = CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_iter():
|
||||
mock_iter = MagicMock()
|
||||
# Simulate search results with a 'fields' attribute
|
||||
mock_row1 = MagicMock()
|
||||
mock_row1.fields = {"id": "doc1", "text": "content 1", "test_embedding": [0.1]*10}
|
||||
mock_row2 = MagicMock()
|
||||
mock_row2.fields = {"id": "doc2", "text": "content 2", "test_embedding": [0.2]*10}
|
||||
mock_iter.rows.return_value = [mock_row1, mock_row2]
|
||||
return mock_iter
|
||||
|
||||
# --- Test Cases ---
|
||||
|
||||
def test_initialization_success(couchbase_tool, tool_config):
|
||||
"""Test successful initialization with valid config."""
|
||||
assert couchbase_tool.cluster == tool_config["cluster"]
|
||||
assert couchbase_tool.bucket_name == "test_bucket"
|
||||
assert couchbase_tool.scope_name == "test_scope"
|
||||
assert couchbase_tool.collection_name == "test_collection"
|
||||
assert couchbase_tool.index_name == "test_index"
|
||||
assert couchbase_tool.embedding_function is not None
|
||||
assert couchbase_tool.limit == 5
|
||||
assert couchbase_tool.embedding_key == "test_embedding"
|
||||
assert couchbase_tool.scoped_index == True
|
||||
|
||||
# Check if helper methods were called during init (via mocks in fixture)
|
||||
couchbase_tool.cluster.buckets().get_bucket.assert_called_once_with("test_bucket")
|
||||
couchbase_tool.cluster.bucket().collections().get_all_scopes.assert_called_once()
|
||||
couchbase_tool.cluster.bucket().scope().search_indexes().get_all_indexes.assert_called_once()
|
||||
|
||||
def test_initialization_missing_required_args(mock_cluster, mock_embedding_function):
|
||||
"""Test initialization fails when required arguments are missing."""
|
||||
base_config = {
|
||||
"cluster": mock_cluster, "bucket_name": "b", "scope_name": "s",
|
||||
"collection_name": "c", "index_name": "i", "embedding_function": mock_embedding_function
|
||||
}
|
||||
required_keys = base_config.keys()
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
for key in required_keys:
|
||||
incomplete_config = base_config.copy()
|
||||
del incomplete_config[key]
|
||||
with pytest.raises(ValueError):
|
||||
CouchbaseFTSVectorSearchTool(**incomplete_config)
|
||||
|
||||
def test_initialization_couchbase_unavailable():
|
||||
"""Test behavior when couchbase library is not available."""
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', False):
|
||||
with patch('click.confirm', return_value=False) as mock_confirm:
|
||||
with pytest.raises(ImportError, match="The 'couchbase' package is required"):
|
||||
CouchbaseFTSVectorSearchTool(cluster=MagicMock(), bucket_name="b", scope_name="s",
|
||||
collection_name="c", index_name="i", embedding_function=MagicMock())
|
||||
mock_confirm.assert_called_once() # Ensure user was prompted
|
||||
|
||||
def test_run_success_scoped_index(couchbase_tool, mock_search_iter, tool_config, mock_embedding_function):
|
||||
"""Test successful _run execution with a scoped index."""
|
||||
query = "find relevant documents"
|
||||
# expected_embedding = mock_embedding_function(query)
|
||||
|
||||
# Mock the scope search method
|
||||
couchbase_tool._scope.search = MagicMock(return_value=mock_search_iter)
|
||||
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery') as mock_vq, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch') as mock_vs, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest') as mock_sr, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions') as mock_so:
|
||||
|
||||
# Set up the mock objects and their return values
|
||||
mock_vector_query = MagicMock()
|
||||
mock_vector_search = MagicMock()
|
||||
mock_search_req = MagicMock()
|
||||
mock_search_options = MagicMock()
|
||||
|
||||
mock_vq.return_value = mock_vector_query
|
||||
mock_vs.from_vector_query.return_value = mock_vector_search
|
||||
mock_sr.create.return_value = mock_search_req
|
||||
mock_so.return_value = mock_search_options
|
||||
|
||||
result = couchbase_tool._run(query=query)
|
||||
|
||||
# Check embedding function call
|
||||
tool_config['embedding_function'].assert_called_once_with(query)
|
||||
|
||||
# Check VectorQuery call
|
||||
mock_vq.assert_called_once_with(
|
||||
tool_config['embedding_key'], mock_embedding_function.return_value, tool_config['limit']
|
||||
)
|
||||
# Check VectorSearch call
|
||||
mock_vs.from_vector_query.assert_called_once_with(mock_vector_query)
|
||||
# Check SearchRequest creation
|
||||
mock_sr.create.assert_called_once_with(mock_vector_search)
|
||||
# Check SearchOptions creation
|
||||
mock_so.assert_called_once_with(limit=tool_config['limit'], fields=["*"])
|
||||
|
||||
# Check that scope search was called correctly
|
||||
couchbase_tool._scope.search.assert_called_once_with(
|
||||
tool_config['index_name'],
|
||||
mock_search_req,
|
||||
mock_search_options
|
||||
)
|
||||
|
||||
# Check cluster search was NOT called
|
||||
couchbase_tool.cluster.search.assert_not_called()
|
||||
|
||||
# Check result format (simple check for JSON structure)
|
||||
assert '"id": "doc1"' in result
|
||||
assert '"id": "doc2"' in result
|
||||
assert result.startswith('[') # Should be valid JSON after concatenation
|
||||
|
||||
def test_run_success_global_index(tool_config, mock_search_iter, mock_embedding_function):
|
||||
"""Test successful _run execution with a global (non-scoped) index."""
|
||||
tool_config['scoped_index'] = False
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
couchbase_tool = CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
|
||||
query = "find global documents"
|
||||
# expected_embedding = mock_embedding_function(query)
|
||||
|
||||
# Mock the cluster search method
|
||||
couchbase_tool.cluster.search = MagicMock(return_value=mock_search_iter)
|
||||
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery') as mock_vq, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch') as mock_vs, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest') as mock_sr, \
|
||||
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions') as mock_so:
|
||||
|
||||
# Set up the mock objects and their return values
|
||||
mock_vector_query = MagicMock()
|
||||
mock_vector_search = MagicMock()
|
||||
mock_search_req = MagicMock()
|
||||
mock_search_options = MagicMock()
|
||||
|
||||
mock_vq.return_value = mock_vector_query
|
||||
mock_vs.from_vector_query.return_value = mock_vector_search
|
||||
mock_sr.create.return_value = mock_search_req
|
||||
mock_so.return_value = mock_search_options
|
||||
|
||||
result = couchbase_tool._run(query=query)
|
||||
|
||||
# Check embedding function call
|
||||
tool_config['embedding_function'].assert_called_once_with(query)
|
||||
|
||||
# Check VectorQuery/Search call
|
||||
mock_vq.assert_called_once_with(
|
||||
tool_config['embedding_key'], mock_embedding_function.return_value, tool_config['limit']
|
||||
)
|
||||
mock_sr.create.assert_called_once_with(mock_vector_search)
|
||||
# Check SearchOptions creation
|
||||
mock_so.assert_called_once_with(limit=tool_config['limit'], fields=["*"])
|
||||
|
||||
# Check that cluster search was called correctly
|
||||
couchbase_tool.cluster.search.assert_called_once_with(
|
||||
tool_config['index_name'],
|
||||
mock_search_req,
|
||||
mock_search_options
|
||||
)
|
||||
|
||||
# Check scope search was NOT called
|
||||
couchbase_tool._scope.search.assert_not_called()
|
||||
|
||||
# Check result format
|
||||
assert '"id": "doc1"' in result
|
||||
assert '"id": "doc2"' in result
|
||||
|
||||
def test_check_bucket_exists_fail(tool_config):
|
||||
"""Test check for bucket non-existence."""
|
||||
mock_cluster = tool_config['cluster']
|
||||
mock_cluster.buckets().get_bucket.side_effect = mock_couchbase.exceptions.BucketNotFoundException("Bucket not found")
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
with pytest.raises(ValueError, match="Bucket test_bucket does not exist."):
|
||||
CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
|
||||
|
||||
def test_check_scope_exists_fail(tool_config):
|
||||
"""Test check for scope non-existence."""
|
||||
mock_cluster = tool_config['cluster']
|
||||
# Simulate scope not being in the list returned
|
||||
mock_scope_spec = MagicMock()
|
||||
mock_scope_spec.name = "wrong_scope"
|
||||
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
with pytest.raises(ValueError, match="Scope test_scope not found"):
|
||||
CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
|
||||
|
||||
def test_check_collection_exists_fail(tool_config):
|
||||
"""Test check for collection non-existence."""
|
||||
mock_cluster = tool_config['cluster']
|
||||
# Simulate collection not being in the scope's list
|
||||
mock_scope_spec = MagicMock()
|
||||
mock_scope_spec.name = "test_scope"
|
||||
mock_collection_spec = MagicMock()
|
||||
mock_collection_spec.name = "wrong_collection"
|
||||
mock_scope_spec.collections = [mock_collection_spec] # Only has wrong collection
|
||||
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
with pytest.raises(ValueError, match="Collection test_collection not found"):
|
||||
CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
|
||||
def test_check_index_exists_fail_scoped(tool_config):
|
||||
"""Test check for scoped index non-existence."""
|
||||
mock_cluster = tool_config['cluster']
|
||||
# Simulate index not being in the list returned by scope manager
|
||||
mock_cluster.bucket().scope().search_indexes().get_all_indexes.return_value = []
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
with pytest.raises(ValueError, match="Index test_index does not exist"):
|
||||
CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
|
||||
|
||||
def test_check_index_exists_fail_global(tool_config):
|
||||
"""Test check for global index non-existence."""
|
||||
tool_config['scoped_index'] = False
|
||||
mock_cluster = tool_config['cluster']
|
||||
# Simulate index not being in the list returned by cluster manager
|
||||
mock_cluster.search_indexes().get_all_indexes.return_value = []
|
||||
|
||||
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
|
||||
with pytest.raises(ValueError, match="Index test_index does not exist"):
|
||||
CouchbaseFTSVectorSearchTool(**tool_config)
|
||||
355
tests/tools/crewai_enterprise_tools_test.py
Normal file
355
tests/tools/crewai_enterprise_tools_test.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.tools import CrewaiEnterpriseTools
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionTool
|
||||
|
||||
|
||||
class TestCrewaiEnterpriseTools(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_tools = [
|
||||
self._create_mock_tool("tool1", "Tool 1 Description"),
|
||||
self._create_mock_tool("tool2", "Tool 2 Description"),
|
||||
self._create_mock_tool("tool3", "Tool 3 Description"),
|
||||
]
|
||||
self.adapter_patcher = patch(
|
||||
"crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools.EnterpriseActionKitToolAdapter"
|
||||
)
|
||||
self.MockAdapter = self.adapter_patcher.start()
|
||||
|
||||
mock_adapter_instance = self.MockAdapter.return_value
|
||||
mock_adapter_instance.tools.return_value = self.mock_tools
|
||||
|
||||
def tearDown(self):
|
||||
self.adapter_patcher.stop()
|
||||
|
||||
def _create_mock_tool(self, name, description):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.name = name
|
||||
mock_tool.description = description
|
||||
return mock_tool
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_returns_tool_collection(self):
|
||||
tools = CrewaiEnterpriseTools()
|
||||
self.assertIsInstance(tools, ToolCollection)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_returns_all_tools_when_no_actions_list(self):
|
||||
tools = CrewaiEnterpriseTools()
|
||||
self.assertEqual(len(tools), 3)
|
||||
self.assertEqual(tools[0].name, "tool1")
|
||||
self.assertEqual(tools[1].name, "tool2")
|
||||
self.assertEqual(tools[2].name, "tool3")
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_filters_tools_by_actions_list(self):
|
||||
tools = CrewaiEnterpriseTools(actions_list=["ToOl1", "tool3"])
|
||||
self.assertEqual(len(tools), 2)
|
||||
self.assertEqual(tools[0].name, "tool1")
|
||||
self.assertEqual(tools[1].name, "tool3")
|
||||
|
||||
def test_uses_provided_parameters(self):
|
||||
CrewaiEnterpriseTools(
|
||||
enterprise_token="test-token",
|
||||
enterprise_action_kit_project_id="project-id",
|
||||
enterprise_action_kit_project_url="project-url",
|
||||
)
|
||||
|
||||
self.MockAdapter.assert_called_once_with(
|
||||
enterprise_action_token="test-token",
|
||||
enterprise_action_kit_project_id="project-id",
|
||||
enterprise_action_kit_project_url="project-url",
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_uses_environment_token(self):
|
||||
CrewaiEnterpriseTools()
|
||||
self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token")
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
|
||||
def test_uses_environment_token_when_no_token_provided(self):
|
||||
CrewaiEnterpriseTools(enterprise_token="")
|
||||
self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token",
|
||||
"CREWAI_ENTERPRISE_TOOLS_ACTIONS_LIST": '["tool1", "tool3"]',
|
||||
},
|
||||
)
|
||||
def test_uses_environment_actions_list(self):
|
||||
tools = CrewaiEnterpriseTools()
|
||||
self.assertEqual(len(tools), 2)
|
||||
self.assertEqual(tools[0].name, "tool1")
|
||||
self.assertEqual(tools[1].name, "tool3")
|
||||
|
||||
|
||||
class TestEnterpriseActionToolSchemaConversion(unittest.TestCase):
|
||||
"""Test the enterprise action tool schema conversion and validation."""
|
||||
|
||||
def setUp(self):
|
||||
self.test_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TEST_COMPLEX_ACTION",
|
||||
"description": "Test action with complex nested structure",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filterCriteria": {
|
||||
"type": "object",
|
||||
"description": "Filter criteria object",
|
||||
"properties": {
|
||||
"operation": {"type": "string", "enum": ["AND", "OR"]},
|
||||
"rules": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"enum": ["name", "email", "status"],
|
||||
},
|
||||
"operator": {
|
||||
"type": "string",
|
||||
"enum": ["equals", "contains"],
|
||||
},
|
||||
"value": {"type": "string"},
|
||||
},
|
||||
"required": ["field", "operator", "value"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["operation", "rules"],
|
||||
},
|
||||
"options": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {"type": "integer"},
|
||||
"offset": {"type": "integer"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def test_complex_schema_conversion(self):
|
||||
"""Test that complex nested schemas are properly converted to Pydantic models."""
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
self.assertEqual(tool.name, "gmail_search_for_email")
|
||||
self.assertEqual(tool.action_name, "GMAIL_SEARCH_FOR_EMAIL")
|
||||
|
||||
schema_class = tool.args_schema
|
||||
self.assertIsNotNone(schema_class)
|
||||
|
||||
schema_fields = schema_class.model_fields
|
||||
self.assertIn("filterCriteria", schema_fields)
|
||||
self.assertIn("options", schema_fields)
|
||||
|
||||
# Test valid input structure
|
||||
valid_input = {
|
||||
"filterCriteria": {
|
||||
"operation": "AND",
|
||||
"rules": [
|
||||
{"field": "name", "operator": "contains", "value": "test"},
|
||||
{"field": "status", "operator": "equals", "value": "active"},
|
||||
],
|
||||
},
|
||||
"options": {"limit": 10},
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
validated_input = schema_class(**valid_input)
|
||||
self.assertIsNotNone(validated_input.filterCriteria)
|
||||
self.assertIsNotNone(validated_input.options)
|
||||
|
||||
def test_optional_fields_validation(self):
|
||||
"""Test that optional fields work correctly."""
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
schema_class = tool.args_schema
|
||||
|
||||
minimal_input = {}
|
||||
validated_input = schema_class(**minimal_input)
|
||||
self.assertIsNone(validated_input.filterCriteria)
|
||||
self.assertIsNone(validated_input.options)
|
||||
|
||||
partial_input = {"options": {"limit": 10}}
|
||||
validated_input = schema_class(**partial_input)
|
||||
self.assertIsNone(validated_input.filterCriteria)
|
||||
self.assertIsNotNone(validated_input.options)
|
||||
|
||||
def test_enum_validation(self):
|
||||
"""Test that enum values are properly validated."""
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
schema_class = tool.args_schema
|
||||
|
||||
invalid_input = {
|
||||
"filterCriteria": {
|
||||
"operation": "INVALID_OPERATOR",
|
||||
"rules": [],
|
||||
}
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
schema_class(**invalid_input)
|
||||
|
||||
def test_required_nested_fields(self):
|
||||
"""Test that required fields in nested objects are validated."""
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
schema_class = tool.args_schema
|
||||
|
||||
incomplete_input = {
|
||||
"filterCriteria": {
|
||||
"operation": "OR",
|
||||
"rules": [
|
||||
{
|
||||
"field": "name",
|
||||
"operator": "contains",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
schema_class(**incomplete_input)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_tool_execution_with_complex_input(self, mock_post):
|
||||
"""Test that the tool can execute with complex validated input."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {"success": True, "results": []}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
tool._run(
|
||||
filterCriteria={
|
||||
"operation": "OR",
|
||||
"rules": [
|
||||
{"field": "name", "operator": "contains", "value": "test"},
|
||||
{"field": "status", "operator": "equals", "value": "active"},
|
||||
],
|
||||
},
|
||||
options={"limit": 10},
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
payload = call_args[1]["json"]
|
||||
|
||||
self.assertIn("filterCriteria", payload)
|
||||
self.assertIn("options", payload)
|
||||
self.assertEqual(payload["filterCriteria"]["operation"], "OR")
|
||||
|
||||
def test_model_naming_convention(self):
|
||||
"""Test that generated model names follow proper conventions."""
|
||||
tool = EnterpriseActionTool(
|
||||
name="gmail_search_for_email",
|
||||
description="Test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||
action_schema=self.test_schema,
|
||||
)
|
||||
|
||||
schema_class = tool.args_schema
|
||||
self.assertIsNotNone(schema_class)
|
||||
|
||||
self.assertTrue(schema_class.__name__.endswith("Schema"))
|
||||
self.assertTrue(schema_class.__name__[0].isupper())
|
||||
|
||||
complex_input = {
|
||||
"filterCriteria": {
|
||||
"operation": "OR",
|
||||
"rules": [
|
||||
{"field": "name", "operator": "contains", "value": "test"},
|
||||
{"field": "status", "operator": "equals", "value": "active"},
|
||||
],
|
||||
},
|
||||
"options": {"limit": 10},
|
||||
}
|
||||
|
||||
validated = schema_class(**complex_input)
|
||||
self.assertIsNotNone(validated.filterCriteria)
|
||||
|
||||
def test_simple_schema_with_enums(self):
|
||||
"""Test a simpler schema with basic enum validation."""
|
||||
simple_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "SIMPLE_TEST",
|
||||
"description": "Simple test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "pending"],
|
||||
},
|
||||
"priority": {"type": "integer", "enum": [1, 2, 3]},
|
||||
},
|
||||
"required": ["status"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = EnterpriseActionTool(
|
||||
name="simple_test",
|
||||
description="Simple test tool",
|
||||
enterprise_action_token="test_token",
|
||||
action_name="SIMPLE_TEST",
|
||||
action_schema=simple_schema,
|
||||
)
|
||||
|
||||
schema_class = tool.args_schema
|
||||
|
||||
valid_input = {"status": "active", "priority": 2}
|
||||
validated = schema_class(**valid_input)
|
||||
self.assertEqual(validated.status, "active")
|
||||
self.assertEqual(validated.priority, 2)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
schema_class(status="invalid_status")
|
||||
@@ -0,0 +1,165 @@
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
import pytest
|
||||
from crewai_tools.tools.crewai_platform_tools import CrewAIPlatformActionTool
|
||||
|
||||
|
||||
class TestCrewAIPlatformActionTool(unittest.TestCase):
|
||||
@pytest.fixture
|
||||
def sample_action_schema(self):
|
||||
return {
|
||||
"function": {
|
||||
"name": "test_action",
|
||||
"description": "Test action for unit testing",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message to send"
|
||||
},
|
||||
"priority": {
|
||||
"type": "integer",
|
||||
"description": "Priority level"
|
||||
}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def platform_action_tool(self, sample_action_schema):
|
||||
return CrewAIPlatformActionTool(
|
||||
description="Test Action Tool\nTest description",
|
||||
action_name="test_action",
|
||||
action_schema=sample_action_schema
|
||||
)
|
||||
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
|
||||
def test_run_success(self, mock_post):
|
||||
schema = {
|
||||
"function": {
|
||||
"name": "test_action",
|
||||
"description": "Test action",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message"}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool = CrewAIPlatformActionTool(
|
||||
description="Test tool",
|
||||
action_name="test_action",
|
||||
action_schema=schema
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {"result": "success", "data": "test_data"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = tool._run(message="test message")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
_, kwargs = mock_post.call_args
|
||||
|
||||
assert "test_action/execute" in kwargs["url"]
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer test_token"
|
||||
assert kwargs["json"]["message"] == "test message"
|
||||
assert "success" in result
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
|
||||
def test_run_api_error(self, mock_post):
|
||||
schema = {
|
||||
"function": {
|
||||
"name": "test_action",
|
||||
"description": "Test action",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message"}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool = CrewAIPlatformActionTool(
|
||||
description="Test tool",
|
||||
action_name="test_action",
|
||||
action_schema=schema
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.ok = False
|
||||
mock_response.json.return_value = {"error": {"message": "Invalid request"}}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = tool._run(message="test message")
|
||||
|
||||
assert "API request failed" in result
|
||||
assert "Invalid request" in result
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
|
||||
def test_run_exception(self, mock_post):
|
||||
schema = {
|
||||
"function": {
|
||||
"name": "test_action",
|
||||
"description": "Test action",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message"}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool = CrewAIPlatformActionTool(
|
||||
description="Test tool",
|
||||
action_name="test_action",
|
||||
action_schema=schema
|
||||
)
|
||||
|
||||
mock_post.side_effect = Exception("Network error")
|
||||
|
||||
result = tool._run(message="test message")
|
||||
|
||||
assert "Error executing action test_action: Network error" in result
|
||||
|
||||
def test_run_without_token(self):
|
||||
schema = {
|
||||
"function": {
|
||||
"name": "test_action",
|
||||
"description": "Test action",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message"}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool = CrewAIPlatformActionTool(
|
||||
description="Test tool",
|
||||
action_name="test_action",
|
||||
action_schema=schema
|
||||
)
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = tool._run(message="test message")
|
||||
assert "Error executing action test_action:" in result
|
||||
assert "No platform integration token found" in result
|
||||
@@ -0,0 +1,223 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
import pytest
|
||||
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformToolBuilder, CrewAIPlatformActionTool
|
||||
|
||||
|
||||
class TestCrewaiPlatformToolBuilder(unittest.TestCase):
|
||||
@pytest.fixture
|
||||
def platform_tool_builder(self):
|
||||
"""Create a CrewaiPlatformToolBuilder instance for testing"""
|
||||
return CrewaiPlatformToolBuilder(apps=["github", "slack"])
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_response(self):
|
||||
return {
|
||||
"actions": {
|
||||
"github": [
|
||||
{
|
||||
"name": "create_issue",
|
||||
"description": "Create a GitHub issue",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string", "description": "Issue title"},
|
||||
"body": {"type": "string", "description": "Issue body"}
|
||||
},
|
||||
"required": ["title"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"slack": [
|
||||
{
|
||||
"name": "send_message",
|
||||
"description": "Send a Slack message",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel": {"type": "string", "description": "Channel name"},
|
||||
"text": {"type": "string", "description": "Message text"}
|
||||
},
|
||||
"required": ["channel", "text"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
def test_fetch_actions_success(self, mock_get):
|
||||
mock_api_response = {
|
||||
"actions": {
|
||||
"github": [
|
||||
{
|
||||
"name": "create_issue",
|
||||
"description": "Create a GitHub issue",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string", "description": "Issue title"}
|
||||
},
|
||||
"required": ["title"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
builder = CrewaiPlatformToolBuilder(apps=["github", "slack/send_message"])
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
builder._fetch_actions()
|
||||
|
||||
mock_get.assert_called_once()
|
||||
args, kwargs = mock_get.call_args
|
||||
|
||||
assert "/actions" in args[0]
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer test_token"
|
||||
assert kwargs["params"]["apps"] == "github,slack/send_message"
|
||||
|
||||
assert "create_issue" in builder._actions_schema
|
||||
assert builder._actions_schema["create_issue"]["function"]["name"] == "create_issue"
|
||||
|
||||
def test_fetch_actions_no_token(self):
|
||||
builder = CrewaiPlatformToolBuilder(apps=["github"])
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
builder._fetch_actions()
|
||||
assert "No platform integration token found" in str(context.exception)
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
def test_create_tools(self, mock_get):
|
||||
mock_api_response = {
|
||||
"actions": {
|
||||
"github": [
|
||||
{
|
||||
"name": "create_issue",
|
||||
"description": "Create a GitHub issue",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string", "description": "Issue title"}
|
||||
},
|
||||
"required": ["title"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"slack": [
|
||||
{
|
||||
"name": "send_message",
|
||||
"description": "Send a Slack message",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel": {"type": "string", "description": "Channel name"}
|
||||
},
|
||||
"required": ["channel"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
builder = CrewaiPlatformToolBuilder(apps=["github", "slack"])
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = mock_api_response
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tools = builder.tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(isinstance(tool, CrewAIPlatformActionTool) for tool in tools)
|
||||
|
||||
tool_names = [tool.action_name for tool in tools]
|
||||
assert "create_issue" in tool_names
|
||||
assert "send_message" in tool_names
|
||||
|
||||
github_tool = next((t for t in tools if t.action_name == "create_issue"), None)
|
||||
slack_tool = next((t for t in tools if t.action_name == "send_message"), None)
|
||||
|
||||
assert github_tool is not None
|
||||
assert slack_tool is not None
|
||||
assert "Create a GitHub issue" in github_tool.description
|
||||
assert "Send a Slack message" in slack_tool.description
|
||||
|
||||
def test_tools_caching(self):
|
||||
builder = CrewaiPlatformToolBuilder(apps=["github"])
|
||||
|
||||
cached_tools = []
|
||||
|
||||
def mock_create_tools():
|
||||
builder._tools = cached_tools
|
||||
|
||||
with patch.object(builder, '_fetch_actions') as mock_fetch, \
|
||||
patch.object(builder, '_create_tools', side_effect=mock_create_tools) as mock_create:
|
||||
|
||||
tools1 = builder.tools()
|
||||
assert mock_fetch.call_count == 1
|
||||
assert mock_create.call_count == 1
|
||||
|
||||
tools2 = builder.tools()
|
||||
assert mock_fetch.call_count == 1
|
||||
assert mock_create.call_count == 1
|
||||
|
||||
assert tools1 is tools2
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
def test_empty_apps_list(self):
|
||||
builder = CrewaiPlatformToolBuilder(apps=[])
|
||||
|
||||
with patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"actions": {}}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tools = builder.tools()
|
||||
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) == 0
|
||||
|
||||
_, kwargs = mock_get.call_args
|
||||
assert kwargs["params"]["apps"] == ""
|
||||
|
||||
def test_detailed_description_generation(self):
|
||||
builder = CrewaiPlatformToolBuilder(apps=["test"])
|
||||
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"simple_string": {"type": "string", "description": "A simple string"},
|
||||
"nested_object": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inner_prop": {"type": "integer", "description": "Inner property"}
|
||||
},
|
||||
"description": "Nested object"
|
||||
},
|
||||
"array_prop": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Array of strings"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
descriptions = builder._generate_detailed_description(complex_schema)
|
||||
|
||||
assert isinstance(descriptions, list)
|
||||
assert len(descriptions) > 0
|
||||
|
||||
description_text = "\n".join(descriptions)
|
||||
assert "simple_string" in description_text
|
||||
assert "nested_object" in description_text
|
||||
assert "array_prop" in description_text
|
||||
@@ -0,0 +1,95 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformTools
|
||||
|
||||
|
||||
class TestCrewaiPlatformTools(unittest.TestCase):
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
def test_crewai_platform_tools_basic(self, mock_get):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"actions": {"github": []}}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tools = CrewaiPlatformTools(apps=["github"])
|
||||
assert tools is not None
|
||||
assert isinstance(tools, list)
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
def test_crewai_platform_tools_multiple_apps(self, mock_get):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
"actions": {
|
||||
"github": [
|
||||
{
|
||||
"name": "create_issue",
|
||||
"description": "Create a GitHub issue",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string", "description": "Issue title"},
|
||||
"body": {"type": "string", "description": "Issue body"}
|
||||
},
|
||||
"required": ["title"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"slack": [
|
||||
{
|
||||
"name": "send_message",
|
||||
"description": "Send a Slack message",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel": {"type": "string", "description": "Channel to send to"},
|
||||
"text": {"type": "string", "description": "Message text"}
|
||||
},
|
||||
"required": ["channel", "text"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tools = CrewaiPlatformTools(apps=["github", "slack"])
|
||||
assert tools is not None
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) == 2
|
||||
|
||||
mock_get.assert_called_once()
|
||||
args, kwargs = mock_get.call_args
|
||||
assert "apps=github,slack" in args[0] or kwargs.get("params", {}).get("apps") == "github,slack"
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
def test_crewai_platform_tools_empty_apps(self):
|
||||
with patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"actions": {}}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tools = CrewaiPlatformTools(apps=[])
|
||||
assert tools is not None
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) == 0
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
def test_crewai_platform_tools_api_error_handling(self, mock_get):
|
||||
mock_get.side_effect = Exception("API Error")
|
||||
|
||||
tools = CrewaiPlatformTools(apps=["github"])
|
||||
assert tools is not None
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) == 0
|
||||
|
||||
def test_crewai_platform_tools_no_token(self):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
CrewaiPlatformTools(apps=["github"])
|
||||
assert "No platform integration token found" in str(context.exception)
|
||||
32
tests/tools/exa_search_tool_test.py
Normal file
32
tests/tools/exa_search_tool_test.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from crewai_tools import EXASearchTool
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def exa_search_tool():
|
||||
return EXASearchTool(api_key="test_api_key")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_exa_api_key():
|
||||
with patch.dict(os.environ, {"EXA_API_KEY": "test_key_from_env"}):
|
||||
yield
|
||||
|
||||
def test_exa_search_tool_initialization():
|
||||
with patch("crewai_tools.tools.exa_tools.exa_search_tool.Exa") as mock_exa_class:
|
||||
api_key = "test_api_key"
|
||||
tool = EXASearchTool(api_key=api_key)
|
||||
|
||||
assert tool.api_key == api_key
|
||||
assert tool.content is False
|
||||
assert tool.summary is False
|
||||
assert tool.type == "auto"
|
||||
mock_exa_class.assert_called_once_with(api_key=api_key)
|
||||
|
||||
|
||||
def test_exa_search_tool_initialization_with_env(mock_exa_api_key):
|
||||
with patch("crewai_tools.tools.exa_tools.exa_search_tool.Exa") as mock_exa_class:
|
||||
EXASearchTool()
|
||||
mock_exa_class.assert_called_once_with(api_key="test_key_from_env")
|
||||
187
tests/tools/generate_crewai_automation_tool_test.py
Normal file
187
tests/tools/generate_crewai_automation_tool_test.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automation_tool import (
|
||||
GenerateCrewaiAutomationTool,
|
||||
GenerateCrewaiAutomationToolSchema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env():
|
||||
with patch.dict(os.environ, {"CREWAI_PERSONAL_ACCESS_TOKEN": "test_token"}):
|
||||
os.environ.pop("CREWAI_PLUS_URL", None)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return GenerateCrewaiAutomationTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_url_tool():
|
||||
with patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom.crewai.com"}):
|
||||
return GenerateCrewaiAutomationTool()
|
||||
|
||||
|
||||
def test_default_initialization(tool):
|
||||
assert tool.crewai_enterprise_url == "https://app.crewai.com"
|
||||
assert tool.personal_access_token == "test_token"
|
||||
assert tool.name == "Generate CrewAI Automation"
|
||||
|
||||
|
||||
def test_custom_base_url_from_environment(custom_url_tool):
|
||||
assert custom_url_tool.crewai_enterprise_url == "https://custom.crewai.com"
|
||||
|
||||
|
||||
def test_personal_access_token_from_environment(tool):
|
||||
assert tool.personal_access_token == "test_token"
|
||||
|
||||
|
||||
def test_valid_prompt_only():
|
||||
schema = GenerateCrewaiAutomationToolSchema(
|
||||
prompt="Create a web scraping automation"
|
||||
)
|
||||
assert schema.prompt == "Create a web scraping automation"
|
||||
assert schema.organization_id is None
|
||||
|
||||
|
||||
def test_valid_prompt_with_organization_id():
|
||||
schema = GenerateCrewaiAutomationToolSchema(
|
||||
prompt="Create automation", organization_id="org-123"
|
||||
)
|
||||
assert schema.prompt == "Create automation"
|
||||
assert schema.organization_id == "org-123"
|
||||
|
||||
|
||||
def test_empty_prompt_validation():
|
||||
schema = GenerateCrewaiAutomationToolSchema(prompt="")
|
||||
assert schema.prompt == ""
|
||||
assert schema.organization_id is None
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_successful_generation_without_org_id(mock_post, tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"url": "https://app.crewai.com/studio/project-123"
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = tool.run(prompt="Create automation")
|
||||
|
||||
assert (
|
||||
result
|
||||
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-123"
|
||||
)
|
||||
mock_post.assert_called_once_with(
|
||||
"https://app.crewai.com/crewai_plus/api/v1/studio",
|
||||
headers={
|
||||
"Authorization": "Bearer test_token",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
json={"prompt": "Create automation"},
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_successful_generation_with_org_id(mock_post, tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"url": "https://app.crewai.com/studio/project-456"
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = tool.run(prompt="Create automation", organization_id="org-456")
|
||||
|
||||
assert (
|
||||
result
|
||||
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-456"
|
||||
)
|
||||
mock_post.assert_called_once_with(
|
||||
"https://app.crewai.com/crewai_plus/api/v1/studio",
|
||||
headers={
|
||||
"Authorization": "Bearer test_token",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"X-Crewai-Organization-Id": "org-456",
|
||||
},
|
||||
json={"prompt": "Create automation"},
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_custom_base_url_usage(mock_post, custom_url_tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"url": "https://custom.crewai.com/studio/project-789"
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
custom_url_tool.run(prompt="Create automation")
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://custom.crewai.com/crewai_plus/api/v1/studio",
|
||||
headers={
|
||||
"Authorization": "Bearer test_token",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
json={"prompt": "Create automation"},
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_api_error_response_handling(mock_post, tool):
|
||||
mock_post.return_value.raise_for_status.side_effect = requests.HTTPError(
|
||||
"400 Bad Request"
|
||||
)
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
tool.run(prompt="Create automation")
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_network_error_handling(mock_post, tool):
|
||||
mock_post.side_effect = requests.ConnectionError("Network unreachable")
|
||||
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
tool.run(prompt="Create automation")
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_api_response_missing_url(mock_post, tool):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = tool.run(prompt="Create automation")
|
||||
|
||||
assert result == "Generated CrewAI Studio project URL: None"
|
||||
|
||||
|
||||
def test_authorization_header_construction(tool):
|
||||
headers = tool._get_headers()
|
||||
|
||||
assert headers["Authorization"] == "Bearer test_token"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Accept"] == "application/json"
|
||||
assert "X-Crewai-Organization-Id" not in headers
|
||||
|
||||
|
||||
def test_authorization_header_with_org_id(tool):
|
||||
headers = tool._get_headers(organization_id="org-123")
|
||||
|
||||
assert headers["Authorization"] == "Bearer test_token"
|
||||
assert headers["X-Crewai-Organization-Id"] == "org-123"
|
||||
|
||||
|
||||
def test_missing_personal_access_token():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
tool = GenerateCrewaiAutomationTool()
|
||||
assert tool.personal_access_token is None
|
||||
47
tests/tools/parallel_search_tool_test.py
Normal file
47
tests/tools/parallel_search_tool_test.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.parallel_tools.parallel_search_tool import (
|
||||
ParallelSearchTool,
|
||||
)
|
||||
|
||||
|
||||
def test_requires_env_var(monkeypatch):
|
||||
monkeypatch.delenv("PARALLEL_API_KEY", raising=False)
|
||||
tool = ParallelSearchTool()
|
||||
result = tool.run(objective="test")
|
||||
assert "PARALLEL_API_KEY" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.parallel_tools.parallel_search_tool.requests.post")
|
||||
def test_happy_path(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("PARALLEL_API_KEY", "test")
|
||||
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"search_id": "search_123",
|
||||
"results": [
|
||||
{
|
||||
"url": "https://www.un.org/en/about-us/history-of-the-un",
|
||||
"title": "History of the United Nations",
|
||||
"excerpts": [
|
||||
"Four months after the San Francisco Conference ended, the United Nations officially began, on 24 October 1945..."
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tool = ParallelSearchTool()
|
||||
result = tool.run(objective="When was the UN established?", search_queries=["Founding year UN"])
|
||||
data = json.loads(result)
|
||||
assert "search_id" in data
|
||||
urls = [r.get("url", "") for r in data.get("results", [])]
|
||||
# Validate host against allowed set instead of substring matching
|
||||
allowed_hosts = {"www.un.org", "un.org"}
|
||||
assert any(urlparse(u).netloc in allowed_hosts for u in urls)
|
||||
|
||||
|
||||
43
tests/tools/rag/rag_tool_test.py
Normal file
43
tests/tools/rag/rag_tool_test.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def mock_embedchain_db_uri():
|
||||
with NamedTemporaryFile() as tmp:
|
||||
uri = f"sqlite:///{tmp.name}"
|
||||
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_llm_and_embedder():
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="openai",
|
||||
config=dict(model="gpt-3.5-custom"),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="openai",
|
||||
config=dict(model="text-embedding-3-custom"),
|
||||
),
|
||||
)
|
||||
)
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, EmbedchainAdapter)
|
||||
|
||||
adapter = cast(EmbedchainAdapter, tool.adapter)
|
||||
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
||||
assert (
|
||||
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
||||
)
|
||||
129
tests/tools/selenium_scraping_tool_test.py
Normal file
129
tests/tools/selenium_scraping_tool_test.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
|
||||
SeleniumScrapingTool,
|
||||
)
|
||||
|
||||
|
||||
def mock_driver_with_html(html_content):
|
||||
driver = MagicMock()
|
||||
mock_element = MagicMock()
|
||||
mock_element.get_attribute.return_value = html_content
|
||||
bs = BeautifulSoup(html_content, "html.parser")
|
||||
mock_element.text = bs.get_text()
|
||||
|
||||
driver.find_elements.return_value = [mock_element]
|
||||
driver.find_element.return_value = mock_element
|
||||
|
||||
return driver
|
||||
|
||||
|
||||
def initialize_tool_with(mock_driver):
|
||||
tool = SeleniumScrapingTool(driver=mock_driver)
|
||||
return tool
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_tool_initialization(mocked_chrome):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
mocked_chrome.return_value = MagicMock()
|
||||
|
||||
tool = SeleniumScrapingTool()
|
||||
|
||||
assert tool.website_url is None
|
||||
assert tool.css_element is None
|
||||
assert tool.cookie is None
|
||||
assert tool.wait_time == 3
|
||||
assert tool.return_html is False
|
||||
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_tool_initialization_with_options(mocked_chrome):
|
||||
mocked_chrome.return_value = MagicMock()
|
||||
|
||||
options = Options()
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
SeleniumScrapingTool(options=options)
|
||||
|
||||
mocked_chrome.assert_called_once_with(options=options)
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_scrape_without_css_selector(_mocked_chrome_driver):
|
||||
html_content = "<html><body><div>test content</div></body></html>"
|
||||
mock_driver = mock_driver_with_html(html_content)
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
|
||||
result = tool._run(website_url="https://example.com")
|
||||
|
||||
assert "test content" in result
|
||||
mock_driver.get.assert_called_once_with("https://example.com")
|
||||
mock_driver.find_element.assert_called_with("tag name", "body")
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_scrape_with_css_selector(_mocked_chrome_driver):
|
||||
html_content = "<html><body><div>test content</div><div class='test'>test content in a specific div</div></body></html>"
|
||||
mock_driver = mock_driver_with_html(html_content)
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
|
||||
result = tool._run(website_url="https://example.com", css_element="div.test")
|
||||
|
||||
assert "test content in a specific div" in result
|
||||
mock_driver.get.assert_called_once_with("https://example.com")
|
||||
mock_driver.find_elements.assert_called_with("css selector", "div.test")
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_scrape_with_return_html_true(_mocked_chrome_driver):
|
||||
html_content = "<html><body><div>HTML content</div></body></html>"
|
||||
mock_driver = mock_driver_with_html(html_content)
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
|
||||
result = tool._run(website_url="https://example.com", return_html=True)
|
||||
|
||||
assert html_content in result
|
||||
mock_driver.get.assert_called_once_with("https://example.com")
|
||||
mock_driver.find_element.assert_called_with("tag name", "body")
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_scrape_with_return_html_false(_mocked_chrome_driver):
|
||||
html_content = "<html><body><div>HTML content</div></body></html>"
|
||||
mock_driver = mock_driver_with_html(html_content)
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
|
||||
result = tool._run(website_url="https://example.com", return_html=False)
|
||||
|
||||
assert "HTML content" in result
|
||||
mock_driver.get.assert_called_once_with("https://example.com")
|
||||
mock_driver.find_element.assert_called_with("tag name", "body")
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_scrape_with_driver_error(_mocked_chrome_driver):
|
||||
mock_driver = MagicMock()
|
||||
mock_driver.find_element.side_effect = Exception("WebDriver error occurred")
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
result = tool._run(website_url="https://example.com")
|
||||
assert result == "Error scraping website: WebDriver error occurred"
|
||||
mock_driver.close.assert_called_once()
|
||||
|
||||
@patch("selenium.webdriver.Chrome")
|
||||
def test_initialization_with_driver(_mocked_chrome_driver):
|
||||
mock_driver = MagicMock()
|
||||
tool = initialize_tool_with(mock_driver)
|
||||
assert tool.driver == mock_driver
|
||||
151
tests/tools/serper_dev_tool_test.py
Normal file
151
tests/tools/serper_dev_tool_test.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
import os
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_serper_api_key():
|
||||
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serper_tool():
|
||||
return SerperDevTool(n_results=2)
|
||||
|
||||
|
||||
def test_serper_tool_initialization():
|
||||
tool = SerperDevTool()
|
||||
assert tool.n_results == 10
|
||||
assert tool.save_file is False
|
||||
assert tool.search_type == "search"
|
||||
assert tool.country == ""
|
||||
assert tool.location == ""
|
||||
assert tool.locale == ""
|
||||
|
||||
|
||||
def test_serper_tool_custom_initialization():
|
||||
tool = SerperDevTool(
|
||||
n_results=5,
|
||||
save_file=True,
|
||||
search_type="news",
|
||||
country="US",
|
||||
location="New York",
|
||||
locale="en"
|
||||
)
|
||||
assert tool.n_results == 5
|
||||
assert tool.save_file is True
|
||||
assert tool.search_type == "news"
|
||||
assert tool.country == "US"
|
||||
assert tool.location == "New York"
|
||||
assert tool.locale == "en"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_search(mock_post):
|
||||
tool = SerperDevTool(n_results=2)
|
||||
mock_response = {
|
||||
"searchParameters": {
|
||||
"q": "test query",
|
||||
"type": "search"
|
||||
},
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title 1",
|
||||
"link": "http://test1.com",
|
||||
"snippet": "Test Description 1",
|
||||
"position": 1
|
||||
},
|
||||
{
|
||||
"title": "Test Title 2",
|
||||
"link": "http://test2.com",
|
||||
"snippet": "Test Description 2",
|
||||
"position": 2
|
||||
}
|
||||
],
|
||||
"peopleAlsoAsk": [
|
||||
{
|
||||
"question": "Test Question",
|
||||
"snippet": "Test Answer",
|
||||
"title": "Test Source",
|
||||
"link": "http://test.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value.json.return_value = mock_response
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
result = tool.run(search_query="test query")
|
||||
|
||||
assert "searchParameters" in result
|
||||
assert result["searchParameters"]["q"] == "test query"
|
||||
assert len(result["organic"]) == 2
|
||||
assert result["organic"][0]["title"] == "Test Title 1"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_news_search(mock_post):
|
||||
tool = SerperDevTool(n_results=2, search_type="news")
|
||||
mock_response = {
|
||||
"searchParameters": {
|
||||
"q": "test news",
|
||||
"type": "news"
|
||||
},
|
||||
"news": [
|
||||
{
|
||||
"title": "News Title 1",
|
||||
"link": "http://news1.com",
|
||||
"snippet": "News Description 1",
|
||||
"date": "2024-01-01",
|
||||
"source": "News Source 1",
|
||||
"imageUrl": "http://image1.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value.json.return_value = mock_response
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
result = tool.run(search_query="test news")
|
||||
|
||||
assert "news" in result
|
||||
assert len(result["news"]) == 1
|
||||
assert result["news"][0]["title"] == "News Title 1"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_serper_tool_with_location_params(mock_post):
|
||||
tool = SerperDevTool(
|
||||
n_results=2,
|
||||
country="US",
|
||||
location="New York",
|
||||
locale="en"
|
||||
)
|
||||
|
||||
tool.run(search_query="test")
|
||||
|
||||
called_payload = mock_post.call_args.kwargs["json"]
|
||||
assert called_payload["gl"] == "US"
|
||||
assert called_payload["location"] == "New York"
|
||||
assert called_payload["hl"] == "en"
|
||||
|
||||
|
||||
def test_invalid_search_type():
|
||||
tool = SerperDevTool()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
tool.run(search_query="test", search_type="invalid")
|
||||
assert "Invalid search type" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_api_error_handling(mock_post):
|
||||
tool = SerperDevTool()
|
||||
mock_post.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool.run(search_query="test")
|
||||
assert "API Error" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
336
tests/tools/singlestore_search_tool_test.py
Normal file
336
tests/tools/singlestore_search_tool_test.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from singlestoredb import connect
|
||||
from singlestoredb.server import docker
|
||||
|
||||
from crewai_tools import SingleStoreSearchTool
|
||||
from crewai_tools.tools.singlestore_search_tool import SingleStoreSearchToolSchema
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def docker_server_url() -> Generator[str, None, None]:
|
||||
"""Start a SingleStore Docker server for tests."""
|
||||
try:
|
||||
sdb = docker.start(license="")
|
||||
conn = sdb.connect()
|
||||
curr = conn.cursor()
|
||||
curr.execute("CREATE DATABASE test_crewai")
|
||||
curr.close()
|
||||
conn.close()
|
||||
yield sdb.connection_url
|
||||
sdb.stop()
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not start SingleStore Docker container: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def clean_db_url(docker_server_url) -> Generator[str, None, None]:
|
||||
"""Provide a clean database URL and clean up tables after test."""
|
||||
yield docker_server_url
|
||||
try:
|
||||
conn = connect(host=docker_server_url, database="test_crewai")
|
||||
curr = conn.cursor()
|
||||
curr.execute("SHOW TABLES")
|
||||
results = curr.fetchall()
|
||||
for result in results:
|
||||
curr.execute(f"DROP TABLE {result[0]}")
|
||||
curr.close()
|
||||
conn.close()
|
||||
except Exception:
|
||||
# Ignore cleanup errors
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_table_setup(clean_db_url):
|
||||
"""Set up sample tables for testing."""
|
||||
conn = connect(host=clean_db_url, database="test_crewai")
|
||||
curr = conn.cursor()
|
||||
|
||||
# Create sample tables
|
||||
curr.execute(
|
||||
"""
|
||||
CREATE TABLE employees (
|
||||
id INT PRIMARY KEY,
|
||||
name VARCHAR(100),
|
||||
department VARCHAR(50),
|
||||
salary DECIMAL(10,2)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
curr.execute(
|
||||
"""
|
||||
CREATE TABLE departments (
|
||||
id INT PRIMARY KEY,
|
||||
name VARCHAR(100),
|
||||
budget DECIMAL(12,2)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Insert sample data
|
||||
curr.execute(
|
||||
"""
|
||||
INSERT INTO employees VALUES
|
||||
(1, 'Alice Smith', 'Engineering', 75000.00),
|
||||
(2, 'Bob Johnson', 'Marketing', 65000.00),
|
||||
(3, 'Carol Davis', 'Engineering', 80000.00)
|
||||
"""
|
||||
)
|
||||
|
||||
curr.execute(
|
||||
"""
|
||||
INSERT INTO departments VALUES
|
||||
(1, 'Engineering', 500000.00),
|
||||
(2, 'Marketing', 300000.00)
|
||||
"""
|
||||
)
|
||||
|
||||
curr.close()
|
||||
conn.close()
|
||||
return clean_db_url
|
||||
|
||||
|
||||
class TestSingleStoreSearchTool:
|
||||
"""Test suite for SingleStoreSearchTool."""
|
||||
|
||||
def test_tool_creation_with_connection_params(self, sample_table_setup):
|
||||
"""Test tool creation with individual connection parameters."""
|
||||
# Parse URL components for individual parameters
|
||||
url_parts = sample_table_setup.split("@")[1].split(":")
|
||||
host = url_parts[0]
|
||||
port = int(url_parts[1].split("/")[0])
|
||||
user = "root"
|
||||
password = sample_table_setup.split("@")[0].split(":")[2]
|
||||
tool = SingleStoreSearchTool(
|
||||
tables=[],
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
database="test_crewai",
|
||||
)
|
||||
|
||||
assert tool.name == "Search a database's table(s) content"
|
||||
assert "SingleStore" in tool.description
|
||||
assert (
|
||||
"employees(id int(11), name varchar(100), department varchar(50), salary decimal(10,2))"
|
||||
in tool.description.lower()
|
||||
)
|
||||
assert (
|
||||
"departments(id int(11), name varchar(100), budget decimal(12,2))"
|
||||
in tool.description.lower()
|
||||
)
|
||||
assert tool.args_schema == SingleStoreSearchToolSchema
|
||||
assert tool.connection_pool is not None
|
||||
|
||||
def test_tool_creation_with_connection_url(self, sample_table_setup):
|
||||
"""Test tool creation with connection URL."""
|
||||
tool = SingleStoreSearchTool(host=f"{sample_table_setup}/test_crewai")
|
||||
|
||||
assert tool.name == "Search a database's table(s) content"
|
||||
assert tool.connection_pool is not None
|
||||
|
||||
def test_tool_creation_with_specific_tables(self, sample_table_setup):
|
||||
"""Test tool creation with specific table list."""
|
||||
tool = SingleStoreSearchTool(
|
||||
tables=["employees"],
|
||||
host=sample_table_setup,
|
||||
database="test_crewai",
|
||||
)
|
||||
|
||||
# Check that description includes specific tables
|
||||
assert "employees" in tool.description
|
||||
assert "departments" not in tool.description
|
||||
|
||||
def test_tool_creation_with_nonexistent_table(self, sample_table_setup):
|
||||
"""Test tool creation fails with non-existent table."""
|
||||
|
||||
with pytest.raises(ValueError, match="Table nonexistent does not exist"):
|
||||
SingleStoreSearchTool(
|
||||
tables=["employees", "nonexistent"],
|
||||
host=sample_table_setup,
|
||||
database="test_crewai",
|
||||
)
|
||||
|
||||
def test_tool_creation_with_empty_database(self, clean_db_url):
|
||||
"""Test tool creation fails with empty database."""
|
||||
|
||||
with pytest.raises(ValueError, match="No tables found in the database"):
|
||||
SingleStoreSearchTool(host=clean_db_url, database="test_crewai")
|
||||
|
||||
def test_description_generation(self, sample_table_setup):
|
||||
"""Test that tool description is properly generated with table info."""
|
||||
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
# Check description contains table definitions
|
||||
assert "employees(" in tool.description
|
||||
assert "departments(" in tool.description
|
||||
assert "id int" in tool.description.lower()
|
||||
assert "name varchar" in tool.description.lower()
|
||||
|
||||
def test_query_validation_select_allowed(self, sample_table_setup):
|
||||
"""Test that SELECT queries are allowed."""
|
||||
os.environ["SINGLESTOREDB_URL"] = sample_table_setup
|
||||
tool = SingleStoreSearchTool(database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query("SELECT * FROM employees")
|
||||
assert valid is True
|
||||
assert message == "Valid query"
|
||||
|
||||
def test_query_validation_show_allowed(self, sample_table_setup):
|
||||
"""Test that SHOW queries are allowed."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query("SHOW TABLES")
|
||||
assert valid is True
|
||||
assert message == "Valid query"
|
||||
|
||||
def test_query_validation_case_insensitive(self, sample_table_setup):
|
||||
"""Test that query validation is case insensitive."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, _ = tool._validate_query("select * from employees")
|
||||
assert valid is True
|
||||
|
||||
valid, _ = tool._validate_query("SHOW tables")
|
||||
assert valid is True
|
||||
|
||||
def test_query_validation_insert_denied(self, sample_table_setup):
|
||||
"""Test that INSERT queries are denied."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query(
|
||||
"INSERT INTO employees VALUES (4, 'Test', 'Test', 1000)"
|
||||
)
|
||||
assert valid is False
|
||||
assert "Only SELECT and SHOW queries are supported" in message
|
||||
|
||||
def test_query_validation_update_denied(self, sample_table_setup):
|
||||
"""Test that UPDATE queries are denied."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query("UPDATE employees SET salary = 90000")
|
||||
assert valid is False
|
||||
assert "Only SELECT and SHOW queries are supported" in message
|
||||
|
||||
def test_query_validation_delete_denied(self, sample_table_setup):
|
||||
"""Test that DELETE queries are denied."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query("DELETE FROM employees WHERE id = 1")
|
||||
assert valid is False
|
||||
assert "Only SELECT and SHOW queries are supported" in message
|
||||
|
||||
def test_query_validation_non_string(self, sample_table_setup):
|
||||
"""Test that non-string queries are rejected."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
valid, message = tool._validate_query(123)
|
||||
assert valid is False
|
||||
assert "Search query must be a string" in message
|
||||
|
||||
def test_run_select_query(self, sample_table_setup):
|
||||
"""Test executing a SELECT query."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run("SELECT * FROM employees ORDER BY id")
|
||||
|
||||
assert "Search Results:" in result
|
||||
assert "Alice Smith" in result
|
||||
assert "Bob Johnson" in result
|
||||
assert "Carol Davis" in result
|
||||
|
||||
def test_run_filtered_query(self, sample_table_setup):
|
||||
"""Test executing a filtered SELECT query."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run(
|
||||
"SELECT name FROM employees WHERE department = 'Engineering'"
|
||||
)
|
||||
|
||||
assert "Search Results:" in result
|
||||
assert "Alice Smith" in result
|
||||
assert "Carol Davis" in result
|
||||
assert "Bob Johnson" not in result
|
||||
|
||||
def test_run_show_query(self, sample_table_setup):
|
||||
"""Test executing a SHOW query."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run("SHOW TABLES")
|
||||
|
||||
assert "Search Results:" in result
|
||||
assert "employees" in result
|
||||
assert "departments" in result
|
||||
|
||||
def test_run_empty_result(self, sample_table_setup):
|
||||
"""Test executing a query that returns no results."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run("SELECT * FROM employees WHERE department = 'NonExistent'")
|
||||
|
||||
assert result == "No results found."
|
||||
|
||||
def test_run_invalid_query_syntax(self, sample_table_setup):
|
||||
"""Test executing a query with invalid syntax."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run("SELECT * FORM employees") # Intentional typo
|
||||
|
||||
assert "Error executing search query:" in result
|
||||
|
||||
def test_run_denied_query(self, sample_table_setup):
|
||||
"""Test that denied queries return appropriate error message."""
|
||||
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
|
||||
|
||||
result = tool._run("DELETE FROM employees")
|
||||
|
||||
assert "Invalid search query:" in result
|
||||
assert "Only SELECT and SHOW queries are supported" in result
|
||||
|
||||
def test_connection_pool_usage(self, sample_table_setup):
|
||||
"""Test that connection pooling works correctly."""
|
||||
tool = SingleStoreSearchTool(
|
||||
host=sample_table_setup,
|
||||
database="test_crewai",
|
||||
pool_size=2,
|
||||
)
|
||||
|
||||
# Execute multiple queries to test pool usage
|
||||
results = []
|
||||
for _ in range(5):
|
||||
result = tool._run("SELECT COUNT(*) FROM employees")
|
||||
results.append(result)
|
||||
|
||||
# All queries should succeed
|
||||
for result in results:
|
||||
assert "Search Results:" in result
|
||||
assert "3" in result # Count of employees
|
||||
|
||||
def test_tool_schema_validation(self):
|
||||
"""Test that the tool schema validation works correctly."""
|
||||
# Valid input
|
||||
valid_input = SingleStoreSearchToolSchema(search_query="SELECT * FROM test")
|
||||
assert valid_input.search_query == "SELECT * FROM test"
|
||||
|
||||
# Test that description is present
|
||||
schema_dict = SingleStoreSearchToolSchema.model_json_schema()
|
||||
assert "search_query" in schema_dict["properties"]
|
||||
assert "description" in schema_dict["properties"]["search_query"]
|
||||
|
||||
def test_connection_error_handling(self):
|
||||
"""Test handling of connection errors."""
|
||||
with pytest.raises(Exception):
|
||||
# This should fail due to invalid connection parameters
|
||||
SingleStoreSearchTool(
|
||||
host="invalid_host",
|
||||
port=9999,
|
||||
user="invalid_user",
|
||||
password="invalid_password",
|
||||
database="invalid_db",
|
||||
)
|
||||
103
tests/tools/snowflake_search_tool_test.py
Normal file
103
tests/tools/snowflake_search_tool_test.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mock_snowflake_connection():
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [("col1",), ("col2",)]
|
||||
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
|
||||
mock_cursor.execute.return_value = None
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
return SnowflakeConfig(
|
||||
account="test_account",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
warehouse="test_warehouse",
|
||||
database="test_db",
|
||||
snowflake_schema="test_schema",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(mock_config):
|
||||
with patch("snowflake.connector.connect") as mock_connect:
|
||||
tool = SnowflakeSearchTool(config=mock_config)
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
results = await snowflake_tool._run(
|
||||
query="SELECT * FROM test_table", timeout=300
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["col1"] == 1
|
||||
assert results[0]["col2"] == "value1"
|
||||
mock_snowflake_connection.cursor.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Execute multiple queries
|
||||
await asyncio.gather(
|
||||
snowflake_tool._run("SELECT 1"),
|
||||
snowflake_tool._run("SELECT 2"),
|
||||
snowflake_tool._run("SELECT 3"),
|
||||
)
|
||||
|
||||
# Should reuse connections from pool
|
||||
assert mock_create_conn.call_count <= snowflake_tool.pool_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Add connection to pool
|
||||
await snowflake_tool._get_connection()
|
||||
|
||||
# Return connection to pool
|
||||
async with snowflake_tool._pool_lock:
|
||||
snowflake_tool._connection_pool.append(mock_snowflake_connection)
|
||||
|
||||
# Trigger cleanup
|
||||
snowflake_tool.__del__()
|
||||
|
||||
mock_snowflake_connection.close.assert_called_once()
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig()
|
||||
|
||||
# Test invalid account format
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(
|
||||
account="invalid//account", user="test_user", password="test_pass"
|
||||
)
|
||||
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
262
tests/tools/stagehand_tool_test.py
Normal file
262
tests/tools/stagehand_tool_test.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Create mock classes that will be used by our fixture
|
||||
class MockStagehandModule:
|
||||
def __init__(self):
|
||||
self.Stagehand = MagicMock()
|
||||
self.StagehandConfig = MagicMock()
|
||||
self.StagehandPage = MagicMock()
|
||||
|
||||
class MockStagehandSchemas:
|
||||
def __init__(self):
|
||||
self.ActOptions = MagicMock()
|
||||
self.ExtractOptions = MagicMock()
|
||||
self.ObserveOptions = MagicMock()
|
||||
self.AvailableModel = MagicMock()
|
||||
|
||||
class MockStagehandUtils:
|
||||
def __init__(self):
|
||||
self.configure_logging = MagicMock()
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_stagehand_modules():
|
||||
"""Mock stagehand modules at the start of this test module."""
|
||||
# Store original modules if they exist
|
||||
original_modules = {}
|
||||
for module_name in ["stagehand", "stagehand.schemas", "stagehand.utils"]:
|
||||
if module_name in sys.modules:
|
||||
original_modules[module_name] = sys.modules[module_name]
|
||||
|
||||
# Create and inject mock modules
|
||||
mock_stagehand = MockStagehandModule()
|
||||
mock_stagehand_schemas = MockStagehandSchemas()
|
||||
mock_stagehand_utils = MockStagehandUtils()
|
||||
|
||||
sys.modules["stagehand"] = mock_stagehand
|
||||
sys.modules["stagehand.schemas"] = mock_stagehand_schemas
|
||||
sys.modules["stagehand.utils"] = mock_stagehand_utils
|
||||
|
||||
# Import after mocking
|
||||
from crewai_tools.tools.stagehand_tool.stagehand_tool import StagehandResult, StagehandTool
|
||||
|
||||
# Make these available to tests in this module
|
||||
sys.modules[__name__].StagehandResult = StagehandResult
|
||||
sys.modules[__name__].StagehandTool = StagehandTool
|
||||
|
||||
yield
|
||||
|
||||
# Restore original modules
|
||||
for module_name, module in original_modules.items():
|
||||
sys.modules[module_name] = module
|
||||
|
||||
|
||||
class MockStagehandPage(MagicMock):
|
||||
def act(self, options):
|
||||
mock_result = MagicMock()
|
||||
mock_result.model_dump.return_value = {
|
||||
"message": "Action completed successfully"
|
||||
}
|
||||
return mock_result
|
||||
|
||||
def goto(self, url):
|
||||
return MagicMock()
|
||||
|
||||
def extract(self, options):
|
||||
mock_result = MagicMock()
|
||||
mock_result.model_dump.return_value = {
|
||||
"data": "Extracted content",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
return mock_result
|
||||
|
||||
def observe(self, options):
|
||||
result1 = MagicMock()
|
||||
result1.description = "Button element"
|
||||
result1.method = "click"
|
||||
|
||||
result2 = MagicMock()
|
||||
result2.description = "Input field"
|
||||
result2.method = "type"
|
||||
|
||||
return [result1, result2]
|
||||
|
||||
|
||||
class MockStagehand(MagicMock):
|
||||
def init(self):
|
||||
self.session_id = "test-session-id"
|
||||
self.page = MockStagehandPage()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stagehand_instance():
|
||||
with patch(
|
||||
"crewai_tools.tools.stagehand_tool.stagehand_tool.Stagehand",
|
||||
return_value=MockStagehand(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stagehand_tool():
|
||||
return StagehandTool(
|
||||
api_key="test_api_key",
|
||||
project_id="test_project_id",
|
||||
model_api_key="test_model_api_key",
|
||||
_testing=True, # Enable testing mode to bypass dependency check
|
||||
)
|
||||
|
||||
|
||||
def test_stagehand_tool_initialization():
|
||||
"""Test that the StagehandTool initializes with the correct default values."""
|
||||
tool = StagehandTool(
|
||||
api_key="test_api_key",
|
||||
project_id="test_project_id",
|
||||
model_api_key="test_model_api_key",
|
||||
_testing=True, # Enable testing mode
|
||||
)
|
||||
|
||||
assert tool.api_key == "test_api_key"
|
||||
assert tool.project_id == "test_project_id"
|
||||
assert tool.model_api_key == "test_model_api_key"
|
||||
assert tool.headless is False
|
||||
assert tool.dom_settle_timeout_ms == 3000
|
||||
assert tool.self_heal is True
|
||||
assert tool.wait_for_captcha_solves is True
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
|
||||
def test_act_command(mock_run, stagehand_tool):
|
||||
"""Test the 'act' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Action result: Action completed successfully"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Click the submit button", command_type="act"
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert "Action result" in result
|
||||
assert "Action completed successfully" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
|
||||
def test_navigate_command(mock_run, stagehand_tool):
|
||||
"""Test the 'navigate' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Successfully navigated to https://example.com"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Go to example.com",
|
||||
url="https://example.com",
|
||||
command_type="navigate",
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert "https://example.com" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
|
||||
def test_extract_command(mock_run, stagehand_tool):
|
||||
"""Test the 'extract' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Extracted data: {\"data\": \"Extracted content\", \"metadata\": {\"source\": \"test\"}}"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Extract all product names and prices", command_type="extract"
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert "Extracted data" in result
|
||||
assert "Extracted content" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
|
||||
def test_observe_command(mock_run, stagehand_tool):
|
||||
"""Test the 'observe' command functionality."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Element 1: Button element\nSuggested action: click\nElement 2: Input field\nSuggested action: type"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Find all interactive elements", command_type="observe"
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert "Element 1: Button element" in result
|
||||
assert "Element 2: Input field" in result
|
||||
assert "Suggested action: click" in result
|
||||
assert "Suggested action: type" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
|
||||
def test_error_handling(mock_run, stagehand_tool):
|
||||
"""Test error handling in the tool."""
|
||||
# Setup mock
|
||||
mock_run.return_value = "Error: Browser automation error"
|
||||
|
||||
# Run the tool
|
||||
result = stagehand_tool._run(
|
||||
instruction="Click a non-existent button", command_type="act"
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert "Error:" in result
|
||||
assert "Browser automation error" in result
|
||||
|
||||
|
||||
def test_initialization_parameters():
|
||||
"""Test that the StagehandTool initializes with the correct parameters."""
|
||||
# Create tool with custom parameters
|
||||
tool = StagehandTool(
|
||||
api_key="custom_api_key",
|
||||
project_id="custom_project_id",
|
||||
model_api_key="custom_model_api_key",
|
||||
headless=True,
|
||||
dom_settle_timeout_ms=5000,
|
||||
self_heal=False,
|
||||
wait_for_captcha_solves=False,
|
||||
verbose=3,
|
||||
_testing=True, # Enable testing mode
|
||||
)
|
||||
|
||||
# Verify the tool was initialized with the correct parameters
|
||||
assert tool.api_key == "custom_api_key"
|
||||
assert tool.project_id == "custom_project_id"
|
||||
assert tool.model_api_key == "custom_model_api_key"
|
||||
assert tool.headless is True
|
||||
assert tool.dom_settle_timeout_ms == 5000
|
||||
assert tool.self_heal is False
|
||||
assert tool.wait_for_captcha_solves is False
|
||||
assert tool.verbose == 3
|
||||
|
||||
|
||||
def test_close_method():
|
||||
"""Test that the close method cleans up resources correctly."""
|
||||
# Create the tool with testing mode
|
||||
tool = StagehandTool(
|
||||
api_key="test_api_key",
|
||||
project_id="test_project_id",
|
||||
model_api_key="test_model_api_key",
|
||||
_testing=True,
|
||||
)
|
||||
|
||||
# Setup mock stagehand instance
|
||||
tool._stagehand = MagicMock()
|
||||
tool._stagehand.close = MagicMock() # Non-async mock
|
||||
tool._page = MagicMock()
|
||||
|
||||
# Call the close method
|
||||
tool.close()
|
||||
|
||||
# Verify resources were cleaned up
|
||||
assert tool._stagehand is None
|
||||
assert tool._page is None
|
||||
175
tests/tools/test_code_interpreter_tool.py
Normal file
175
tests/tools/test_code_interpreter_tool.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
CodeInterpreterTool,
|
||||
SandboxPython,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def printer_mock():
|
||||
with patch("crewai_tools.printer.Printer.print") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docker_unavailable_mock():
|
||||
with patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.CodeInterpreterTool._check_docker_available",
|
||||
return_value=False,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
def test_run_code_in_docker(docker_mock, printer_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print('Hello, World!')"
|
||||
libraries_used = ["numpy", "pandas"]
|
||||
expected_output = "Hello, World!\n"
|
||||
|
||||
docker_mock().containers.run().exec_run().exit_code = 0
|
||||
docker_mock().containers.run().exec_run().output = expected_output.encode()
|
||||
|
||||
result = tool.run_code_in_docker(code, libraries_used)
|
||||
assert result == expected_output
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in Docker environment", color="bold_blue"
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
def test_run_code_in_docker_with_error(docker_mock, printer_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print(1/0)"
|
||||
libraries_used = ["numpy", "pandas"]
|
||||
expected_output = "Something went wrong while running the code: \nZeroDivisionError: division by zero\n"
|
||||
|
||||
docker_mock().containers.run().exec_run().exit_code = 1
|
||||
docker_mock().containers.run().exec_run().output = (
|
||||
b"ZeroDivisionError: division by zero\n"
|
||||
)
|
||||
|
||||
result = tool.run_code_in_docker(code, libraries_used)
|
||||
assert result == expected_output
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in Docker environment", color="bold_blue"
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
def test_run_code_in_docker_with_script(docker_mock, printer_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """print("This is line 1")
|
||||
print("This is line 2")"""
|
||||
libraries_used = []
|
||||
expected_output = "This is line 1\nThis is line 2\n"
|
||||
|
||||
docker_mock().containers.run().exec_run().exit_code = 0
|
||||
docker_mock().containers.run().exec_run().output = expected_output.encode()
|
||||
|
||||
result = tool.run_code_in_docker(code, libraries_used)
|
||||
assert result == expected_output
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in Docker environment", color="bold_blue"
|
||||
)
|
||||
|
||||
|
||||
def test_restricted_sandbox_basic_code_execution(printer_mock, docker_unavailable_mock):
|
||||
"""Test basic code execution."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
result = 2 + 2
|
||||
print(result)
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert result == 4
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_modules(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted modules cannot be imported."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_modules = SandboxPython.BLOCKED_MODULES
|
||||
|
||||
for module in restricted_modules:
|
||||
code = f"""
|
||||
import {module}
|
||||
result = "Import succeeded"
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
|
||||
assert f"An error occurred: Importing '{module}' is not allowed" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_blocked_builtins(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test that restricted builtins are not available."""
|
||||
tool = CodeInterpreterTool()
|
||||
restricted_builtins = SandboxPython.UNSAFE_BUILTINS
|
||||
|
||||
for builtin in restricted_builtins:
|
||||
code = f"""
|
||||
{builtin}("test")
|
||||
result = "Builtin available"
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert f"An error occurred: name '{builtin}' is not defined" in result
|
||||
|
||||
|
||||
def test_restricted_sandbox_running_with_no_result_variable(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test behavior when no result variable is set."""
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
x = 10
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"Running code in restricted sandbox", color="yellow"
|
||||
)
|
||||
assert result == "No result variable found."
|
||||
|
||||
|
||||
def test_unsafe_mode_running_with_no_result_variable(
|
||||
printer_mock, docker_unavailable_mock
|
||||
):
|
||||
"""Test behavior when no result variable is set."""
|
||||
tool = CodeInterpreterTool(unsafe_mode=True)
|
||||
code = """
|
||||
x = 10
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"WARNING: Running code in unsafe mode", color="bold_magenta"
|
||||
)
|
||||
assert result == "No result variable found."
|
||||
|
||||
|
||||
def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
|
||||
"""Test behavior when no result variable is set."""
|
||||
tool = CodeInterpreterTool(unsafe_mode=True)
|
||||
code = """
|
||||
import os
|
||||
os.system("ls -la")
|
||||
result = eval("5/1")
|
||||
"""
|
||||
result = tool.run(code=code, libraries_used=[])
|
||||
printer_mock.assert_called_with(
|
||||
"WARNING: Running code in unsafe mode", color="bold_magenta"
|
||||
)
|
||||
assert 5.0 == result
|
||||
10
tests/tools/test_import_without_warnings.py
Normal file
10
tests/tools/test_import_without_warnings.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from pydantic.warnings import PydanticDeprecatedSince20
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("error", category=PydanticDeprecatedSince20)
|
||||
def test_import_tools_without_pydantic_deprecation_warnings():
|
||||
# This test is to ensure that the import of crewai_tools does not raise any Pydantic deprecation warnings.
|
||||
import crewai_tools
|
||||
|
||||
assert crewai_tools
|
||||
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mongodb_vector_search_tool():
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo", database_name="bar", collection_name="test"
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
def test_successful_query_execution(mongodb_vector_search_tool):
|
||||
# Enable embedding
|
||||
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "foo"
|
||||
assert results[0]["_id"] == 1
|
||||
|
||||
|
||||
def test_provide_config():
|
||||
query_config = MongoDBVectorSearchConfig(limit=10)
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo",
|
||||
database_name="bar",
|
||||
collection_name="test",
|
||||
query_config=query_config,
|
||||
vector_index_name="foo",
|
||||
embedding_model="bar",
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
with patch.object(tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
tool._run(query="sandwiches")
|
||||
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
|
||||
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
|
||||
def test_cleanup_on_deletion(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
|
||||
# Trigger cleanup
|
||||
mongodb_vector_search_tool.__del__()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
|
||||
def test_create_search_index(mongodb_vector_search_tool):
|
||||
with patch(
|
||||
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
|
||||
) as mock_create_search_index:
|
||||
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
|
||||
kwargs = mock_create_search_index.mock_calls[0].kwargs
|
||||
assert kwargs["dimensions"] == 10
|
||||
assert kwargs["similarity"] == "cosine"
|
||||
|
||||
|
||||
def test_add_texts(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
|
||||
mongodb_vector_search_tool.add_texts(["foo"])
|
||||
args = bulk_write.mock_calls[0].args
|
||||
assert "ReplaceOne" in str(args[0][0])
|
||||
assert "foo" in str(args[0][0])
|
||||
163
tests/tools/test_oxylabs_tools.py
Normal file
163
tests/tools/test_oxylabs_tools.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from oxylabs import RealtimeClient
|
||||
from oxylabs.sources.response import Response as OxylabsResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_tools import (
|
||||
OxylabsAmazonProductScraperTool,
|
||||
OxylabsAmazonSearchScraperTool,
|
||||
OxylabsGoogleSearchScraperTool,
|
||||
OxylabsUniversalScraperTool,
|
||||
)
|
||||
from crewai_tools.tools.oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
|
||||
OxylabsAmazonProductScraperConfig,
|
||||
)
|
||||
from crewai_tools.tools.oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
|
||||
OxylabsGoogleSearchScraperConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oxylabs_api() -> RealtimeClient:
|
||||
oxylabs_api_mock = MagicMock()
|
||||
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Scraping Sandbox</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main">
|
||||
<div id="product-list">
|
||||
<div>
|
||||
<p>Amazing product</p>
|
||||
<p>Price $14.99</p>
|
||||
</div>
|
||||
<div>
|
||||
<p>Good product</p>
|
||||
<p>Price $9.99</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
json_content = {
|
||||
"results": {
|
||||
"products": [
|
||||
{"title": "Amazing product", "price": 14.99, "currency": "USD"},
|
||||
{"title": "Good product", "price": 9.99, "currency": "USD"},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
html_response = OxylabsResponse({"results": [{"content": html_content}]})
|
||||
json_response = OxylabsResponse({"results": [{"content": json_content}]})
|
||||
|
||||
oxylabs_api_mock.universal.scrape_url.side_effect = [json_response, html_response]
|
||||
oxylabs_api_mock.amazon.scrape_search.side_effect = [json_response, html_response]
|
||||
oxylabs_api_mock.amazon.scrape_product.side_effect = [json_response, html_response]
|
||||
oxylabs_api_mock.google.scrape_search.side_effect = [json_response, html_response]
|
||||
|
||||
return oxylabs_api_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_class",),
|
||||
[
|
||||
(OxylabsUniversalScraperTool,),
|
||||
(OxylabsAmazonSearchScraperTool,),
|
||||
(OxylabsGoogleSearchScraperTool,),
|
||||
(OxylabsAmazonProductScraperTool,),
|
||||
],
|
||||
)
|
||||
def test_tool_initialization(tool_class: Type[BaseTool]):
|
||||
tool = tool_class(username="username", password="password")
|
||||
assert isinstance(tool, tool_class)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_class",),
|
||||
[
|
||||
(OxylabsUniversalScraperTool,),
|
||||
(OxylabsAmazonSearchScraperTool,),
|
||||
(OxylabsGoogleSearchScraperTool,),
|
||||
(OxylabsAmazonProductScraperTool,),
|
||||
],
|
||||
)
|
||||
def test_tool_initialization_with_env_vars(tool_class: Type[BaseTool]):
|
||||
os.environ["OXYLABS_USERNAME"] = "username"
|
||||
os.environ["OXYLABS_PASSWORD"] = "password"
|
||||
|
||||
tool = tool_class()
|
||||
assert isinstance(tool, tool_class)
|
||||
|
||||
del os.environ["OXYLABS_USERNAME"]
|
||||
del os.environ["OXYLABS_PASSWORD"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_class",),
|
||||
[
|
||||
(OxylabsUniversalScraperTool,),
|
||||
(OxylabsAmazonSearchScraperTool,),
|
||||
(OxylabsGoogleSearchScraperTool,),
|
||||
(OxylabsAmazonProductScraperTool,),
|
||||
],
|
||||
)
|
||||
def test_tool_initialization_failure(tool_class: Type[BaseTool]):
|
||||
# making sure env vars are not set
|
||||
for key in ["OXYLABS_USERNAME", "OXYLABS_PASSWORD"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tool_class()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_class", "tool_config"),
|
||||
[
|
||||
(OxylabsUniversalScraperTool, {"geo_location": "Paris, France"}),
|
||||
(
|
||||
OxylabsAmazonSearchScraperTool,
|
||||
{"domain": "co.uk"},
|
||||
),
|
||||
(
|
||||
OxylabsGoogleSearchScraperTool,
|
||||
OxylabsGoogleSearchScraperConfig(render="html"),
|
||||
),
|
||||
(
|
||||
OxylabsAmazonProductScraperTool,
|
||||
OxylabsAmazonProductScraperConfig(parse=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_invocation(
|
||||
tool_class: Type[BaseTool],
|
||||
tool_config: BaseModel,
|
||||
oxylabs_api: RealtimeClient,
|
||||
):
|
||||
tool = tool_class(username="username", password="password", config=tool_config)
|
||||
|
||||
# setting via __dict__ to bypass pydantic validation
|
||||
tool.__dict__["oxylabs_api"] = oxylabs_api
|
||||
|
||||
# verifying parsed job returns json content
|
||||
result = tool.run("Scraping Query 1")
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(json.loads(result), dict)
|
||||
|
||||
# verifying raw job returns str
|
||||
result = tool.run("Scraping Query 2")
|
||||
assert isinstance(result, str)
|
||||
assert "<!DOCTYPE html>" in result
|
||||
309
tests/tools/test_search_tools.py
Normal file
309
tests/tools/test_search_tools.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.tools import (
|
||||
CodeDocsSearchTool,
|
||||
CSVSearchTool,
|
||||
DirectorySearchTool,
|
||||
DOCXSearchTool,
|
||||
GithubSearchTool,
|
||||
JSONSearchTool,
|
||||
MDXSearchTool,
|
||||
PDFSearchTool,
|
||||
TXTSearchTool,
|
||||
WebsiteSearchTool,
|
||||
XMLSearchTool,
|
||||
YoutubeChannelSearchTool,
|
||||
YoutubeVideoSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
return mock_adapter
|
||||
|
||||
|
||||
def test_directory_search_tool():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = Path(temp_dir) / "test.txt"
|
||||
test_file.write_text("This is a test file for directory search")
|
||||
|
||||
tool = DirectorySearchTool(directory=temp_dir)
|
||||
result = tool._run(search_query="test file")
|
||||
assert "test file" in result.lower()
|
||||
|
||||
|
||||
def test_pdf_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
|
||||
result = tool._run(query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = PDFSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(pdf="test.pdf", query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
|
||||
def test_txt_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||
temp_file.write(b"This is a test file for txt search")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = TXTSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test file")
|
||||
assert "test file" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_docx_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
|
||||
result = tool._run(search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = DOCXSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(docx="test.docx", search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
|
||||
def test_json_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b'{"test": "This is a test JSON file"}')
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = JSONSearchTool()
|
||||
result = tool._run(search_query="test JSON", json_path=temp_file_path)
|
||||
assert "test json" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_xml_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = XMLSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(search_query="test XML", xml="test.xml")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.xml")
|
||||
mock_adapter.query.assert_called_once_with("test XML")
|
||||
|
||||
|
||||
def test_csv_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
|
||||
temp_file.write(b"name,description\ntest,This is a test CSV file")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = CSVSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test CSV")
|
||||
assert "test csv" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_mdx_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
|
||||
temp_file.write(b"# Test MDX\nThis is a test MDX file")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = MDXSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test MDX")
|
||||
assert "test mdx" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_website_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
website = "https://crewai.com"
|
||||
search_query = "what is crewai?"
|
||||
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = WebsiteSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(website=website, search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
|
||||
def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "some video description"
|
||||
|
||||
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
|
||||
search_query = "what is the video about?"
|
||||
tool = YoutubeVideoSearchTool(
|
||||
youtube_video_url=youtube_video_url,
|
||||
adapter=mock_adapter,
|
||||
)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "some video description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
|
||||
assert "some video description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "channel description"
|
||||
|
||||
youtube_channel_handle = "@crewai"
|
||||
search_query = "what is the channel about?"
|
||||
tool = YoutubeChannelSearchTool(
|
||||
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
|
||||
)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "channel description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
youtube_channel_handle=youtube_channel_handle, search_query=search_query
|
||||
)
|
||||
assert "channel description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_code_docs_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "test documentation"
|
||||
|
||||
docs_url = "https://crewai.com/any-docs-url"
|
||||
search_query = "test documentation"
|
||||
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = CodeDocsSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(docs_url=docs_url, search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_github_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "repo description"
|
||||
|
||||
# ensure the provided repo and content types are used after initialization
|
||||
tool = GithubSearchTool(
|
||||
gh_token="test_token",
|
||||
github_repo="crewai/crewai",
|
||||
content_types=["code"],
|
||||
adapter=mock_adapter,
|
||||
)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure content types provided by run call is used
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
github_repo="crewai/crewai",
|
||||
content_types=["code", "issue"],
|
||||
search_query="tell me about crewai repo",
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure default content types are used if not provided
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
github_repo="crewai/crewai",
|
||||
search_query="tell me about crewai repo",
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure nothing is added if no repo is provided
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
mock_adapter.add.assert_not_called()
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
231
tests/tools/tool_collection_test.py
Normal file
231
tests/tools/tool_collection_test.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
class TestToolCollection(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
self.search_tool = self._create_mock_tool("SearcH", "Search Tool") # Tool name is case sensitive
|
||||
self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool")
|
||||
self.translator_tool = self._create_mock_tool("translator", "Translator Tool")
|
||||
|
||||
self.tools = ToolCollection([
|
||||
self.search_tool,
|
||||
self.calculator_tool,
|
||||
self.translator_tool
|
||||
])
|
||||
|
||||
def _create_mock_tool(self, name, description):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.name = name
|
||||
mock_tool.description = description
|
||||
return mock_tool
|
||||
|
||||
def test_initialization(self):
|
||||
self.assertEqual(len(self.tools), 3)
|
||||
self.assertEqual(self.tools[0].name, "SearcH")
|
||||
self.assertEqual(self.tools[1].name, "calculator")
|
||||
self.assertEqual(self.tools[2].name, "translator")
|
||||
|
||||
def test_empty_initialization(self):
|
||||
empty_collection = ToolCollection()
|
||||
self.assertEqual(len(empty_collection), 0)
|
||||
self.assertEqual(empty_collection._name_cache, {})
|
||||
|
||||
def test_initialization_with_none(self):
|
||||
collection = ToolCollection(None)
|
||||
self.assertEqual(len(collection), 0)
|
||||
self.assertEqual(collection._name_cache, {})
|
||||
|
||||
def test_access_by_index(self):
|
||||
self.assertEqual(self.tools[0], self.search_tool)
|
||||
self.assertEqual(self.tools[1], self.calculator_tool)
|
||||
self.assertEqual(self.tools[2], self.translator_tool)
|
||||
|
||||
def test_access_by_name(self):
|
||||
self.assertEqual(self.tools["search"], self.search_tool)
|
||||
self.assertEqual(self.tools["calculator"], self.calculator_tool)
|
||||
self.assertEqual(self.tools["translator"], self.translator_tool)
|
||||
|
||||
def test_key_error_for_invalid_name(self):
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["nonexistent"]
|
||||
|
||||
def test_index_error_for_invalid_index(self):
|
||||
with self.assertRaises(IndexError):
|
||||
_ = self.tools[10]
|
||||
|
||||
def test_negative_index(self):
|
||||
self.assertEqual(self.tools[-1], self.translator_tool)
|
||||
self.assertEqual(self.tools[-2], self.calculator_tool)
|
||||
self.assertEqual(self.tools[-3], self.search_tool)
|
||||
|
||||
def test_append(self):
|
||||
new_tool = self._create_mock_tool("new", "New Tool")
|
||||
self.tools.append(new_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools[3], new_tool)
|
||||
self.assertEqual(self.tools["new"], new_tool)
|
||||
self.assertIn("new", self.tools._name_cache)
|
||||
|
||||
def test_append_duplicate_name(self):
|
||||
duplicate_tool = self._create_mock_tool("search", "Duplicate Search Tool")
|
||||
self.tools.append(duplicate_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools["search"], duplicate_tool)
|
||||
|
||||
def test_extend(self):
|
||||
new_tools = [
|
||||
self._create_mock_tool("tool4", "Tool 4"),
|
||||
self._create_mock_tool("tool5", "Tool 5"),
|
||||
]
|
||||
self.tools.extend(new_tools)
|
||||
|
||||
self.assertEqual(len(self.tools), 5)
|
||||
self.assertEqual(self.tools["tool4"], new_tools[0])
|
||||
self.assertEqual(self.tools["tool5"], new_tools[1])
|
||||
self.assertIn("tool4", self.tools._name_cache)
|
||||
self.assertIn("tool5", self.tools._name_cache)
|
||||
|
||||
def test_insert(self):
|
||||
new_tool = self._create_mock_tool("inserted", "Inserted Tool")
|
||||
self.tools.insert(1, new_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 4)
|
||||
self.assertEqual(self.tools[1], new_tool)
|
||||
self.assertEqual(self.tools["inserted"], new_tool)
|
||||
self.assertIn("inserted", self.tools._name_cache)
|
||||
|
||||
def test_remove(self):
|
||||
self.tools.remove(self.calculator_tool)
|
||||
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["calculator"]
|
||||
self.assertNotIn("calculator", self.tools._name_cache)
|
||||
|
||||
def test_remove_nonexistent_tool(self):
|
||||
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.tools.remove(nonexistent_tool)
|
||||
|
||||
def test_pop(self):
|
||||
popped = self.tools.pop(1)
|
||||
|
||||
self.assertEqual(popped, self.calculator_tool)
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["calculator"]
|
||||
self.assertNotIn("calculator", self.tools._name_cache)
|
||||
|
||||
def test_pop_last(self):
|
||||
popped = self.tools.pop()
|
||||
|
||||
self.assertEqual(popped, self.translator_tool)
|
||||
self.assertEqual(len(self.tools), 2)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["translator"]
|
||||
self.assertNotIn("translator", self.tools._name_cache)
|
||||
|
||||
def test_clear(self):
|
||||
self.tools.clear()
|
||||
|
||||
self.assertEqual(len(self.tools), 0)
|
||||
self.assertEqual(self.tools._name_cache, {})
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.tools["search"]
|
||||
|
||||
def test_iteration(self):
|
||||
tools_list = list(self.tools)
|
||||
self.assertEqual(tools_list, [self.search_tool, self.calculator_tool, self.translator_tool])
|
||||
|
||||
def test_contains(self):
|
||||
self.assertIn(self.search_tool, self.tools)
|
||||
self.assertIn(self.calculator_tool, self.tools)
|
||||
self.assertIn(self.translator_tool, self.tools)
|
||||
|
||||
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
|
||||
self.assertNotIn(nonexistent_tool, self.tools)
|
||||
|
||||
def test_slicing(self):
|
||||
slice_result = self.tools[1:3]
|
||||
self.assertEqual(len(slice_result), 2)
|
||||
self.assertEqual(slice_result[0], self.calculator_tool)
|
||||
self.assertEqual(slice_result[1], self.translator_tool)
|
||||
|
||||
self.assertIsInstance(slice_result, list)
|
||||
self.assertNotIsInstance(slice_result, ToolCollection)
|
||||
|
||||
def test_getitem_with_tool_name_as_int(self):
|
||||
numeric_name_tool = self._create_mock_tool("123", "Numeric Name Tool")
|
||||
self.tools.append(numeric_name_tool)
|
||||
|
||||
self.assertEqual(self.tools["123"], numeric_name_tool)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
_ = self.tools[123]
|
||||
|
||||
def test_filter_by_names(self):
|
||||
|
||||
filtered = self.tools.filter_by_names(None)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 3)
|
||||
|
||||
filtered = self.tools.filter_by_names(["search", "translator"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 2)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered[1], self.translator_tool)
|
||||
self.assertEqual(filtered["search"], self.search_tool)
|
||||
self.assertEqual(filtered["translator"], self.translator_tool)
|
||||
|
||||
filtered = self.tools.filter_by_names(["search", "nonexistent"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
|
||||
filtered = self.tools.filter_by_names(["nonexistent1", "nonexistent2"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
|
||||
filtered = self.tools.filter_by_names([])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
|
||||
def test_filter_where(self):
|
||||
filtered = self.tools.filter_where(lambda tool: tool.name.startswith("S"))
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered["search"], self.search_tool)
|
||||
|
||||
filtered = self.tools.filter_where(lambda tool: True)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 3)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered[1], self.calculator_tool)
|
||||
self.assertEqual(filtered[2], self.translator_tool)
|
||||
|
||||
filtered = self.tools.filter_where(lambda tool: False)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
filtered = self.tools.filter_where(lambda tool: len(tool.name) > 8)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 2)
|
||||
self.assertEqual(filtered[0], self.calculator_tool)
|
||||
self.assertEqual(filtered[1], self.translator_tool)
|
||||
Reference in New Issue
Block a user